diff --git a/.gitignore b/.gitignore index be75938ec401b1d72fa54773c85191aaac7d7f35..828bbe9bd3363853ae3f58f54a8d5f60cefad837 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ Podfile.lock /tensorflow/contrib/lite/examples/ios/simple/data/*.txt /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite xcuserdata/** +/api_init_files_list.txt # Android .gradle diff --git a/CODEOWNERS b/CODEOWNERS index 007a304c3e706ce968576ec8979c08f1a3bcc552..b9f0313cc6d59d3fbdcd014e1a528126d863075a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -45,7 +45,7 @@ # /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh # /tensorflow/contrib/slim/ @sguada @thenbasilmanran # /tensorflow/contrib/stateless/ @girving -# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst +# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank # /tensorflow/contrib/testing/ @dandelionmane # /tensorflow/contrib/timeseries/ @allenlavoie # /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c8212b7445c32f241d887306d3c19ad..db4b1581ae671b1e676e215c9a80dfaab832fa21 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements @@ -79,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g., Changes to TensorFlow C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). -Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: ```bash apt-get install -y clang-tidy diff --git a/README.md b/README.md index ef5bdc66ef03131318e1dde627e0224cca9137fd..63853137cfd30b396f8c7d204811f3e4a1794c07 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,16 @@ ----------------- -| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------|---------------| -| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +| **`Documentation`** | +|-----------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture lets you deploy computation to one +between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes TensorBoard, a data visualization toolkit. +code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), a data visualization toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -22,6 +22,10 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. +Keep up to date with release announcements and security updates by +subscribing to +[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). + ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -36,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. -**Individual whl files** -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) -* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) -([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) - #### *Try your first TensorFlow program* ```shell $ python @@ -61,6 +56,7 @@ $ python 42 >>> sess.close() ``` +Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines @@ -78,10 +74,35 @@ The TensorFlow project strives to abide by generally accepted best practices in [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + +## Continuous build status + +### Official Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | TBA | TBA | +| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | + + +### Community Supported Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | +| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | + + ## For more information * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) +* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) diff --git a/RELEASE.md b/RELEASE.md index 6f54dee58f75c29a16545ba25de12fe059baf1eb..e09e9c6190f57adec67c2ae1d85848dabfd9c2a7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,192 @@ +# Release 1.9.0 + +## Major Features And Improvements +* Update tf.keras to the Keras 2.1.6 API. +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Adding support of core feature columns and losses to gradient boosted trees estimators. +* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details + +## Breaking Chances + * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). + +## Bug Fixes and Other Changes +* `tf.data`: + * The `DatasetBase::DebugString()` method is now `const`. + * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. +* Eager Execution: +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. +* Accelerated Linear Algebra (XLA): +* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* `tf.contrib`: + * Add `tf.contrib.data.choose_from_datasets()`. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Add optional `args` argument to `Dataset.from_generator()`. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang + +# Release 1.8.0 + +## Major Features And Improvements +* Can now pass `tf.contrib.distribute.MirroredStrategy()` to `tf.estimator.RunConfig()` to run an Estimator model on multiple GPUs on one machine. +* Add `tf.contrib.data.prefetch_to_device()`, which supports prefetching to GPU memory. +* Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor. +* Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability. +* `tf.contrib.bayesflow` is moving out to it's own repo. +* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication[1](#rpc-issue). + +## Bug Fixes and Other Changes +* `tf.data`: + * Add `tf.contrib.data.prefetch_to_device`, which enables prefetching dataset elements to GPU memory. + * Add `tf.contrib.data.AUTOTUNE`, which allows the tf.data runtime to automatically tune the prefetch buffer sizes based on your system and environment. + * Add `tf.contrib.data.make_csv_dataset` for building datasets of CSV files. +* Eager Execution: + * With eager execution Datasets can now be used as standard python iterators (`for batch in dataset:`). Both `Dataset.__iter__()` and `Dataset.make_one_shot_iterator()` can now be used to create iterators when eager execution is enabled. + * Automatic device placement has been enabled (i.e., use a GPU if available automatically, without requiring an explicit `with tf.device(“/gpu:0”)`) (Fixes #14133) + * `tf.GradientTape` has moved out of contrib. +* `tf.keras`: + * Added the fashion mnist dataset. + * New data preprocessing functions: `image/random_brightness`, `sequence/TimeseriesGenerator`, and `text/hashing_trick`. +* Accelerated Linear Algebra (XLA): + * Select and scatter in reference util and evaluator now use lexicographical order to break ties. +* TensorFlow Debugger (tfdbg) CLI: + * During tensor-filter operations, allow exclusion of nodes by regular expressions. + * Fix spurious background colors in some text terminals. +* `tf.contrib`: + * Add meta-distribution BatchReshape which reshapes batch dimensions. + * `tf.contrib.layers.recompute_grad` works for explicit gradient checkpointing on TPU. + * Add `tf.contrib.framework.argsort`. + * Allow `DNNBoostedTreeCombinedEstimator` to work with core versions of feature columns and losses. + * Add non-linear image warping ops: `tf.contrib.image.sparse_image_warp`, `tf.contrib.image.dense_image_warp`, and `tf.contrib.image.interpolate_spline`. + * Fix bug in `tf.contrib.opt.MultitaskOptimizerWrapper` where types of tensors were mismatched. +* Other: + * Low-level graph construction now calls the TensorFlow C API. This change should be invisible to most users, but can be disabled by setting the environment variable `TF_C_API_GRAPH_CONSTRUCTION=0` in this release. Future releases will remove the ability to disable this change. Please [file a bug](https://github.com/tensorflow/tensorflow/issues/new) if you find yourself using this escape hatch. + * Add description of shapes and a pointer to tutorial notebook in `tf.distributions.Distribution`. + * Update scatter operations: + * Add `tf.scatter_min` and `tf.scatter_max` + * Extend scatter operations to work with a scalar update parameter. + * Move cuDNN RNN ops to core for use in TensorFlow codebase only. + * Add `float64` support for `Conv2d`, `Conv2dBackpropInput`, and `Conv2dBackpropFilter`. + * Add `float64` support for `AvgPool`/`AvgPoolGrad`. + * Make graph name scope thread local so that they work correctly in multi-threaded environments. + * Update nsync synchronization library to avoid slow primitives on Linux. + * Removed need to put nsync/public on C include path when building custom ops. + * Add `tf.image.psnr`, `tf.image.ssim`, `tf.image.ssim_multiscale`, `tf.image.image_gradients`, `tf.image.sobel_edges`. + * Add links to https://js.tensorflow.org. + * Fix non-uniformity of orthogonal matrices. + * Fix bug where multi-image Estimator eval summaries were not displayed correctly. + +1 The cancellation logic of the RPC op contains a concurrency error. A fix has been submitted to master and will be part of the next release. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu + +# Release 1.7.0 + +## Major Features And Improvements +* Eager mode is moving out of contrib, try `tf.enable_eager_execution()`. +* Graph rewrites emulating fixed-point quantization compatible with TensorFlow Lite, supported by new `tf.contrib.quantize` package. +* Easily customize gradient computation with `tf.custom_gradient`. +* [TensorBoard Debugger Plugin](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md), the graphical user interface (GUI) of TensorFlow Debugger (tfdbg), is now in alpha. +* Experimental support for reading a sqlite database as a `Dataset` with new `tf.contrib.data.SqlDataset`. +* Distributed Mutex / CriticalSection added to `tf.contrib.framework.CriticalSection`. +* Better text processing with `tf.regex_replace`. +* Easy, efficient sequence input with `tf.contrib.data.bucket_by_sequence_length` +* Initial support for `tf.contrib.tensorrt` that enables native TensorRT in + TensorFlow. + +## Bug Fixes and Other Changes +* Accelerated Linear Algebra (XLA): + * Add `MaxPoolGradGrad` support for XLA + * CSE pass from Tensorflow is now disabled in XLA. +* `tf.data`: + * `tf.data.Dataset` + * Add support for building C++ Dataset op kernels as external libraries, using the `tf.load_op_library()` mechanism. + * `Dataset.list_files()` now shuffles its output by default. + * `Dataset.shuffle(..., seed=tf.constant(0, dtype=tf.int64))` now yields the same sequence of elements as `Dataset.shuffle(..., seed=0)`. + * Add `num_parallel_reads` argument to `tf.data.TFRecordDataset`. +* `tf.contrib`: + * `tf.contrib.bayesflow.halton_sequence` now supports randomization. + * Add support for scalars in `tf.contrib.all_reduce`. + * Add `effective_sample_size` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `potential_scale_reduction` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `BatchNormalization`, `Kumaraswamy` bijectors. + * Deprecate `tf.contrib.learn`. Please check contrib/learn/README.md for instructions on how to convert existing code. + * `tf.contrib.data` + * Remove deprecated `tf.contrib.data.Dataset`, `tf.contrib.data.Iterator`, `tf.contrib.data.FixedLengthRecordDataset`, `tf.contrib.data.TextLineDataset`, and `tf.contrib.data.TFRecordDataset` classes. + * Added `bucket_by_sequence_length`, `sliding_window_batch`, and `make_batched_features_dataset` + * Remove unmaintained `tf.contrib.ndlstm`. You can find it externally at https://github.com/tmbarchive/tfndlstm. + * Moved most of `tf.contrib.bayesflow` to its own repo: `tfp` +* Other: + * tf.py_func now reports the full stack trace if an exception occurs. + * Integrate `TPUClusterResolver` with GKE's integration for Cloud TPUs. + * Add a library for statistical testing of samplers. + * Add Helpers to stream data from the GCE VM to a Cloud TPU. + * Integrate ClusterResolvers with TPUEstimator. + * Unify metropolis_hastings interface with HMC kernel. + * Move LIBXSMM convolutions to a separate --define flag so that they are disabled by default. + * Fix `MomentumOptimizer` lambda. + * Reduce `tfp.layers` boilerplate via programmable docstrings. + * Add `auc_with_confidence_intervals`, a method for computing the AUC and confidence interval with linearithmic time complexity. + * `regression_head` now accepts customized link function, to satisfy the usage that user can define their own link function if the `array_ops.identity` does not meet the requirement. + * Fix `initialized_value` and `initial_value` behaviors for `ResourceVariables` created from `VariableDef` protos. + * Add TensorSpec to represent the specification of Tensors. + * Constant folding pass is now deterministic. + * Support `float16` `dtype` in `tf.linalg.*`. + * Add `tf.estimator.export.TensorServingInputReceiver` that allows `tf.estimator.Estimator.export_savedmodel` to pass raw tensors to model functions. + +## Deprecations + +* TensorFlow 1.7 may be the last time we support Cuda versions below 8.0. + Starting with TensorFlow 1.8 release, 8.0 will be the minimum supported + version. +* TensorFlow 1.7 may be the last time we support cuDNN versions below 6.0. + Starting with TensorFlow 1.8 release, 6.0 will be the minimum supported + version. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abe, Alistair Low, Andy Kernahan, Appledore, Ben, Ben Barsdell, Boris Pfahringer, Brad Wannow, Brett Koonce, Carl Thomé, cclauss, Chengzhi Chen, Chris Drake, Christopher Yeh, Clayne Robison, Codrut Grosu, Daniel Trebbien, Danny Goodman, David Goodwin, David Norman, Deron Eriksson, Donggeon Lim, Donny Viszneki, DosLin, DylanDmitri, Francisco Guerrero, Fred Reiss, gdh1995, Giuseppe, Glenn Weidner, gracehoney, Guozhong Zhuang, Haichen "Hc" Li, Harald Husum, harumitsu.nobuta, Henry Spivey, hsm207, Jekyll Song, Jerome, Jiongyan Zhang, jjsjann123, John Sungjin Park, Johnson145, JoshVarty, Julian Wolff, Jun Wang, June-One, Kamil Sindi, Kb Sriram, Kdavis-Mozilla, Kenji, lazypanda1, Liang-Chi Hsieh, Loo Rong Jie, Mahesh Bhosale, MandarJKulkarni, ManHyuk, Marcus Ong, Marshal Hayes, Martin Pool, matthieudelaro, mdfaijul, mholzel, Michael Zhou, Ming Li, Minmin Sun, Myungjoo Ham, MyungsungKwak, Naman Kamra, Peng Yu, Penghao Cen, Phil, Raghuraman-K, resec, Rohin Mohanadas, Sandeep N Gupta, Scott Tseng, seaotterman, Seo Sanghyeon, Sergei Lebedev, Ted Chang, terrytangyuan, Tim H, tkunic, Tod, vihanjain, Yan Facai (颜发才), Yin Li, Yong Tang, Yukun Chen, Yusuke Yamada + + + # Release 1.6.0 ## Breaking Changes @@ -106,7 +295,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 * Add `complex64` support to XLA compiler. * `bfloat` support is now added to XLA infrastructure. * Make `ClusterSpec` propagation work with XLA devices. - * Use a determinisitic executor to generate XLA graph. + * Use a deterministic executor to generate XLA graph. * `tf.contrib`: * `tf.contrib.distributions`: * Add `tf.contrib.distributions.Autoregressive`. @@ -274,14 +463,6 @@ answered questions, and were part of inspiring discussions. # Release 1.4.0 -## Major Features And Improvements -* `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of - the core TensorFlow API. - * The API is now subject to backwards compatibility guarantees. - -# Release 1.4.0 - ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. * [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of diff --git a/tensorflow/SECURITY.md b/SECURITY.md similarity index 91% rename from tensorflow/SECURITY.md rename to SECURITY.md index fea24b273920885ba8a1ae96aafbf7710df46e1f..e2f6ff353a3c04a6ec6b8ccbaeb75db59fa22d54 100644 --- a/tensorflow/SECURITY.md +++ b/SECURITY.md @@ -6,7 +6,7 @@ report vulnerabilities in TensorFlow. ## TensorFlow models are programs -TensorFlow's runtime system interprets and executes programs. What machine +TensorFlow's runtime system interprets and executes programs. What machine learning practitioners term [**models**](https://developers.google.com/machine-learning/glossary/#model) are expressed as programs that TensorFlow executes. TensorFlow programs are encoded @@ -28,12 +28,12 @@ data you supply to TensorFlow to train a model, or to use a model to run inference on the data. **TensorFlow models are programs, and need to be treated as such from a security -perspective.** +perspective.** ## Running untrusted models As a general rule: **Always** execute untrusted models inside a sandbox (e.g., -[nsjail](https://github.com/google/nsjail)). +[nsjail](https://github.com/google/nsjail)). There are several ways in which a model could become untrusted. Obviously, if an untrusted party supplies TensorFlow kernels, arbitrary code may be executed. @@ -109,11 +109,11 @@ graphs known to the `ModelServer`. This means that an attacker may run graphs using untrusted inputs as described above, but they would not be able to execute arbitrary graphs. It is possible to safely expose a `ModelServer` directly to an untrusted network, **but only if the graphs it is configured to -use have been carefully audited to be safe**. +use have been carefully audited to be safe**. Similar to best practices for other servers, we recommend running any `ModelServer` with appropriate privileges (i.e., using a separate user with -reduced permisisons). In the spirit of defense in depth, we recommend +reduced permissions). In the spirit of defense in depth, we recommend authenticating requests to any TensorFlow server connected to an untrusted network, as well as sandboxing the server to minimize the adverse effects of any breach. @@ -129,11 +129,11 @@ with specially crafted inputs. ### What is a vulnerability? Given TensorFlow's flexibility, it is possible to specify computation graphs -which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models +which exhibit unexpected or unwanted behavior. The fact that TensorFlow models can perform arbitrary computations means that they may read and write files, communicate via the network, produce deadlocks and infinite loops, or run out of memory. It is only when these behaviors are outside the specifications of the -operations involved that such behavior is a vulnerability. +operations involved that such behavior is a vulnerability. A `FileWriter` writing a file is not unexpected behavior and therefore is not a vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution @@ -168,7 +168,18 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. + +In addition, please include the following information along with your report: + +* Your name and affiliation (if any). +* A description of the technical details of the vulnerabilities. It is very + important to let us know how we can reproduce your findings. +* An explanation who can exploit this vulnerability, and what they gain when + doing so -- write an attack scenario. This will help us evaluate your report + quickly, especially if the issue is complex. +* Whether this vulnerability public or known to third parties. If it is, please + provide details. If you believe that an existing (public) issue is security-related, please send an email to `security@tensorflow.org`. The email should include the issue ID and @@ -231,9 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= -----END PGP PUBLIC KEY BLOCK----- ``` -### Known vulnerabilities - -| Type | Versions affected | Reported by | Additional Information | -|-------------------|:-----------------:|--------------------|-----------------------------| -| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +### Known Vulnerabilities +For a list of known vulnerabilities and security advisories for TensorFlow, +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here]. diff --git a/WORKSPACE b/WORKSPACE index 1e38a9a8cd754886fc5232531816b875de0879a3..fd7570a80ae2ee0087f7d2fd771fcce5b9690028 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657", - strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f", + sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", + strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 ], ) @@ -14,28 +14,18 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +# We must check the bazel version before trying to parse any other BUILD +# files, in case the parsing of those build files depends on the bazel +# version we require here. +load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") +check_bazel_version_at_least("0.10.0") + load("//tensorflow:workspace.bzl", "tf_workspace") -# Uncomment and update the paths in these entries to build the Android demo. -#android_sdk_repository( -# name = "androidsdk", -# api_level = 23, -# # Ensure that you have the build_tools_version below installed in the -# # SDK manager as it updates periodically. -# build_tools_version = "26.0.1", -# # Replace with path to Android SDK on your system -# path = "", -#) -# -#android_ndk_repository( -# name="androidndk", -# path="", -# # This needs to be 14 or higher to compile TensorFlow. -# # Please specify API level to >= 21 to build for 64-bit -# # archtectures or the Android NDK will automatically select biggest -# # API level that it supports without notice. -# # Note that the NDK version is not the API level. -# api_level=14) +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() diff --git a/configure.py b/configure.py index b5436dba20ad1aadeffe8057c0a85709914f603e..ada342a50ab5104509156d3e44e6435a308255a3 100644 --- a/configure.py +++ b/configure.py @@ -35,12 +35,13 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' +_DEFAULT_NCCL_VERSION = '1.3' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -225,8 +226,6 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --force_python=py%s' % python_major_version) - write_to_bazelrc('build --host_force_python=py%s' % python_major_version) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path @@ -250,7 +249,11 @@ def reset_tf_configure_bazelrc(workspace_path): if _TF_BAZELRC_FILENAME in l: continue f.write('%s\n' % l) - f.write('import %s\n' % _TF_BAZELRC) + if is_windows(): + tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") + else: + tf_bazelrc_path = _TF_BAZELRC + f.write('import %s\n' % tf_bazelrc_path) def cleanup_makefile(): @@ -444,7 +447,7 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', 'version']) + curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -480,6 +483,8 @@ def set_cc_opt_flags(environ_cp): if is_ppc64le(): # gcc on ppc64le does not support -march, use mcpu instead default_cc_opt_flags = '-mcpu=native' + elif is_windows(): + default_cc_opt_flags = '/arch:AVX' else: default_cc_opt_flags = '-march=native' question = ('Please specify optimization flags to use during compilation when' @@ -490,14 +495,9 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - if not is_ppc64le(): + if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') - # TODO(mikecase): Remove these default defines once we are able to get - # TF Lite targets building without them. - write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -520,7 +520,7 @@ def set_tf_cuda_clang(environ_cp): def set_tf_download_clang(environ_cp): """Set TF_DOWNLOAD_CLANG action_env.""" - question = 'Do you want to download a fresh release of clang? (Experimental)' + question = 'Do you wish to download a fresh release of clang? (Experimental)' yes_reply = 'Clang will be downloaded and used to compile tensorflow.' no_reply = 'Clang will not be downloaded.' set_action_env_var( @@ -670,8 +670,9 @@ def create_android_ndk_rule(environ_cp): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.') ) - - write_android_ndk_workspace_rule(android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', + check_ndk_level(android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -733,41 +734,12 @@ def create_android_sdk_rule(environ_cp): error_msg=('The selected SDK does not have build-tools version %s ' 'available.')) - write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level) - - -def write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level): - print('Writing android_sdk_workspace rule.\n') - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_sdk_repository( - name="androidsdk", - api_level=%s, - path="%s", - build_tools_version="%s")\n -""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) - - -def write_android_ndk_workspace_rule(android_ndk_home_path): - print('Writing android_ndk_workspace rule.') - ndk_api_level = check_ndk_level(android_ndk_home_path) - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_ndk_repository( - name="androidndk", - path="%s", - api_level=%s)\n -""" % (android_ndk_home_path, ndk_api_level)) + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', + android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', + android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -780,18 +752,16 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - return revision.group(1) - return None - - -def workspace_has_any_android_rule(): - """Check the WORKSPACE for existing android_*_repository rules.""" - with open(_TF_WORKSPACE, 'r') as f: - workspace = f.read() - has_any_rule = re.search(r'^android_[ns]dk_repository', - workspace, - re.MULTILINE) - return has_any_rule + ndk_api_level = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + return ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -841,8 +811,8 @@ def reformat_version_sequence(version_str, sequence_count): def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use, ' - 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION + 'Please specify the CUDA SDK version you want to use. ' + '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. @@ -1044,7 +1014,10 @@ def set_tf_tensorrt_install_path(environ_cp): for lib_file in possible_files: if is_compatible(lib_file, cuda_ver, cudnn_ver): - ver_str = nvinfer_pattern.search(lib_file).group(1) + matches = nvinfer_pattern.search(lib_file) + if len(matches.groups()) == 0: + continue + ver_str = matches.group(1) ver = convert_version_to_int(ver_str) if len(ver_str) else 0 if ver > highest_ver[0]: highest_ver = [ver, ver_str, lib_file] @@ -1067,7 +1040,7 @@ def set_tf_tensorrt_install_path(environ_cp): break # Reset and Retry - if len(possible_files): + if possible_files: print('TensorRT libraries found in one the following directories', 'are not compatible with selected cuda and cudnn installations') print(trt_install_path) @@ -1076,7 +1049,8 @@ def set_tf_tensorrt_install_path(environ_cp): if search_result: print(libnvinfer_path_from_ldconfig) else: - print('Invalid path to TensorRT. None of the following files can be found:') + print( + 'Invalid path to TensorRT. None of the following files can be found:') print(trt_install_path) print(os.path.join(trt_install_path, 'lib')) print(os.path.join(trt_install_path, 'lib64')) @@ -1095,6 +1069,81 @@ def set_tf_tensorrt_install_path(environ_cp): write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) +def set_tf_nccl_install_path(environ_cp): + """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + + Args: + environ_cp: copy of the os.environ. + + Raises: + ValueError: if this method was called under non-Linux platform. + UserInputError: if user has provided invalid input multiple times. + """ + if not is_linux(): + raise ValueError('Currently NCCL is only supported on Linux platforms.') + + ask_nccl_version = ( + 'Please specify the NCCL version you want to use. ' + '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + tf_nccl_version = get_from_env_or_user_or_default( + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + + if tf_nccl_version == '1': + break # No need to get install path, NCCL 1 is a GitHub repo. + + # TODO(csigg): Look with ldconfig first if we can find the library in paths + # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding + # include directory. This is where the NCCL .deb packages install them. + # Then ask the user if we should use that. Instead of a single + # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to + # nccl_configure.bzl + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = (r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + break + + # Reset and Retry + print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, + nccl_hdr_path)) + + environ_cp['TF_NCCL_VERSION'] = '' + else: + raise UserInputError('Invalid TF_NCCL setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + + # Set TF_NCCL_VERSION + environ_cp['TF_NCCL_VERSION'] = tf_nccl_version + write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1143,6 +1192,9 @@ def set_tf_cuda_compute_capabilities(environ_cp): ask_cuda_compute_capabilities, default_cuda_compute_capabilities) # Check whether all capabilities from the input is valid all_valid = True + # Remove all whitespace characters before splitting the string + # that users may insert by accident, as this will result in error + tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: @@ -1345,6 +1397,10 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_build_strip_flag(): + write_to_bazelrc('build --strip=always') + + def set_windows_build_flags(): if is_windows(): # The non-monolithic build is not supported yet @@ -1372,7 +1428,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.5.4') + check_bazel_version('0.10.0') reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() @@ -1389,6 +1445,9 @@ def main(): environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on + # Windows. + environ_cp['TF_DOWNLOAD_CLANG'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1403,7 +1462,7 @@ def main(): set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', - 'with_kafka_support', False, 'kafka') + 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1428,6 +1487,8 @@ def main(): set_tf_cudnn_version(environ_cp) if is_linux(): set_tf_tensorrt_install_path(environ_cp) + set_tf_nccl_install_path(environ_cp) + set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( 'LD_LIBRARY_PATH') != '1': @@ -1436,16 +1497,8 @@ def main(): set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': - if not is_windows(): - # Ask if we want to download clang release while building. - set_tf_download_clang(environ_cp) - else: - # We use bazel's generated crosstool on Windows and there is no - # way to provide downloaded toolchain for that yet. - # TODO(ibiryukov): Investigate using clang as a cuda compiler on - # Windows. - environ_cp['TF_DOWNLOAD_CLANG'] = '0' - + # Ask whether we should download the clang toolchain. + set_tf_download_clang(environ_cp) if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) @@ -1455,6 +1508,13 @@ def main(): if not is_windows(): set_gcc_host_compiler_path(environ_cp) set_other_cuda_vars(environ_cp) + else: + # CUDA not required. Ask whether we should download the clang toolchain and + # use it for the CPU build. + set_tf_download_clang(environ_cp) + if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': + write_to_bazelrc('build --config=download_clang') + write_to_bazelrc('test --config=download_clang') set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': @@ -1463,23 +1523,18 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) + set_build_strip_flag() set_windows_build_flags() - if workspace_has_any_android_rule(): - print('The WORKSPACE file has at least one of ["android_sdk_repository", ' - '"android_ndk_repository"] already set. Will not ask to help ' - 'configure the WORKSPACE. Please delete the existing rules to ' - 'activate the helper.\n') - else: - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): - create_android_ndk_rule(environ_cp) - create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See tools/bazel.rc for ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index dc995d231d3e591771f801e28024a76610cdba26..6d134dbb80cb8c3dcf15b2ba20783870a67e9a62 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -19,6 +19,10 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/tools/api/generator:api_gen.bzl", + "gen_api_init_files", # @unused +) # Config setting for determining if we are building for Android. config_setting( @@ -240,6 +244,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_kafka_support_windows_override", + define_values = {"with_kafka_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -394,309 +405,6 @@ package_group( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -py_library( - name = "tensorflow_py", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = ["//tensorflow/python"], -) - -filegroup( - name = "all_opensource_files", - data = [ - ":all_files", - "//tensorflow/c:all_files", - "//tensorflow/cc:all_files", - "//tensorflow/cc/saved_model:all_files", - "//tensorflow/cc/saved_model/python:all_files", - "//tensorflow/cc/tools:all_files", - "//tensorflow/compiler/aot:all_files", - "//tensorflow/compiler/aot/tests:all_files", - "//tensorflow/compiler/jit:all_files", - "//tensorflow/compiler/jit/graphcycles:all_files", - "//tensorflow/compiler/jit/kernels:all_files", - "//tensorflow/compiler/jit/legacy_flags:all_files", - "//tensorflow/compiler/jit/ops:all_files", - "//tensorflow/compiler/plugin:all_files", - "//tensorflow/compiler/tests:all_files", - "//tensorflow/compiler/tf2xla:all_files", - "//tensorflow/compiler/tf2xla/cc:all_files", - "//tensorflow/compiler/tf2xla/kernels:all_files", - "//tensorflow/compiler/tf2xla/lib:all_files", - "//tensorflow/compiler/tf2xla/ops:all_files", - "//tensorflow/compiler/xla:all_files", - "//tensorflow/compiler/xla/client:all_files", - "//tensorflow/compiler/xla/client/lib:all_files", - "//tensorflow/compiler/xla/legacy_flags:all_files", - "//tensorflow/compiler/xla/python:all_files", - "//tensorflow/compiler/xla/service:all_files", - "//tensorflow/compiler/xla/service/cpu:all_files", - "//tensorflow/compiler/xla/service/gpu:all_files", - "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files", - "//tensorflow/compiler/xla/service/interpreter:all_files", - "//tensorflow/compiler/xla/service/llvm_ir:all_files", - "//tensorflow/compiler/xla/tests:all_files", - "//tensorflow/compiler/xla/tools:all_files", - "//tensorflow/compiler/xla/tools/parser:all_files", - "//tensorflow/contrib:all_files", - "//tensorflow/contrib/all_reduce:all_files", - "//tensorflow/contrib/android:all_files", - "//tensorflow/contrib/batching:all_files", - "//tensorflow/contrib/bayesflow:all_files", - "//tensorflow/contrib/boosted_trees:all_files", - "//tensorflow/contrib/boosted_trees/estimator_batch:all_files", - "//tensorflow/contrib/boosted_trees/lib:all_files", - "//tensorflow/contrib/boosted_trees/proto:all_files", - "//tensorflow/contrib/boosted_trees/resources:all_files", - "//tensorflow/contrib/cloud:all_files", - "//tensorflow/contrib/cloud/kernels:all_files", - "//tensorflow/contrib/cluster_resolver:all_files", - "//tensorflow/contrib/coder:all_files", - "//tensorflow/contrib/compiler:all_files", - "//tensorflow/contrib/copy_graph:all_files", - "//tensorflow/contrib/crf:all_files", - "//tensorflow/contrib/cudnn_rnn:all_files", - "//tensorflow/contrib/data:all_files", - "//tensorflow/contrib/data/kernels:all_files", - "//tensorflow/contrib/data/python/kernel_tests:all_files", - "//tensorflow/contrib/data/python/ops:all_files", - "//tensorflow/contrib/decision_trees/proto:all_files", - "//tensorflow/contrib/deprecated:all_files", - "//tensorflow/contrib/distributions:all_files", - "//tensorflow/contrib/eager/proto:all_files", - "//tensorflow/contrib/eager/python:all_files", - "//tensorflow/contrib/estimator:all_files", - "//tensorflow/contrib/factorization:all_files", - "//tensorflow/contrib/factorization/examples:all_files", - "//tensorflow/contrib/factorization/kernels:all_files", - "//tensorflow/contrib/feature_column:all_files", - "//tensorflow/contrib/ffmpeg:all_files", - "//tensorflow/contrib/ffmpeg/default:all_files", - "//tensorflow/contrib/framework:all_files", - "//tensorflow/contrib/fused_conv:all_files", - "//tensorflow/contrib/gan:all_files", - "//tensorflow/contrib/gdr:all_files", - "//tensorflow/contrib/graph_editor:all_files", - "//tensorflow/contrib/grid_rnn:all_files", - "//tensorflow/contrib/hooks:all_files", - "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files", - "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", - "//tensorflow/contrib/image:all_files", - "//tensorflow/contrib/input_pipeline:all_files", - "//tensorflow/contrib/input_pipeline/kernels:all_files", - "//tensorflow/contrib/integrate:all_files", - "//tensorflow/contrib/keras:all_files", - "//tensorflow/contrib/kernel_methods:all_files", - "//tensorflow/contrib/kfac:all_files", - "//tensorflow/contrib/kfac/examples:all_files", - "//tensorflow/contrib/kfac/examples/tests:all_files", - "//tensorflow/contrib/kfac/python/kernel_tests:all_files", - "//tensorflow/contrib/kfac/python/ops:all_files", - "//tensorflow/contrib/labeled_tensor:all_files", - "//tensorflow/contrib/layers:all_files", - "//tensorflow/contrib/layers/kernels:all_files", - "//tensorflow/contrib/learn:all_files", - "//tensorflow/contrib/learn/python/learn/datasets:all_files", - "//tensorflow/contrib/legacy_seq2seq:all_files", - "//tensorflow/contrib/libsvm:all_files", - "//tensorflow/contrib/linalg:all_files", - "//tensorflow/contrib/linear_optimizer:all_files", - "//tensorflow/contrib/lite:all_files", - "//tensorflow/contrib/lite/java:all_files", - "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", - "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", - "//tensorflow/contrib/lite/java/src/main/native:all_files", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", - "//tensorflow/contrib/lite/kernels:all_files", - "//tensorflow/contrib/lite/kernels/internal:all_files", - "//tensorflow/contrib/lite/models/smartreply:all_files", - "//tensorflow/contrib/lite/nnapi:all_files", - "//tensorflow/contrib/lite/python:all_files", - "//tensorflow/contrib/lite/schema:all_files", - "//tensorflow/contrib/lite/testing:all_files", - "//tensorflow/contrib/lite/toco:all_files", - "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", - "//tensorflow/contrib/lite/toco/python:all_files", - "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", - "//tensorflow/contrib/lite/toco/tflite:all_files", - "//tensorflow/contrib/lite/tools:all_files", - "//tensorflow/contrib/lookup:all_files", - "//tensorflow/contrib/losses:all_files", - "//tensorflow/contrib/makefile:all_files", - "//tensorflow/contrib/memory_stats:all_files", - "//tensorflow/contrib/meta_graph_transform:all_files", - "//tensorflow/contrib/metrics:all_files", - "//tensorflow/contrib/model_pruning:all_files", - "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", - "//tensorflow/contrib/nccl:all_files", - "//tensorflow/contrib/nearest_neighbor:all_files", - "//tensorflow/contrib/nn:all_files", - "//tensorflow/contrib/opt:all_files", - "//tensorflow/contrib/periodic_resample:all_files", - "//tensorflow/contrib/predictor:all_files", - "//tensorflow/contrib/py2tf:all_files", - "//tensorflow/contrib/py2tf/converters:all_files", - "//tensorflow/contrib/py2tf/impl:all_files", - "//tensorflow/contrib/py2tf/pyct:all_files", - "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", - "//tensorflow/contrib/py2tf/utils:all_files", - "//tensorflow/contrib/quantize:all_files", - "//tensorflow/contrib/receptive_field:all_files", - "//tensorflow/contrib/reduce_slice_ops:all_files", - "//tensorflow/contrib/remote_fused_graph/pylib:all_files", - "//tensorflow/contrib/resampler:all_files", - "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/saved_model:all_files", - "//tensorflow/contrib/saved_model/cc/saved_model:all_files", - "//tensorflow/contrib/seq2seq:all_files", - "//tensorflow/contrib/session_bundle:all_files", - "//tensorflow/contrib/session_bundle/example:all_files", - "//tensorflow/contrib/signal:all_files", - "//tensorflow/contrib/slim:all_files", - "//tensorflow/contrib/slim/python/slim/data:all_files", - "//tensorflow/contrib/slim/python/slim/nets:all_files", - "//tensorflow/contrib/solvers:all_files", - "//tensorflow/contrib/sparsemax:all_files", - "//tensorflow/contrib/specs:all_files", - "//tensorflow/contrib/staging:all_files", - "//tensorflow/contrib/stat_summarizer:all_files", - "//tensorflow/contrib/stateless:all_files", - "//tensorflow/contrib/summary:all_files", - "//tensorflow/contrib/tensor_forest:all_files", - "//tensorflow/contrib/tensor_forest/hybrid:all_files", - "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", - "//tensorflow/contrib/tensor_forest/proto:all_files", - "//tensorflow/contrib/tensorboard:all_files", - "//tensorflow/contrib/tensorboard/db:all_files", - "//tensorflow/contrib/tensorrt:all_files", - "//tensorflow/contrib/testing:all_files", - "//tensorflow/contrib/text:all_files", - "//tensorflow/contrib/tfprof:all_files", - "//tensorflow/contrib/timeseries:all_files", - "//tensorflow/contrib/timeseries/examples:all_files", - "//tensorflow/contrib/timeseries/python/timeseries:all_files", - "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", - "//tensorflow/contrib/tpu:all_files", - "//tensorflow/contrib/tpu/profiler:all_files", - "//tensorflow/contrib/tpu/proto:all_files", - "//tensorflow/contrib/training:all_files", - "//tensorflow/contrib/util:all_files", - "//tensorflow/contrib/verbs:all_files", - "//tensorflow/core:all_files", - "//tensorflow/core/api_def:all_files", - "//tensorflow/core/debug:all_files", - "//tensorflow/core/distributed_runtime:all_files", - "//tensorflow/core/distributed_runtime/rpc:all_files", - "//tensorflow/core/grappler:all_files", - "//tensorflow/core/grappler/clusters:all_files", - "//tensorflow/core/grappler/costs:all_files", - "//tensorflow/core/grappler/inputs:all_files", - "//tensorflow/core/grappler/optimizers:all_files", - "//tensorflow/core/grappler/utils:all_files", - "//tensorflow/core/kernels:all_files", - "//tensorflow/core/kernels/batching_util:all_files", - "//tensorflow/core/kernels/data:all_files", - "//tensorflow/core/kernels/data/sql:all_files", - "//tensorflow/core/kernels/fuzzing:all_files", - "//tensorflow/core/kernels/hexagon:all_files", - "//tensorflow/core/kernels/neon:all_files", - "//tensorflow/core/lib/db:all_files", - "//tensorflow/core/ops/compat:all_files", - "//tensorflow/core/platform/cloud:all_files", - "//tensorflow/core/platform/default/build_config:all_files", - "//tensorflow/core/platform/hadoop:all_files", - "//tensorflow/core/platform/s3:all_files", - "//tensorflow/core/profiler:all_files", - "//tensorflow/core/profiler/internal:all_files", - "//tensorflow/core/profiler/internal/advisor:all_files", - "//tensorflow/core/util/ctc:all_files", - "//tensorflow/core/util/tensor_bundle:all_files", - "//tensorflow/examples/adding_an_op:all_files", - "//tensorflow/examples/android:all_files", - "//tensorflow/examples/benchmark:all_files", - "//tensorflow/examples/get_started/regression:all_files", - "//tensorflow/examples/how_tos/reading_data:all_files", - "//tensorflow/examples/image_retraining:all_files", - "//tensorflow/examples/label_image:all_files", - "//tensorflow/examples/learn:all_files", - "//tensorflow/examples/multibox_detector:all_files", - "//tensorflow/examples/saved_model:all_files", - "//tensorflow/examples/speech_commands:all_files", - "//tensorflow/examples/tutorials/estimators:all_files", - "//tensorflow/examples/tutorials/layers:all_files", - "//tensorflow/examples/tutorials/mnist:all_files", - "//tensorflow/examples/tutorials/monitors:all_files", - "//tensorflow/examples/tutorials/word2vec:all_files", - "//tensorflow/examples/wav_to_spectrogram:all_files", - "//tensorflow/go:all_files", - "//tensorflow/java:all_files", - "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", - "//tensorflow/java/src/main/native:all_files", - "//tensorflow/python:all_files", - "//tensorflow/python/data:all_files", - "//tensorflow/python/data/kernel_tests:all_files", - "//tensorflow/python/data/ops:all_files", - "//tensorflow/python/data/util:all_files", - "//tensorflow/python/debug:all_files", - "//tensorflow/python/eager:all_files", - "//tensorflow/python/estimator:all_files", - "//tensorflow/python/feature_column:all_files", - "//tensorflow/python/keras:all_files", - "//tensorflow/python/kernel_tests:all_files", - "//tensorflow/python/kernel_tests/distributions:all_files", - "//tensorflow/python/kernel_tests/linalg:all_files", - "//tensorflow/python/kernel_tests/random:all_files", - "//tensorflow/python/ops/distributions:all_files", - "//tensorflow/python/ops/linalg:all_files", - "//tensorflow/python/ops/losses:all_files", - "//tensorflow/python/profiler:all_files", - "//tensorflow/python/profiler/internal:all_files", - "//tensorflow/python/saved_model:all_files", - "//tensorflow/python/tools:all_files", - "//tensorflow/tools/api/generator:all_files", - "//tensorflow/tools/api/golden:all_files", - "//tensorflow/tools/api/lib:all_files", - "//tensorflow/tools/api/tests:all_files", - "//tensorflow/tools/benchmark:all_files", - "//tensorflow/tools/build_info:all_files", - "//tensorflow/tools/ci_build/gpu_build:all_files", - "//tensorflow/tools/common:all_files", - "//tensorflow/tools/compatibility:all_files", - "//tensorflow/tools/dist_test/server:all_files", - "//tensorflow/tools/docker:all_files", - "//tensorflow/tools/docker/notebooks:all_files", - "//tensorflow/tools/docs:all_files", - "//tensorflow/tools/git:all_files", - "//tensorflow/tools/graph_transforms:all_files", - "//tensorflow/tools/mlpbtxt:all_files", - "//tensorflow/tools/proto_text:all_files", - "//tensorflow/tools/quantization:all_files", - "//tensorflow/tools/test:all_files", - "//tensorflow/user_ops:all_files", - "//third_party/eigen3:all_files", - "//third_party/fft2d:all_files", - "//third_party/flatbuffers:all_files", - "//third_party/hadoop:all_files", - "//third_party/sycl:all_files", - "//third_party/sycl/sycl:all_files", - ], - visibility = ["//visibility:public"], -) - load( "//third_party/mkl:build_defs.bzl", "if_mkl", @@ -746,11 +454,12 @@ tf_cc_shared_object( linkstatic = 1, visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:core_cpu_impl", "//tensorflow/core:framework_internal_impl", + "//tensorflow/core:gpu_runtime_impl", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", - "//tensorflow/core:core_cpu_impl", "//tensorflow/stream_executor:stream_executor_impl", - "//tensorflow/core:gpu_runtime_impl", ] + tf_additional_binary_deps(), ) @@ -766,27 +475,27 @@ tf_cc_shared_object( # excludes all but a subset of function names. # On MacOS, the linker does not support version_script, but has an # an "-exported_symbols_list" command. -z defs disallows undefined -# symbols in object files and -s strips the output. +# symbols in object files. tf_cc_shared_object( name = "libtensorflow.so", linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow/c:exported_symbols.lds", + "$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow/c:version_script.lds", + "$(location //tensorflow/c:version_script.lds)", ], }), deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", "//tensorflow/c:exported_symbols.lds", "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", @@ -799,15 +508,14 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow:tf_exported_symbols.lds", + "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow:tf_version_script.lds", + "$(location //tensorflow:tf_version_script.lds)", ], }), deps = [ @@ -829,3 +537,20 @@ exports_files( "tf_exported_symbols.lds", ], ) + +gen_api_init_files( + name = "tensorflow_python_api_gen", + srcs = ["api_template.__init__.py"], + root_init_template = "api_template.__init__.py", +) + +py_library( + name = "tensorflow_py", + srcs = [ + ":tensorflow_python_api_gen", + "//tensorflow/python/estimator/api:estimator_python_api_gen", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//tensorflow/python"], +) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 78ad6aec19f3bbbfcb389012ac1577573b3e4901..440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -20,14 +20,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=wildcard-import -from tensorflow.python import * # pylint: disable=redefined-builtin -# pylint: enable=wildcard-import +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + del absolute_import del division del print_function diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9662d7b478ba61c69edc20b0d47293f9939e7881 --- /dev/null +++ b/tensorflow/api_template.__init__.py @@ -0,0 +1,58 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + +try: + import os # pylint: disable=g-import-not-at-top + # Add `estimator` attribute to allow access to estimator APIs via + # "tf.estimator..." + from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top + + # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." + # style imports. + from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top + __path__ += [os.path.dirname(estimator_api.__file__)] + del estimator_api + del os +except (ImportError, AttributeError): + print('tf.estimator package not installed.') + +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 5dfb743681255d8c03e91ea43fd441d94fdee59d..8a9301d584775cff3ae315e6fd856b00d1734248 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -17,7 +17,10 @@ load( filegroup( name = "headers", - srcs = ["c_api.h"], + srcs = [ + "c_api.h", + "c_api_experimental.h", + ], visibility = ["//tensorflow:__subpackages__"], ) @@ -31,6 +34,8 @@ filegroup( exclude = [ "c_api_experimental.cc", "c_api_experimental.h", + "python_api.cc", + "python_api.h", "*test*", ], ), @@ -113,6 +118,11 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/contrib/tpu:all_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", ], ) @@ -209,6 +219,27 @@ tf_cuda_cc_test( ], ) +tf_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["testdata/tf_record"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api_experimental", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "c_api_function_test", size = "small", @@ -253,20 +284,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + # TODO(b/74620627): remove when _USE_C_SHAPES is removed + "//tensorflow/python:cpp_shape_inference_proto_cc", ], ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 85f1d1639b4d09f2de77d326481a86ec246270d0..cb0b093ad260e000dcef9d1123e967a77cf1a041 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -30,6 +30,7 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" @@ -62,6 +63,7 @@ limitations under the License. // brain namespace because we are defining 'extern "C"' functions. using tensorflow::AllocationDescription; using tensorflow::DataType; +using tensorflow::ExtendSessionGraphHelper; using tensorflow::Graph; using tensorflow::GraphDef; using tensorflow::mutex_lock; @@ -73,6 +75,7 @@ using tensorflow::NodeBuilder; using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistry; +using tensorflow::OutputTensor; using tensorflow::PartialTensorShape; using tensorflow::RunMetadata; using tensorflow::RunOptions; @@ -628,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, "Failed to allocate memory to serialize message of type '", in.GetTypeName(), "' and size ", proto_size); } - in.SerializeToArray(buf, proto_size); + // SerializeToArray takes size as an int. + // This next 'if' is a workaround till we update to depend on a version + // of protocol buffers that includes + // https://github.com/google/protobuf/pull/4739 + if (proto_size > std::numeric_limits::max()) { + return InvalidArgument("Cannot serialize protocol buffer of type ", + in.GetTypeName(), " as the serialized size (", + proto_size, + "bytes) would be larger than the limit (", + std::numeric_limits::max(), " bytes)"); + } + if (!in.SerializeToArray(buf, proto_size)) { + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { @@ -638,17 +656,17 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + const char* mutation_type) { // If any session has already run this node_id, mark this session as // unrunnable. for (auto it : graph->sessions) { + mutex_lock session_lock(it.first->mu); if (it.first->last_num_graph_nodes > op.node.id()) { - it.second = FailedPrecondition( + it.second = strings::StrCat( "Operation '", op.node.DebugString(), "' was changed by ", mutation_type, - " after it was run by a session. Nodes can be mutated " - "only before they are executed by a session. Either don't modify " + " after it was run by a session. This mutation will have no effect, " + "and will trigger an error in the future. Either don't modify " "nodes after running them or create a new session."); } } @@ -708,6 +726,61 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); +// TODO(josh11b,mrry): Change Session to be able to use a Graph* +// directly, instead of requiring us to serialize to a GraphDef and +// call Session::Extend(). +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { + if (session->graph != nullptr) { + // Take the graph lock before the session lock to avoid deadlock. This is + // safe since session->graph does not change. + session->graph->mu.lock(); + mutex_lock session_lock(session->mu); + const Graph& graph = session->graph->graph; + + const string& mutation_warning = session->graph->sessions[session]; + if (!mutation_warning.empty()) { + // TODO(b/74949947): turn this back into an error status + LOG(WARNING) << mutation_warning; + session->graph->sessions[session].clear(); + } + + const auto num_nodes = graph.num_node_ids(); + if (session->last_num_graph_nodes < num_nodes) { + status->status = tensorflow::ValidateNoCycles(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + + GraphDef graph_def; + *graph_def.mutable_versions() = graph.versions(); + // Fill graph_def with nodes with ids in the range + // [session->last_num_graph_nodes, num_nodes), that is the nodes + // added since the last TF_SessionRun() call. + for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { + Node* const node = graph.FindNodeId(id); + if (node != nullptr && node->IsOp()) { + NodeDef* const node_def = graph_def.add_node(); + *node_def = node->def(); + } + } + *graph_def.mutable_library() = graph.flib_def().ToProto(); + session->graph->mu.unlock(); + status->status = session->session->Extend(graph_def); + if (!status->status.ok()) { + // Contract is we always delete input_values[i]. + return false; + } + // Note: session->session is not modified if Extend() fails, so + // we only set last_num_graph_nodes if it succeeds. + session->last_num_graph_nodes = num_nodes; + } else { + session->graph->mu.unlock(); + } + } + return true; +} + } // namespace tensorflow static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, @@ -2039,7 +2112,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; @@ -2408,11 +2481,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, // TF_Session functions ---------------------------------------------- TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) { - if (s->LocalDeviceManager(&device_mgr).ok()) { - devices = device_mgr->ListDevices(); - } -} + : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { @@ -2422,7 +2491,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->sessions[new_session] = Status::OK(); + graph->sessions[new_session] = ""; } return new_session; } else { @@ -2488,7 +2557,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->sessions[session] = Status::OK(); + graph->sessions[session] = ""; session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2512,58 +2581,6 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { delete s; } -// TODO(josh11b,mrry): Change Session to be able to use a Graph* -// directly, instead of requiring us to serialize to a GraphDef and -// call Session::Extend(). -static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { - if (session->graph != nullptr) { - mutex_lock session_lock(session->mu); - session->graph->mu.lock(); - const Graph& graph = session->graph->graph; - - status->status = session->graph->sessions[session]; - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - const auto num_nodes = graph.num_node_ids(); - if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - GraphDef graph_def; - *graph_def.mutable_versions() = graph.versions(); - // Fill graph_def with nodes with ids in the range - // [session->last_num_graph_nodes, num_nodes), that is the nodes - // added since the last TF_SessionRun() call. - for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { - Node* const node = graph.FindNodeId(id); - if (node != nullptr && node->IsOp()) { - NodeDef* const node_def = graph_def.add_node(); - *node_def = node->def(); - } - } - *graph_def.mutable_library() = graph.flib_def().ToProto(); - session->graph->mu.unlock(); - status->status = session->session->Extend(graph_def); - if (!status->status.ok()) { - // Contract is we always delete input_values[i]. - return false; - } - // Note: session->session is not modified if Extend() fails, so - // we only set last_num_graph_nodes if it succeeds. - session->last_num_graph_nodes = num_nodes; - } else { - session->graph->mu.unlock(); - } - } - return true; -} - void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, const TF_Output* outputs, @@ -2573,7 +2590,8 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2610,7 +2628,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, const char** handle, TF_Status* status) { *handle = nullptr; - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2653,7 +2672,8 @@ void TF_SessionPRun(TF_Session* session, const char* handle, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2682,6 +2702,24 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, + TF_Tensor** result, TF_Status* status) { + *result = nullptr; + mutex_lock l(graph->mu); + OutputTensor tensor(&output.oper->node, output.index); + bool evaluated; + Tensor result_tensor; + status->status = EvaluateConstantTensor( + tensor, graph->refiner, *graph->graph.op_registry(), + graph->graph.versions().producer(), &evaluated, &result_tensor); + if (evaluated) { + DCHECK(status->status.ok()); + *result = TF_TensorFromTensor(result_tensor, status); + if (!status->status.ok()) evaluated = false; + } + return evaluated; +} + TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { tensorflow::OpList op_list; if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index ad592ef70961ef427bfe9fd322a82bd64df7f9f1..c8594347451dffd465d7fa926cc53818dc9e38d4 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -72,7 +72,7 @@ limitations under the License. #ifdef SWIG #define TF_CAPI_EXPORT #else -#if defined(COMPILER_MSVC) +#if defined(_WIN32) #ifdef TF_COMPILE_LIBRARY #define TF_CAPI_EXPORT __declspec(dllexport) #else @@ -80,7 +80,7 @@ limitations under the License. #endif // TF_COMPILE_LIBRARY #else #define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // COMPILER_MSVC +#endif // _WIN32 #endif // SWIG #ifdef __cplusplus @@ -1275,13 +1275,22 @@ TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( // Deleting a function does not remove it from any graphs it was copied to. TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); +// Attempts to evaluate `output`. This will only be possible if `output` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, this function returns true and `output`s +// value is returned in `result`. Otherwise returns false. An error status is +// returned if something is wrong with the graph or input. Note that this may +// return false even if no error status is set. +TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, + TF_Output output, + TF_Tensor** result, + TF_Status* status); + // TODO(josh11b): Register OpDef, available to all operations added // to this graph. -// The following two may both benefit from a subgraph-definition API -// that re-uses most of the graph-definition API. -// TODO(andydavis): Add functions to a graph. - // -------------------------------------------------------------------------- // API for driving Graph execution. @@ -1487,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieves the type of the device at the given index. // @@ -1497,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieve the amount of memory associated with a given device. // // If index is out of bounds, an error code will be set in the status object, // and -1 will be returned. TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status*); + const TF_DeviceList* list, int index, TF_Status* status); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index be7f85a5bb06dce84579b109d506ded049042b50..95b04f9058afdfaadbc24f0238860279fcd3e800 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,8 +17,27 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" +using tensorflow::FunctionDef; +using tensorflow::Node; +using tensorflow::NodeBuilder; +using tensorflow::Status; + +namespace { +typedef std::unique_ptr + UniqueFuncPtr; +} + +// struct TF_Operation { tensorflow::Node node; }; +static TF_Operation* ToTF_Operation(Node* node) { + return static_cast(static_cast(node)); +} + void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { tensorflow::ConfigProto& config = options->options.config; auto* optimizer_options = @@ -37,3 +56,8402 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); } } + +const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { + tensorflow::mutex_lock c(graph->mu); + const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + +// On success, returns a set of TF_Function instances from `text_proto` of +// GraphDef type. These functions must be deleted by calling TF_DeleteFunction. +// +// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto, +// before creating a TF_Function out of the possibly mutated proto. +static std::vector CreateFunctionsFromTextProto( + const char* text_proto, + std::function* mutate_proto_func, TF_Status* status) { + tensorflow::GraphDef gdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for GraphDef: ", text_proto); + return {}; + } + const auto& fdef_lib = gdef.library(); + if (fdef_lib.gradient_size() > 0) { + status->status = tensorflow::errors::Internal( + "GradientDef is not supported in reading Dataset related functions: ", + text_proto); + return {}; + } + std::vector ret; + for (const FunctionDef& fdef : fdef_lib.function()) { + // Make a copy so that we can mutate it. + FunctionDef fdef_to_load = fdef; + if (mutate_proto_func) { + (*mutate_proto_func)(&fdef_to_load); + } + VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString(); + std::vector binary_proto_buf(fdef_to_load.ByteSizeLong()); + fdef_to_load.SerializeToArray(binary_proto_buf.data(), + binary_proto_buf.size()); + TF_Function* func = TF_FunctionImportFunctionDef( + binary_proto_buf.data(), binary_proto_buf.size(), status); + if (!status->status.ok()) return {}; + ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction)); + } + return ret; +} + +// On success, returns a newly created TF_Function instance encoding a dataset +// node stack that returns a sequence of 3 floats, and sets `dataset_name` to +// the created dataset name. The returned function must be deleted by calling +// TF_DeleteFunction. +static UniqueFuncPtr CreateFakeDatasetFunction(std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "_make_dataset_d8de2712" + output_arg { + name: "TensorSliceDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000(B\000\000,B\000\0000B" + } + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/tensors/component_0:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key: "TensorSliceDataset" + value: "TensorSliceDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_d8de2712"; + auto functions = CreateFunctionsFromTextProto( + func_def, /*mutate_proto_func*/ nullptr, status); + DCHECK_EQ(functions.size(), 1); + return std::move(functions[0]); +} + +#if not defined(PLATFORM_WINDOWS) +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateImagenetDatasetFunctions( + const char* file_path, std::string* dataset_name, TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return std::vector(); +#else + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_91295dea" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "FlatMapDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "flat_filenames/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node_def { + name: "flat_filenames" + op: "Reshape" + input: "arg0" + input: "flat_filenames/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "flat_filenames:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "FlatMapDataset" + op: "FlatMapDataset" + input: "TensorSliceDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_0cc8c35b" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + ret { + key: "FlatMapDataset" + value: "FlatMapDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_0cc8c35b" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "TFRecordDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "compression_type" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8388608 + } + } + } + } + node_def { + name: "TFRecordDataset" + op: "TFRecordDataset" + input: "arg0" + input: "compression_type:output:0" + input: "buffer_size:output:0" + } + ret { + key: "TFRecordDataset" + value: "TFRecordDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_74b6b15c" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "Reshape_1" + type: DT_FLOAT + } + output_arg { + name: "sub_1" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "ParseSingleExample/key_image/class/label" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape" + op: "Reshape" + input: "ParseSingleExample/key_image/class/label:output:0" + input: "ParseSingleExample/Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/class/text" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1" + op: "Reshape" + input: "ParseSingleExample/key_image/class/text:output:0" + input: "ParseSingleExample/Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/encoded" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2" + op: "Reshape" + input: "ParseSingleExample/key_image/encoded:output:0" + input: "ParseSingleExample/Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/format" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "jpeg" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3" + op: "Reshape" + input: "ParseSingleExample/key_image/format:output:0" + input: "ParseSingleExample/Reshape_3/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/ParseSingleExample" + op: "ParseSingleExample" + input: "arg0" + input: "ParseSingleExample/Reshape:output:0" + input: "ParseSingleExample/Reshape_1:output:0" + input: "ParseSingleExample/Reshape_2:output:0" + input: "ParseSingleExample/Reshape_3:output:0" + attr { + key: "Tdense" + value { + list { + type: DT_INT64 + type: DT_STRING + type: DT_STRING + type: DT_STRING + } + } + } + attr { + key: "dense_keys" + value { + list { + s: "image/class/label" + s: "image/class/text" + s: "image/encoded" + s: "image/format" + } + } + } + attr { + key: "dense_shapes" + value { + list { + shape { + } + shape { + } + shape { + } + shape { + } + } + } + } + attr { + key: "num_sparse" + value { + i: 5 + } + } + attr { + key: "sparse_keys" + value { + list { + s: "image/object/bbox/xmax" + s: "image/object/bbox/xmin" + s: "image/object/bbox/ymax" + s: "image/object/bbox/ymin" + s: "image/object/class/label" + } + } + } + attr { + key: "sparse_types" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_INT64 + } + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:2" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/Substr/pos:output:0" + input: "decode_image/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/is_jpeg/Substr/pos:output:0" + input: "decode_image/is_jpeg/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\377\330\377" + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal" + op: "Equal" + input: "decode_image/is_jpeg/Substr:output:0" + input: "decode_image/is_jpeg/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/Switch" + op: "Switch" + input: "decode_image/is_jpeg/Equal:z:0" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/pred_id" + op: "Identity" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/check_jpeg_channels/x:output:0" + input: "decode_image/cond_jpeg/check_jpeg_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/check_jpeg_channels:z:0" + input: "decode_image/cond_jpeg/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg" + op: "DecodeJpeg" + input: "decode_image/cond_jpeg/DecodeJpeg/Switch:output_true:0" + input: "^decode_image/cond_jpeg/Assert/Assert" + attr { + key: "acceptable_fraction" + value { + f: 1.0 + } + } + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dct_method" + value { + s: "" + } + } + attr { + key: "fancy_upscaling" + value { + b: true + } + } + attr { + key: "ratio" + value { + i: 1 + } + } + attr { + key: "try_recover_truncated" + value { + b: false + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\211PN" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png" + op: "Equal" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/is_png/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/Switch" + op: "Switch" + input: "decode_image/Substr:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png:z:0" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng" + op: "DecodePng" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1:output_true:0" + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dtype" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "GIF" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/is_gif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/is_gif/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd" + op: "LogicalAnd" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1:z:0" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif" + op: "DecodeGif" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1:output_true:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr" + op: "Substr" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "BM" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp" + op: "DecodeBmp" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + attr { + key: "channels" + value { + i: 0 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp:image:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Merge:output:0" + input: "decode_image/cond_jpeg/cond_png/DecodePng:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/Merge:output:0" + input: "decode_image/cond_jpeg/DecodeJpeg:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/Cast" + op: "Cast" + input: "decode_image/cond_jpeg/Merge:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.00392156885937 + } + } + } + } + node_def { + name: "convert_image" + op: "Mul" + input: "convert_image/Cast:y:0" + input: "convert_image/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 4 + } + } + tensor_content: "\000\000\000\000\000\000\000\000\000\000\200?\000\000\200?" + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.10000000149 + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2" + op: "SampleDistortedBoundingBoxV2" + input: "distorted_bounding_box_crop/Shape:output:0" + input: "Const:output:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "area_range" + value { + list { + f: 0.0799999982119 + f: 1.0 + } + } + } + attr { + key: "aspect_ratio_range" + value { + list { + f: 0.75 + f: 1.33333337307 + } + } + } + attr { + key: "max_attempts" + value { + i: 1 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "use_image_if_no_bounding_boxes" + value { + b: true + } + } + } + node_def { + name: "distorted_bounding_box_crop/Slice" + op: "Slice" + input: "convert_image:z:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:begin:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:size:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Shape_1" + op: "Shape" + input: "distorted_bounding_box_crop/Slice:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "Shape:output:0" + input: "Shape_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "Equal:z:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + } + node_def { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "Sum" + op: "Sum" + input: "Cast:y:0" + input: "Const_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node_def { + name: "GreaterEqual/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "GreaterEqual" + op: "GreaterEqual" + input: "Sum:output:0" + input: "GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Switch" + op: "Switch" + input: "GreaterEqual:z:0" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_t" + op: "Identity" + input: "cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_f" + op: "Identity" + input: "cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/pred_id" + op: "Identity" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/Shape" + op: "Shape" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Shape/Switch" + op: "Switch" + input: "convert_image:z:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@convert_image" + } + } + } + } + node_def { + name: "cond/Cast" + op: "Cast" + input: "cond/Shape:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice/stack:output:0" + input: "cond/strided_slice/stack_1:output:0" + input: "cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/strided_slice_1/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice_1/stack:output:0" + input: "cond/strided_slice_1/stack_1:output:0" + input: "cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Greater" + op: "Greater" + input: "cond/strided_slice:output:0" + input: "cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Switch" + op: "Switch" + input: "cond/Greater:z:0" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_t" + op: "Identity" + input: "cond/cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_f" + op: "Identity" + input: "cond/cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/pred_id" + op: "Identity" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/strided_slice/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice/stack:output:0" + input: "cond/cond/strided_slice/stack_1:output:0" + input: "cond/cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice_1/stack:output:0" + input: "cond/cond/strided_slice_1/stack_1:output:0" + input: "cond/cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv" + op: "RealDiv" + input: "cond/cond/strided_slice:output:0" + input: "cond/cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul/y" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul" + op: "Mul" + input: "cond/cond/truediv:z:0" + input: "cond/cond/mul/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast/x/1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast/x" + op: "Pack" + input: "cond/cond/mul:z:0" + input: "cond/cond/Cast/x/1:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast" + op: "Cast" + input: "cond/cond/Cast/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_2/stack:output:0" + input: "cond/cond/strided_slice_2/stack_1:output:0" + input: "cond/cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice_2/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_3/stack:output:0" + input: "cond/cond/strided_slice_3/stack_1:output:0" + input: "cond/cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv_1" + op: "RealDiv" + input: "cond/cond/strided_slice_2:output:0" + input: "cond/cond/strided_slice_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul_1/y" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul_1" + op: "Mul" + input: "cond/cond/truediv_1:z:0" + input: "cond/cond/mul_1/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast_1/x/0" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast_1/x" + op: "Pack" + input: "cond/cond/Cast_1/x/0:output:0" + input: "cond/cond/mul_1:z:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast_1" + op: "Cast" + input: "cond/cond/Cast_1/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Merge" + op: "Merge" + input: "cond/cond/Cast_1:y:0" + input: "cond/cond/Cast:y:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic/images" + op: "Pack" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic" + op: "ResizeBicubic" + input: "cond/ResizeBicubic/images:output:0" + input: "cond/cond/Merge:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_2/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2" + op: "StridedSlice" + input: "cond/ResizeBicubic:resized_images:0" + input: "cond/strided_slice_2/stack:output:0" + input: "cond/strided_slice_2/stack_1:output:0" + input: "cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_1" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_3/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3" + op: "StridedSlice" + input: "cond/Shape_1:output:0" + input: "cond/strided_slice_3/stack:output:0" + input: "cond/strided_slice_3/stack_1:output:0" + input: "cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_2" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_4/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4" + op: "StridedSlice" + input: "cond/Shape_2:output:0" + input: "cond/strided_slice_4/stack:output:0" + input: "cond/strided_slice_4/stack_1:output:0" + input: "cond/strided_slice_4/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/sub/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub" + op: "Sub" + input: "cond/strided_slice_3:output:0" + input: "cond/sub/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add" + op: "Add" + input: "cond/sub:z:0" + input: "cond/add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv/Cast" + op: "Cast" + input: "cond/add:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/Cast_1" + op: "Cast" + input: "cond/truediv/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv" + op: "RealDiv" + input: "cond/truediv/Cast:y:0" + input: "cond/truediv/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/sub_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub_1" + op: "Sub" + input: "cond/strided_slice_4:output:0" + input: "cond/sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add_1" + op: "Add" + input: "cond/sub_1:z:0" + input: "cond/add_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv_1/Cast" + op: "Cast" + input: "cond/add_1:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/Cast_1" + op: "Cast" + input: "cond/truediv_1/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1" + op: "RealDiv" + input: "cond/truediv_1/Cast:y:0" + input: "cond/truediv_1/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Shape_3" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Rank" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal" + op: "Equal" + input: "cond/Rank:output:0" + input: "cond/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Assert/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert" + op: "Assert" + input: "cond/Equal:z:0" + input: "cond/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/strided_slice_5/stack" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_2" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_5" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_5/stack:output:0" + input: "cond/strided_slice_5/stack_1:output:0" + input: "cond/strided_slice_5/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/stack/0" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack/1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack" + op: "Pack" + input: "cond/stack/0:output:0" + input: "cond/stack/1:output:0" + input: "cond/strided_slice_5:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/strided_slice_6/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_6/stack:output:0" + input: "cond/strided_slice_6/stack_1:output:0" + input: "cond/strided_slice_6/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual" + op: "GreaterEqual" + input: "cond/strided_slice_6:output:0" + input: "cond/GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_7/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_7/stack:output:0" + input: "cond/strided_slice_7/stack_1:output:0" + input: "cond/strided_slice_7/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual_1" + op: "GreaterEqual" + input: "cond/strided_slice_7:output:0" + input: "cond/GreaterEqual_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/LogicalAnd" + op: "LogicalAnd" + input: "cond/GreaterEqual:z:0" + input: "cond/GreaterEqual_1:z:0" + } + node_def { + name: "cond/Assert_1/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert" + op: "Assert" + input: "cond/LogicalAnd:z:0" + input: "cond/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/stack_1/2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_DOUBLE + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_DOUBLE + tensor_shape { + } + double_val: 0.0 + } + } + } + } + node_def { + name: "cond/stack_1" + op: "Pack" + input: "cond/truediv:z:0" + input: "cond/truediv_1:z:0" + input: "cond/stack_1/2:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ToInt32" + op: "Cast" + input: "cond/stack_1:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Slice" + op: "Slice" + input: "cond/strided_slice_2:output:0" + input: "cond/ToInt32:y:0" + input: "cond/stack:output:0" + input: "^cond/Assert_1/Assert" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/Reshape" + op: "Reshape" + input: "cond/Slice:output:0" + input: "cond/stack:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images" + op: "Pack" + input: "cond/ResizeBicubic_1/images/Switch:output_false:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images/Switch" + op: "Switch" + input: "distorted_bounding_box_crop/Slice:output:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@distorted_bounding_box_crop/Slice" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1/size" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\340\000\000\000\340\000\000\000" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1" + op: "ResizeBicubic" + input: "cond/ResizeBicubic_1/images:output:0" + input: "cond/ResizeBicubic_1/size:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_8/stack" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_1" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_2" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8" + op: "StridedSlice" + input: "cond/ResizeBicubic_1:resized_images:0" + input: "cond/strided_slice_8/stack:output:0" + input: "cond/strided_slice_8/stack_1:output:0" + input: "cond/strided_slice_8/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Merge" + op: "Merge" + input: "cond/strided_slice_8:output:0" + input: "cond/Reshape:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\354Q\370>\325x\351>;\337\317>" + } + } + } + } + node_def { + name: "sub" + op: "Sub" + input: "cond/Merge:output:0" + input: "Const_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\372~j>B`e>fff>" + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "sub:z:0" + input: "Const_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/control_dependency" + op: "Identity" + input: "truediv:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/min" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/max" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/RandomUniform" + op: "RandomUniform" + input: "random_flip_left_right/random_uniform/shape:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/sub" + op: "Sub" + input: "random_flip_left_right/random_uniform/max:output:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/mul" + op: "Mul" + input: "random_flip_left_right/random_uniform/RandomUniform:output:0" + input: "random_flip_left_right/random_uniform/sub:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform" + op: "Add" + input: "random_flip_left_right/random_uniform/mul:z:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Less/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node_def { + name: "random_flip_left_right/Less" + op: "Less" + input: "random_flip_left_right/random_uniform:z:0" + input: "random_flip_left_right/Less/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Switch" + op: "Switch" + input: "random_flip_left_right/Less:z:0" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_t" + op: "Identity" + input: "random_flip_left_right/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_f" + op: "Identity" + input: "random_flip_left_right/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/pred_id" + op: "Identity" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/axis" + op: "Const" + input: "^random_flip_left_right/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2" + op: "ReverseV2" + input: "random_flip_left_right/ReverseV2/Switch:output_true:0" + input: "random_flip_left_right/ReverseV2/axis:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/Switch" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Switch_1" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Merge" + op: "Merge" + input: "random_flip_left_right/Switch_1:output_false:0" + input: "random_flip_left_right/ReverseV2:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\340\000\000\000\340\000\000\000\003\000\000\000" + } + } + } + } + node_def { + name: "Reshape_1" + op: "Reshape" + input: "random_flip_left_right/Merge:output:0" + input: "Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape_2" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:0" + input: "Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast_1" + op: "Cast" + input: "Reshape_2:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_INT64 + } + } + } + node_def { + name: "sub_1/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "sub_1" + op: "Sub" + input: "Cast_1:y:0" + input: "sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "Reshape_1" + value: "Reshape_1:output:0" + } + ret { + key: "sub_1" + value: "sub_1:z:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_5fa5e1f4" + output_arg { + name: "PrefetchDataset_1" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "TensorSliceDataset/MatchingFiles" + op: "MatchingFiles" + input: "TensorSliceDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/MatchingFiles:filenames:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles" + op: "MatchingFiles" + input: "ShuffleDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "ShuffleDataset/Shape" + op: "Shape" + input: "ShuffleDataset/MatchingFiles:filenames:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice" + op: "StridedSlice" + input: "ShuffleDataset/Shape:output:0" + input: "ShuffleDataset/strided_slice/stack:output:0" + input: "ShuffleDataset/strided_slice/stack_1:output:0" + input: "ShuffleDataset/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "ShuffleDataset/Maximum/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/Maximum" + op: "Maximum" + input: "ShuffleDataset/strided_slice:output:0" + input: "ShuffleDataset/Maximum/y:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "TensorSliceDataset:handle:0" + input: "ShuffleDataset/Maximum:z:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ShuffleDataset_1/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed2_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1" + op: "ShuffleDataset" + input: "ShuffleDataset:handle:0" + input: "ShuffleDataset_1/buffer_size:output:0" + input: "ShuffleDataset_1/seed_1:output:0" + input: "ShuffleDataset_1/seed2_1:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "ShuffleDataset_1:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/cycle_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/block_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/sloppy" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: true + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/buffer_output_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/prefetch_input_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset" + op: "ParallelInterleaveDataset" + input: "RepeatDataset:handle:0" + input: "ParallelInterleaveDataset/cycle_length:output:0" + input: "ParallelInterleaveDataset/block_length:output:0" + input: "ParallelInterleaveDataset/sloppy:output:0" + input: "ParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_91295dea" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ShuffleDataset_2/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed2_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2" + op: "ShuffleDataset" + input: "ParallelInterleaveDataset:handle:0" + input: "ShuffleDataset_2/buffer_size_1:output:0" + input: "ShuffleDataset_2/seed_2:output:0" + input: "ShuffleDataset_2/seed2_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ParallelMapDataset/num_parallel_calls" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 64 + } + } + } + } + node_def { + name: "ParallelMapDataset" + op: "ParallelMapDataset" + input: "ShuffleDataset_2:handle:0" + input: "ParallelMapDataset/num_parallel_calls:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_74b6b15c" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "PrefetchDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "PrefetchDataset" + op: "PrefetchDataset" + input: "ParallelMapDataset:handle:0" + input: "PrefetchDataset/buffer_size_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "PrefetchDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + node_def { + name: "PrefetchDataset_1/buffer_size_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "PrefetchDataset_1" + op: "PrefetchDataset" + input: "FilterDataset:handle:0" + input: "PrefetchDataset_1/buffer_size_3:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 64 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: 64 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + ret { + key: "PrefetchDataset_1" + value: "PrefetchDataset_1:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_5fa5e1f4"; + std::function mutate_proto_func = + [dataset_name, file_path](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found = false; + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() != "TensorSliceDataset/MatchingFiles/pattern" && + node_def.name() != "ShuffleDataset/MatchingFiles/pattern") + continue; + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found = true; + DCHECK_EQ(node_def.attr().at("value").tensor().string_val(0), + "$(DATA_DIR)"); + VLOG(1) << "Setting the value of node_def " + "TensorSliceDataset/MatchingFiles/pattern to " + << file_path; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(file_path); + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +#endif +} +#endif + +#if not defined(PLATFORM_WINDOWS) +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads an MNIST file dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateMNISTDatasetFunctions( + const char* file_path, int batch_size, std::string* dataset_name, + TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return nullptr; +#else + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_521bfd08" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "truediv" + type: DT_FLOAT + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "DecodeRaw:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 784 + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "Cast:y:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "truediv/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 255.0 + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "Reshape:output:0" + input: "truediv/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "truediv" + value: "truediv:z:0" + } + } + function { + signature { + name: "tf_map_func_9a08860d" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "ToInt32" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "DecodeRaw:output:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_UINT8 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ToInt32" + op: "Cast" + input: "Reshape:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + ret { + key: "ToInt32" + value: "ToInt32:y:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_2451e43a" + output_arg { + name: "FilterDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "FixedLengthRecordDataset/filenames" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-images-idx3-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/header_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/record_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 784 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/footer_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset/filenames:output:0" + input: "FixedLengthRecordDataset/header_bytes:output:0" + input: "FixedLengthRecordDataset/record_bytes:output:0" + input: "FixedLengthRecordDataset/footer_bytes:output:0" + input: "FixedLengthRecordDataset/buffer_size:output:0" + } + node_def { + name: "MapDataset" + op: "MapDataset" + input: "FixedLengthRecordDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_521bfd08" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/filenames_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-labels-idx1-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/header_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/record_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/footer_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset_1/filenames_1:output:0" + input: "FixedLengthRecordDataset_1/header_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/record_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/footer_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/buffer_size_1:output:0" + } + node_def { + name: "MapDataset_1" + op: "MapDataset" + input: "FixedLengthRecordDataset_1:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_9a08860d" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + } + node_def { + name: "ZipDataset" + op: "ZipDataset" + input: "MapDataset:handle:0" + input: "MapDataset_1:handle:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "CacheDataset/filename" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "CacheDataset" + op: "CacheDataset" + input: "ZipDataset:handle:0" + input: "CacheDataset/filename:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "CacheDataset:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "ShuffleDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 50000 + } + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "RepeatDataset:handle:0" + input: "ShuffleDataset/buffer_size_2:output:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "ShuffleDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + ret { + key: "FilterDataset" + value: "FilterDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_2451e43a"; + std::function mutate_proto_func = + [dataset_name, file_path, batch_size](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found_file_path = false, found_batch_size = false; + // `node_def` may be mutated. + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() == "FixedLengthRecordDataset/filenames" || + node_def.name() == "FixedLengthRecordDataset_1/filenames_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_file_path = true; + // Replace $(DATA_DIR)/foo with /foo + // TODO(hongm): Use StringPiece manipulation for better efficiency. + const std::string cur_value = + node_def.attr().at("value").tensor().string_val(0); + const std::string pattern = "$(DATA_DIR)"; + DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); + const std::string new_value = + file_path + cur_value.substr(pattern.length()); + VLOG(1) << "Setting the value of node_def " << node_def.name() + << " to " << new_value; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(new_value); + } else if (node_def.name() == "BatchDataset/batch_size" || + node_def.name() == "FilterDataset/batch_size_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_batch_size = true; + // Replace $(BATCH_SIZE) with `batch_size` + DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123); + VLOG(1) << "Setting the batch size attr value of node_def " + << node_def.name() << " to " << batch_size; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_int64_val(); + tensor->add_int64_val(batch_size); + } + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found_file_path); + DCHECK(found_batch_size); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +#endif +} +#endif + +// Adds the input functions to `graph`. On success, returns the created +// IteratorGetNext node. +static TF_Operation* AddDatasetFunctionAndIteratorNodesToGraph( + const std::vector& funcs, const std::string& dataset_name, + const std::vector& output_types, + const std::vector& output_shapes, + TF_Graph* graph, TF_Status* status) { + DCHECK(!dataset_name.empty()); + for (auto& func : funcs) { + TF_GraphCopyFunction(graph, func.get(), /*gradient*/ nullptr, status); + if (!status->status.ok()) { + return nullptr; + } + } + + tensorflow::mutex_lock c(graph->mu); + + tensorflow::NameAttrList func; + func.set_name(dataset_name); + // Run the iterator node on CPU. + Node* oneshot_iterator_node; + tensorflow::Status s = NodeBuilder("OneShotIterator", "OneShotIterator") + .Device("/device:CPU:0") + .Attr("container", "") + .Attr("dataset_factory", func) + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Attr("shared_name", "") + .Finalize(&graph->graph, &oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + // Run the iterator node on CPU. + Node* getnext_node; + s = NodeBuilder("IteratorGetNext", "IteratorGetNext") + .Input(oneshot_iterator_node) + .Device("/device:CPU:0") + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Finalize(&graph->graph, &getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + VLOG(1) << "Output graph: " << graph->graph.ToGraphDefDebug().DebugString(); + return ToTF_Operation(getnext_node); +} + +TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph, + TF_Status* status) { + tensorflow::Status s; + + std::string dataset_name; + UniqueFuncPtr result_func = CreateFakeDatasetFunction(&dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector funcs; + funcs.push_back(std::move(result_func)); + std::vector output_shape_list; + output_shape_list.push_back(tensorflow::TensorShapeProto()); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT}, output_shape_list, graph, + status); + if (!status->status.ok()) { + return nullptr; + } + + return getnext_node; +} + +TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + // TODO(ashankar): get these functions working on Windows. + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return nullptr; +#else + tensorflow::Status s; + + std::string dataset_name; + const auto& funcs = + is_mnist + ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name, + status) + : CreateImagenetDatasetFunctions(file_path, &dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector output_shape_list; + // batch_size X 224 X 224 X 3 + auto image_shape = tensorflow::TensorShapeProto(); + image_shape.add_dim()->set_size(batch_size); + if (is_mnist) { + image_shape.add_dim()->set_size(784); + } else { + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(3); + } + output_shape_list.push_back(image_shape); + + // batch_size + auto label_shape = tensorflow::TensorShapeProto(); + label_shape.add_dim()->set_size(batch_size); + output_shape_list.push_back(label_shape); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT, tensorflow::DT_INT32}, + output_shape_list, graph, status); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::mutex_lock c(graph->mu); + VLOG(1) << "The extended graph: " + << graph->graph.ToGraphDefDebug().DebugString(); + + return getnext_node; +#endif +} + +TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + VLOG(1) << "Dequeuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + } + + TF_Operation* dequeue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str()); + if (dequeue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the dequeue node in the TF graph."); + return nullptr; + } + + VLOG(1) << "Running the dequeue op"; + TF_Output output{dequeue_op, 0}; + TF_Tensor* ret; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ &output, /*output_values*/ &ret, + /*noutputs*/ 1, + /*targets*/ nullptr, /*ntargets*/ 0, + /*run_metadata*/ nullptr, status); + if (VLOG_IS_ON(1) && status->status.ok()) { + tensorflow::Tensor tensor; + if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) { + VLOG(1) << "Dequeued tensor content: " << tensor.DebugString(); + } + } + return ret; +} + +void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TF_Tensor* tensor, TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Enqueuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + tensorflow::Tensor internal_tensor; + if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { + VLOG(1) << "Enqueu'ing tensor content: " + << internal_tensor.DebugString(); + } + } + } + + TF_Operation* enqueue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str()); + if (enqueue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the enqueue node in the TF graph."); + return; + } + + TF_Operation* placeholder_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str()); + if (placeholder_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the placeholder node as input to enqueue in the TF " + "graph."); + return; + } + + VLOG(1) << "Running the enqueue op"; + TF_Output input{placeholder_op, 0}; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, + /*targets*/ &enqueue_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); + VLOG(1) << "Enqueuing is done."; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 5a7b007e40aa199889b2d00b2bde5976c19e2966..20bdace40f1272ded06e710034053a7610326e7f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -25,6 +25,7 @@ limitations under the License. // Experimental C API for TensorFlow. // // The API here is subject to changes in the future. +// -------------------------------------------------------------------------- // Macro to control visibility of exported symbols in the shared library (.so, // .dylib, .dll). @@ -34,7 +35,7 @@ limitations under the License. #ifdef SWIG #define TF_CAPI_EXPORT #else -#if defined(COMPILER_MSVC) +#if defined(_WIN32) #ifdef TF_COMPILE_LIBRARY #define TF_CAPI_EXPORT __declspec(dllexport) #else @@ -42,7 +43,7 @@ limitations under the License. #endif // TF_COMPILE_LIBRARY #else #define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // COMPILER_MSVC +#endif // _WIN32 #endif // SWIG #ifdef __cplusplus @@ -59,6 +60,61 @@ extern "C" { TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Creates a stack of data set + iterator nodes, currently hard-coded to return +// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, +// returns the IteratorGetNext node, which caller can run or feed into an node. +// +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets( + TF_Graph* graph, TF_Status* status); + +// Similar to the above API, except that the returned iterator reads the +// file based dataset from `file_path`. +// If `is_mnist` is 0, the dataset corresponds to ImageNet. +// The iterators outputs 2 tensors: +// - A float tensor of shape `batch_size` X 784 when `is_mnist` is non-zero, or +// `batch_size` X 224 X 224 X 3 otherwise. +// - An int32 tensor of shape `batch_size` +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status); + +// On success, dequeues a tensor from a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. +// +// Tensors are enqueued via the corresponding TF enqueue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, + int tensor_id, + TF_Status* status); + +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30fcfd401d9d634962d64aaa3bf348de91f2ecae --- /dev/null +++ b/tensorflow/c/c_api_experimental_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void TestFakeIteratorStack() { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + const float base_value = 42.0; + for (int i = 0; i < 3; ++i) { + csession.SetOutputs({get_next}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(out)); + ASSERT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(float), TF_TensorByteSize(out)); + float* output_contents = static_cast(TF_TensorData(out)); + ASSERT_EQ(base_value + i, *output_contents); + } + + // This should error out since we've exhausted the iterator. + csession.Run(s); + ASSERT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)) << TF_Message(s); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); } + +TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + const string file_path = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record"); + VLOG(1) << "data file path is " << file_path; + const int batch_size = 64; + TF_Operation* get_next = TF_MakeFileBasedIteratorGetNextWithDatasets( + graph, file_path.c_str(), batch_size, /*is_mnist*/ false, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + // The two output tensors should look like: + // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32) + // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32) + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Running iter " << i; + csession.SetOutputs({{get_next, 0}, {get_next, 1}}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + { + TF_Tensor* image = csession.output_tensor(0); + ASSERT_TRUE(image != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(image)); + // Confirm shape is 224 X 224 X 3 + ASSERT_EQ(4, TF_NumDims(image)); + ASSERT_EQ(batch_size, TF_Dim(image, 0)); + ASSERT_EQ(224, TF_Dim(image, 1)); + ASSERT_EQ(224, TF_Dim(image, 2)); + ASSERT_EQ(3, TF_Dim(image, 3)); + ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3, + TF_TensorByteSize(image)); + } + + { + TF_Tensor* label = csession.output_tensor(1); + ASSERT_TRUE(label != nullptr); + ASSERT_EQ(TF_INT32, TF_TensorType(label)); + ASSERT_EQ(1, TF_NumDims(label)); + ASSERT_EQ(batch_size, TF_Dim(label, 0)); + ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label)); + } + } + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 7ca50119eafe299b307f06c555aec1388e7e82e2..610274696f5940c063e68f2310cfd9cc1e0bd964 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 91667056e0eeb224b4b8a034766f11a123cd1a03..95652a11378d6276b5ba6540a07baa15aa77cc1c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -84,19 +84,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // The keys of this map are all the active sessions using this graph. - // Each value is the current "runnability" status of the corresponding - // session. Under normal conditions all statuses are Status::OK(), but - // if some operation is mutated after it was run by a session (this - // is detected in RecordMutation function), that session is no longer - // safe to run. Its status will contain the error that will be returned - // to the user, should she try running this session. + // The keys of this map are all the active sessions using this graph. Each + // value records whether the graph has been mutated since the corresponding + // session has been run (this is detected in RecordMutation function). If the + // string is empty, no mutation has occurred. Otherwise the string is a + // description of the mutation suitable for returning to the user. // // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. // TF_Graph may only / must be deleted when // sessions.size() == 0 && delete_requested == true - tensorflow::gtl::FlatMap sessions + // + // TODO(b/74949947): mutations currently trigger a warning instead of a bad + // status, this should be reverted when possible. + tensorflow::gtl::FlatMap sessions GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph @@ -124,15 +125,16 @@ struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; - TF_Graph* graph; + TF_Graph* const graph; - tensorflow::mutex mu; + tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); int last_num_graph_nodes; - // NOTE(ashankar): Experimental fields to help keep the - // buffers of a TF_Tensor pinned in device memory. - const tensorflow::DeviceMgr* device_mgr; // Owned by session. - std::vector devices; // Owned by device_mgr. + // If true, TF_SessionRun and similar methods will call + // ExtendSessionGraphHelper before running the graph (this is the default + // public behavior). Can be set to false if the caller needs to call + // ExtendSessionGraphHelper manually. + std::atomic extend_before_run; }; struct TF_ImportGraphDefOptions { @@ -210,7 +212,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, TF_Status* status); void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type); + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu); + +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) + LOCKS_EXCLUDED(session->graph->mu, session->mu); } // end namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 028f146be31790b211e546978302e81afe26b231..577f10c5e69ea9ecbe8ce821c6bd5167e98bef25 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -53,7 +53,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) { } const tensorflow::string input_op_name = - tensorflow::ParseTensorName(input_name).first.ToString(); + std::string(tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); @@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); const tensorflow::string output_op_name = - tensorflow::ParseTensorName(output_name).first.ToString(); + std::string(tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); @@ -1700,7 +1700,7 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } -// REGISTER_OP for CApiTestAttributesTest test cases. +// REGISTER_OP for CApiAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other // will have list(type). diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 3db2852ce6560ba493d60ef54a110161c112d110..f3b28c1708129d39e451d927a89c0d10e2193b63 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -34,6 +34,10 @@ static void DoubleDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +static void FloatDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -78,21 +82,34 @@ TF_Tensor* DoubleTensor(double v) { &DoubleDeallocator, nullptr); } +TF_Tensor* FloatTensor(float v) { + const int num_bytes = sizeof(float); + float* values = new float[1]; + values[0] = v; + return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, + &FloatDeallocator, nullptr); +} + // All the *Helper methods are used as a workaround for the restrictions that // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_DataType dtype, const std::vector& dims, TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); - TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrType(desc, "dtype", dtype); + if (!dims.empty()) { + TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } *op = TF_FinishOperation(desc, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_NE(*op, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name, + TF_DataType dtype, const std::vector& dims) { TF_Operation* op; - PlaceholderHelper(graph, s, name, &op); + PlaceholderHelper(graph, s, name, dtype, dims, &op); return op; } @@ -126,6 +143,12 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op, bool check) { diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 2a70177c724c569844a5d8ad42b99bed20209946..c16aba666ee6974fed5351c2d9ac291dcbcdecab 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/types.pb.h" @@ -44,8 +45,12 @@ TF_Tensor* Int32Tensor(int32_t v); TF_Tensor* DoubleTensor(double v); +TF_Tensor* FloatTensor(float v); + TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, - const char* name = "feed"); + const char* name = "feed", + TF_DataType dtype = TF_INT32, + const std::vector& dims = {}); TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); @@ -56,6 +61,9 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add"); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index b1f7bdaa5420a56386e6983052df20aa976aa867..74bc25a491ac01cb725d1c004197e48727c30230 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - v2_reader_->key().ToString() /* full var's name */, + std::string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; + if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = v2_reader_->key().ToString(); + string key = std::string(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index e55cb672e97e1403a3dd864c91c176426eb3f067..f265da2c2c89c0e9caf14f2213c606fcb69997e0 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -14,6 +14,7 @@ tf_cuda_library( name = "c_api", srcs = [ "c_api.cc", + "c_api_debug.cc", "c_api_internal.h", ], hdrs = ["c_api.h"], @@ -24,9 +25,16 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -38,9 +46,22 @@ tf_cuda_library( "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", ], "//conditions:default": [], }) + [ + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", ], ) @@ -51,72 +72,74 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":c_api", - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) -tf_cuda_cc_test( - name = "c_api_test", - srcs = ["c_api_test.cc"], - extra_copts = tfe_xla_copts(), - tags = [ - "guitar", - "multi_gpu", +tf_cuda_library( + name = "c_api_test_util", + testonly = 1, + srcs = ["c_api_test_util.cc"], + hdrs = ["c_api_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", ], deps = [ ":c_api", "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", - "//tensorflow/core:test_main", ], ) -tf_cuda_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - copts = tf_copts(), - visibility = ["//tensorflow:internal"], - deps = select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/c:c_api", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], - }), -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], +tf_cuda_cc_test( + name = "c_api_test", + srcs = [ + "c_api_debug_test.cc", + "c_api_test.cc", + ], + extra_copts = tfe_xla_copts(), + tags = [ + "guitar", + "multi_gpu", + ], deps = [ - ":runtime", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", + ":c_api", + ":c_api_test_util", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 98ef6f0d0ab094eae3e2e21624c3a4ba30d1c3d3..81221c4078bec9820ee187efdf0314da378be62b 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,22 +24,33 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" +#include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" @@ -62,9 +73,121 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#ifdef TENSORFLOW_EAGER_USE_XLA -std::atomic_int_fast64_t func_id_generator(0); -#endif // TENSORFLOW_EAGER_USE_XLA +tensorflow::Status GetAllRemoteDevices( + const std::vector& remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { + std::vector remote_devices; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + for (const string& remote_worker : remote_workers) { + tensorflow::Notification n; + tensorflow::NewRemoteDevices( + tensorflow::Env::Default(), worker_cache, remote_worker, + [&status, &n, &remote_devices]( + const tensorflow::Status& s, + std::vector* devices) { + status = s; + if (s.ok()) { + for (tensorflow::Device* d : *devices) { + remote_devices.push_back(d); + } + } + n.Notify(); + }); + n.WaitForNotification(); + } + std::unique_ptr remote_device_mgr( + new tensorflow::DeviceMgr(remote_devices)); + + TF_RETURN_IF_ERROR(status); + + *device_mgr = std::move(remote_device_mgr); + return tensorflow::Status::OK(); +} + +tensorflow::Status CreateRemoteContexts( + const std::vector& remote_workers, + tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + tensorflow::gtl::FlatMap* remote_contexts) { + for (int i = 0; i < remote_workers.size(); i++) { + const string& remote_worker = remote_workers[i]; + + tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextResponse response; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + return tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + } + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.set_async(async); + auto* eager_client = remote_eager_workers->GetClient(remote_worker); + if (eager_client == nullptr) { + return tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + tensorflow::Notification n; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + eager_client->CreateContextAsync( + &request, &response, [&status, &n](const tensorflow::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(status); + + remote_contexts->emplace(remote_worker, response.context_id()); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, + TFE_Context** ctx) { + string worker_name = tensorflow::strings::StrCat( + "/job:", opts->server_def.job_name(), + "/replica:0/task:", opts->server_def.task_index()); + std::unique_ptr server; + TF_RETURN_IF_ERROR( + tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); + + TF_RETURN_IF_ERROR(server->Start()); + + std::vector remote_workers; + server->master_env()->worker_cache->ListWorkers(&remote_workers); + remote_workers.erase( + std::remove(remote_workers.begin(), remote_workers.end(), worker_name), + remote_workers.end()); + + std::unique_ptr remote_device_mgr; + TF_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + + std::shared_ptr channel_cache = + server->channel_cache(); + std::unique_ptr remote_eager_workers( + tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); + + // Initialize remote eager workers. + tensorflow::gtl::FlatMap remote_contexts; + TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); + + tensorflow::RemoteRendezvous* r = + server->worker_env()->rendezvous_mgr->Find(0); + + auto* device_mgr = server->worker_env()->device_mgr; + *ctx = new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, r, std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts); + + return tensorflow::Status::OK(); +} } // namespace extern "C" { @@ -76,181 +199,158 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, TF_SetConfig(&options->session_options, proto, proto_len, status); } +void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, + unsigned char async) { + options->async = async; +} void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status) { + if (!options->server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + } +} + +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, + unsigned char async, + TF_Status* status) { + status->status = ctx->context.SetAsyncForThread(async); +} + void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, &opts->session_options, status); - if (status->status.ok()) { - if (session->device_mgr == nullptr || session->devices.empty()) { - status->status = tensorflow::errors::InvalidArgument( - "Provided TF_SessionOptions are not compatible with eager execution " - "(perhaps the TF_SessionOptions alluded to session execution in a " - "remote address space?)"); - } - } - if (!status->status.ok()) { - TF_DeleteGraph(graph); - return nullptr; + if (!opts->server_def.job_name().empty()) { + TFE_Context* ctx = nullptr; + status->status = NewRemoteAwareTFE_Context(opts, &ctx); + return ctx; } - return new TFE_Context(*opts, session); -} + std::vector devices; + status->status = tensorflow::DeviceFactory::AddDevices( + opts->session_options.options, "/job:localhost/replica:0/task:0", + &devices); + if (!status->status.ok()) return nullptr; + std::unique_ptr device_mgr( + new tensorflow::DeviceMgr(devices)); -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - status->status = tensorflow::Status::OK(); - { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); - } - TF_Graph* graph = ctx->session->graph; - TF_DeleteSession(ctx->session, status); - TF_DeleteGraph(graph); - ctx->rendezvous->Unref(); - delete ctx; + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, std::move(device_mgr), r); } +void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } + TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { - return TF_SessionListDevices(ctx->session, status); + TF_DeviceList* list = new TF_DeviceList; + ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context.remote_device_mgr()) { + ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + } + return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); -} +void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - ctx->thread_local_policies[std::this_thread::get_id()] = policy; + ctx->context.SetThreadLocalDevicePlacementPolicy( + static_cast(policy)); } +// Note: this function looks up a thread local policy. So it should be called in +// the appropriate client thread. In particular, in async mode, it may not be +// safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - auto policy_map_it = - ctx->thread_local_policies.find(std::this_thread::get_id()); - if (policy_map_it != ctx->thread_local_policies.end()) { - return policy_map_it->second; - } - return ctx->policy; + return static_cast( + ctx->context.GetDevicePlacementPolicy()); +} + +void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.AsyncWait(); +} + +void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.GetStatus(); +} + +void TFE_ContextAsyncClearError(TFE_Context* ctx) { + ctx->context.ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; - return new TFE_TensorHandle(tensor, nullptr); + return new TFE_TensorHandle(tensor, nullptr, nullptr); } -void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } +void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { + DCHECK(h); + if (h->handle) { + h->handle->Unref(); + } + delete h; +} TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->t.dtype()); + return static_cast(h->handle->dtype); } -int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); } - -int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { - return h->t.dim_size(dim_index); +int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->Tensor(&t); + return t == nullptr ? 0 : t->dims(); } -const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { - // TODO(apassos) this will be potentially incorrect in the distributed case as - // our local device will have a name which depends on the ClusterSpec and - // hence will require the context to resolve. - return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" - : h->d->name().c_str(); +int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, + TF_Status* status) { + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->Tensor(&t); + return t == nullptr ? 0 : t->dim_size(dim_index); } -TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { - if (!IsCPU(h->d)) { - TF_SetStatus(status, TF_UNIMPLEMENTED, - tensorflow::strings::StrCat( - "TFE_TensorHandle can be resolved iff it is on CPU (this " - "handle is on ", - h->d->name(), - "). Consider using TFE_TensorHandleCopyToDevice to get a " - "copy of the tensor on CPU") - .c_str()); - return nullptr; - } - return tensorflow::TF_TensorFromTensor(h->t, status); +const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + tensorflow::Device* d = nullptr; + status->status = h->handle->OpDevice(&d); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); } -TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status) { - tensorflow::Device* dstd = ctx->devices()[0]; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } - - tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d; - bool is_same_device = - (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); - const bool dst_cpu = IsCPU(dstd); - const bool src_cpu = IsCPU(srcd); - // both_on_cpu can be true and yet is_same_device is false, if one of src/dst - // has device type XLA_CPU, and the other CPU. - const bool both_on_cpu = src_cpu && dst_cpu; - if (is_same_device || both_on_cpu) { - return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); - } - tensorflow::Tensor* src = &(h->t); - if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && - !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat("Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), - " to device ", DeviceName(dstd), ".") - .c_str()); - return nullptr; - } - tensorflow::AllocatorAttributes attr; - if (src->dtype() == tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); - if (src->shape().num_elements() == 0) { - return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); - } - tensorflow::DeviceContext* src_device_context = nullptr; - if (!src_cpu) { - src_device_context = srcd->tensorflow_gpu_device_info()->default_context; +TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + // TODO(agarwal): move this implementation inside TFE_TensorHandle. + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + tensorflow::TensorHandle* h_cpu = nullptr; + if (!IsCPU(d)) { + status->status = h->handle->CopyToDevice( + h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + if (!status->status.ok()) { + return nullptr; + } + status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) { + h_cpu->Unref(); + return nullptr; + } } - tensorflow::DeviceContext* dst_device_context = nullptr; - if (!dst_cpu) { - dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; + TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); + if (h_cpu != nullptr) { + h_cpu->Unref(); } - // TODO(ashankar): The Sync() call below may be more aggressive than - // necessary. It is based on knowledge of implementation details - that - // GPU devices are implemented using 3 streams - one for host->device copies, - // one for device->host copies and one for sending operations to the GPU. - // With that setup, Sync()ing across all 3 streams should be sufficient - // but more than necessary (since it waits for operations that might have - // nothing to do with this tensor to complete). - status->status = srcd->Sync(); - tensorflow::Notification n; - tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, - srcd, dstd, tensorflow::AllocatorAttributes(), - tensorflow::AllocatorAttributes(), src, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); - n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) - ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) - : nullptr; + return retval; } TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, @@ -260,8 +360,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, status->status = tensorflow::AttrTypeMapForOp(name, &types); if (status->status.ok()) return new TFE_Op(ctx, name, types); if (TF_GetCode(status) == TF_NOT_FOUND) { - tensorflow::mutex_lock l(ctx->functions_mu); - if (ctx->func_lib_def.Find(name) != nullptr) { + if (ctx->context.FindFunctionByName(name)) { status->status = tensorflow::Status::OK(); return new TFE_Op(ctx, name, nullptr); } @@ -272,23 +371,18 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { - tensorflow::Device* d = nullptr; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - op->ctx->session->device_mgr->LookupDevice(device_name, &d); - if (!status->status.ok()) return; - } - op->device = d; + status->status = op->operation.SetDevice(device_name); } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { - tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->devices()[0] : op->device; + tensorflow::Device* device = (op->operation.Device() == nullptr) + ? op->operation.EagerContext()->HostCPU() + : op->operation.Device(); return device->name().c_str(); } void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { - op->use_xla = enable; + op->operation.SetUseXla(enable); #ifndef TENSORFLOW_EAGER_USE_XLA LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " "built with XLA support."; @@ -296,31 +390,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - // Questionable heuristic ... - // - // Motivation: After an 'op' is placed on GPU because some of its earlier - // inputs are on GPU, we want to keep the 'op' there, even if some later - // inputs of it are not on GPU. - if (IsCPU(op->device) && !IsCPU(h->d)) { - op->device = h->d; - } - if (!status->status.ok()) return; - op->inputs.push_back(h->t); - op->input_devices.push_back(h->d); - op->attrs.NumInputs(op->inputs.size()); + op->operation.AddInput(h->handle); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret; - if (op->is_function()) { + if (op->operation.is_function()) { status->status = tensorflow::errors::Unimplemented( "TODO(apassos): Support for attributes for TensorFlow functions is not " "ready yet."); return TF_ATTR_INT; // The compiler requires that we return something. } - status->status = - tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); + status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), + attr_name, &ret, is_list); return ret; } @@ -339,23 +422,24 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, } void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->attrs.Set(attr_name, value); + op->operation.MutableAttrs()->Set(attr_name, value); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { - op->attrs.Set(attr_name, static_cast(value)); + op->operation.MutableAttrs()->Set(attr_name, static_cast(value)); } void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { - op->attrs.Set(attr_name, value); + op->operation.MutableAttrs()->Set(attr_name, value); } void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { - op->attrs.Set(attr_name, (value == 0) ? false : true); + op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true); } void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { - op->attrs.Set(attr_name, static_cast(value)); + op->operation.MutableAttrs()->Set(attr_name, + static_cast(value)); } void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, @@ -377,23 +461,24 @@ void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, proto.add_dim()->set_size(dims[d]); } } - op->attrs.Set(attr_name, proto); + op->operation.MutableAttrs()->Set(attr_name, proto); } void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value) { tensorflow::AttrValue attr_value; tensorflow::NameAttrList* func = attr_value.mutable_func(); - func->set_name(value->name); - value->attrs.FillAttrValueMap(func->mutable_attr()); - op->attrs.Set(attr_name, attr_value); + func->set_name(value->operation.Name()); + value->operation.Attrs().FillAttrValueMap(func->mutable_attr()); + op->operation.MutableAttrs()->Set(attr_name, attr_value); } #define TFE_OP_SET_ATTR_LIST(fn, type) \ void fn(TFE_Op* op, const char* attr_name, const type* values, \ int num_values) { \ - op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice( \ - values, num_values)); \ + op->operation.MutableAttrs()->Set( \ + attr_name, \ + tensorflow::gtl::ArraySlice(values, num_values)); \ } TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) @@ -401,14 +486,14 @@ TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice( - reinterpret_cast(values), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice( + reinterpret_cast(values), num_values)); } void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, const TF_DataType* values, int num_values) { - op->attrs.Set( + op->operation.MutableAttrs()->Set( attr_name, tensorflow::gtl::ArraySlice( reinterpret_cast(values), num_values)); @@ -420,8 +505,8 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, for (int i = 0; i < num_values; ++i) { b[i] = values[i]; } - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice(b.get(), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice(b.get(), num_values)); } void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, @@ -451,9 +536,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, } } } - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice( - proto.get(), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice( + proto.get(), num_values)); } void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, @@ -461,432 +546,41 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, std::unique_ptr funcs( new tensorflow::NameAttrList[num_values]); for (int i = 0; i < num_values; i++) { - funcs[i].set_name(value[i]->name); - value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr()); - } - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice( - funcs.get(), num_values)); -} - -namespace { - -tensorflow::Status ValidateInputTypeAndPlacement( - TFE_Context* ctx, tensorflow::Device* host_device, - tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel, - std::vector* copied_tensors) { - const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); - if (memtypes.size() != op->inputs.size()) { - return tensorflow::errors::InvalidArgument( - "expected ", memtypes.size(), " inputs, got ", op->inputs.size()); - } - for (int i = 0; i < op->inputs.size(); ++i) { - const tensorflow::Device* expected_device = - memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; - const tensorflow::Device* actual_device = - op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; - if (expected_device != actual_device) { - switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { - case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: - // TODO(xpan): See if we could bubble python related error up - // to python level. - if (op->inputs[i].dtype() == tensorflow::DT_INT32) { - // Note: enabling silent copies of int32 tensors to match behavior - // of graph mode. - break; - } - TF_FALLTHROUGH_INTENDED; - case TFE_DEVICE_PLACEMENT_EXPLICIT: - return tensorflow::errors::InvalidArgument( - "Tensors on conflicting devices:" - " cannot compute ", - op->name, " as input #", i, " was expected to be on ", - expected_device->name(), " but is actually on ", - actual_device->name(), " (operation running on ", - op_device->name(), ")", - " Tensors can be copied explicitly using .gpu() or .cpu()," - " or transparently copied by using tfe.enable_eager_execution(" - "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" - " may slow down your model"); - case TFE_DEVICE_PLACEMENT_WARN: - LOG(WARNING) << "before computing " << op->name << " input #" << i - << " was expected to be on " << expected_device->name() - << " but is actually on " << actual_device->name() - << " (operation running on " << op_device->name() - << "). This triggers a copy which can be a performance " - "bottleneck."; - break; - case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. - break; - } - // We are only here if the policy is warn or silent copies, so we should - // trigger a copy. - TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; - TF_Status* s = TF_NewStatus(); - TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - &original, ctx, expected_device->name().c_str(), s); - if (!s->status.ok()) { - tensorflow::Status status = s->status; - delete s; - return tensorflow::errors::Internal( - "Failed copying input tensor from ", actual_device->name(), " to ", - expected_device->name(), " in order to run ", op->name, ": ", - status.error_message()); - } - op->inputs[i] = copied_tensor->t; - copied_tensors->push_back(copied_tensor); - op->input_devices[i] = copied_tensor->d; - delete s; - } - if (op->inputs[i].dtype() != kernel->input_type(i)) { - return tensorflow::errors::InvalidArgument( - "cannot compute ", op->name, " as input #", i, - " was expected to be a ", - tensorflow::DataTypeString(kernel->input_type(i)), - " tensor but is a ", - tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor"); - } - } - return tensorflow::Status::OK(); -} - -#ifdef TENSORFLOW_EAGER_USE_XLA -// Synthesizes and returns a wrapper function over `op`, which must be a -// primitive op (e.g. matmul). -// -// The wrapper function conforms to the function signature expected by -// _XlaLaunchOp, with input params ordered by . For example, if the op has input params , they will be reordered to as the input params to the synthesized function. -// -// It populates `const_input_types`, `arg_input_types` and -// `op_input_to_func_input` based on the reordering results, that the caller can -// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets -// `status` accordingly. -const tensorflow::FunctionDef* OpToFunction( - TFE_Op* op, std::vector* const_input_types, - std::vector* arg_input_types, - tensorflow::gtl::FlatMap* op_input_to_func_input, - TF_Status* status) { - DCHECK(!op->is_function()); - - tensorflow::FunctionDef fdef; - - // Get the OpDef of the op we are trying to encapsulate. - TFE_Context* ctx = op->ctx; - const tensorflow::OpRegistrationData* op_data; - { - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.LookUp(op->name, &op_data); - if (!status->status.ok()) { - return nullptr; - } - } - const tensorflow::OpDef& op_def = op_data->op_def; - - tensorflow::OpDef* signature = fdef.mutable_signature(); - - // Handle constant inputs. - const std::unordered_set const_inputs( - *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); - - // First add place holders for the input args, so that we can refer to them by - // position in the next loop. Also tally up the resource inputs. - int num_resource_inputs = 0; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { - ++num_resource_inputs; - } - signature->add_input_arg(); - } - - // Now we map the input params from `op_def` to `signature`, where the param - // ordering for `signature` is: . - int const_index = 0; - int arg_index = const_inputs.size(); - int resource_index = op_def.input_arg_size() - num_resource_inputs; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); - tensorflow::OpDef::ArgDef* func_input_arg = nullptr; - if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { - VLOG(1) << "For const input, mapping op input " << i << " to func input " - << const_index; - (*op_input_to_func_input)[i] = const_index; - func_input_arg = signature->mutable_input_arg(const_index++); - const_input_types->push_back( - static_cast(op->inputs[i].dtype())); - } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { - VLOG(1) << "For resource input, mapping op input " << i - << " to func input " << resource_index; - (*op_input_to_func_input)[i] = resource_index; - func_input_arg = signature->mutable_input_arg(resource_index++); - } else { - VLOG(1) << "For arg input, mapping op input " << i << " to func input " - << arg_index; - (*op_input_to_func_input)[i] = arg_index; - func_input_arg = signature->mutable_input_arg(arg_index++); - arg_input_types->push_back( - static_cast(op->inputs[i].dtype())); - } - - func_input_arg->set_name(op_input_arg.name()); - func_input_arg->set_type(op->inputs[i].dtype()); - } - VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); - - // Resources args are at the end of the function input params, and we should - // have iterated over all of them. - DCHECK_EQ(signature->input_arg_size(), resource_index); - - // Make the synthesized function's name unique. - signature->set_name(tensorflow::strings::StrCat( - op_def.name(), func_id_generator.fetch_add(1))); - - // Add the node def and set its input names to match op_def's names. - const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); - *fdef.add_node_def() = ndef; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); - } - VLOG(1) << "Added NodeDef: " << fdef.DebugString(); - - // Fix the output names and set output types. - for (int i = 0; i < op_def.output_arg_size(); ++i) { - tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); - const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); - const string& out_tensor_name = tensorflow::strings::StrCat( - ndef.name(), ":", op_def_arg.name(), ":", 0); - arg->set_name(op_def_arg.name()); - (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; - const string& type_attr = op_def_arg.type_attr(); - if (!type_attr.empty()) { - auto i = ndef.attr().find(type_attr); - if (i == ndef.attr().end()) { - status->status = tensorflow::errors::InvalidArgument( - tensorflow::strings::StrCat("Could not find attr ", type_attr, - " in NodeDef ", ndef.DebugString())); - return nullptr; - } - arg->set_type(i->second.type()); - } - } - VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(fdef); - if (!status->status.ok()) return nullptr; - const auto ret = ctx->func_lib_def.Find(signature->name()); - DCHECK(ret != nullptr); - return ret; -} - -// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed -// via XLA. -std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { - VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; - auto launch_op = - std::unique_ptr(TFE_NewOp(op->ctx, "_XlaLaunch", status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - if (op->device) { - TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - - const tensorflow::FunctionDef* fdef; - { - tensorflow::tf_shared_lock l(op->ctx->functions_mu); - fdef = op->ctx->func_lib_def.Find(op->name); - } - std::vector const_input_types; - std::vector arg_input_types; - tensorflow::gtl::FlatMap op_input_to_func_input; - if (fdef == nullptr) { - // See if this is a primitive op, and if so create a function for it, so - // that _XlaLaunchOp can access it. - fdef = OpToFunction(op, &const_input_types, &arg_input_types, - &op_input_to_func_input, status); - if (!status->status.ok()) return nullptr; - } else { - // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for - // functions, so we need to find another way to handle constant inputs. - for (int i = const_input_types.size(); - i < fdef->signature().input_arg_size(); ++i) { - VLOG(1) << "Adding Targs from input arg " << i; - const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); - arg_input_types.push_back(static_cast(arg.type())); - } - } - DCHECK(fdef != nullptr); - - // Copy inputs and their devices. - // Since input param reordering may have occurred between `op` and `launch_op` - // via `op_input_to_func_input`, adjust the actual inputs accordingly. - launch_op->inputs = op->inputs; - launch_op->input_devices = op->input_devices; - if (!op_input_to_func_input.empty()) { - DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); - if (!op->input_devices.empty()) { - DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); - } - for (int i = 0; i < op_input_to_func_input.size(); ++i) { - VLOG(1) << "mapping op input " << i << " to func input " - << op_input_to_func_input[i]; - - launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; - if (!op->input_devices.empty()) { - launch_op->input_devices[op_input_to_func_input[i]] = - op->input_devices[i]; - } - } + funcs[i].set_name(value[i]->operation.Name()); + value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr()); } - launch_op->attrs.NumInputs(op->inputs.size()); - - TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(), - const_input_types.size()); - - // Set Targs and Nresources attrs. - TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(), - arg_input_types.size()); - const int num_resource_inputs = fdef->signature().input_arg_size() - - const_input_types.size() - - arg_input_types.size(); - TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs); - - // Set Tresults attr. - std::vector tresults; - for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { - tresults.push_back(static_cast(arg.type())); - } - TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(), - tresults.size()); - - // Set function attr. - tensorflow::AttrValue attr_value; - tensorflow::NameAttrList* func = attr_value.mutable_func(); - func->set_name(fdef->signature().name()); - launch_op->attrs.Set("function", attr_value); - - return launch_op; + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice( + funcs.get(), num_values)); } -#endif // TENSORFLOW_EAGER_USE_XLA -} // namespace void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - TFE_Context* ctx = op->ctx; - // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU - tensorflow::Device* device = - (op->device == nullptr) ? ctx->devices()[0] : op->device; - -#ifdef TENSORFLOW_EAGER_USE_XLA - std::unique_ptr xla_launch_op; - if (op->use_xla && op->name != "_XlaLaunch") { - xla_launch_op = BuildXlaLaunch(op, status); - if (!status->status.ok()) { - return; - } - op = xla_launch_op.get(); - } -#endif // TENSORFLOW_EAGER_USE_XLA - - std::vector outputs(1); - const tensorflow::MemoryTypeVector* output_memory_types = nullptr; - tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); - tensorflow::KernelAndDevice* kernel; - { - tensorflow::tf_shared_lock l(ctx->cache_mu); - kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); - } - if (kernel == nullptr) { - const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); - // Knowledge of the implementation of Init (and in-turn - // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def - // will be accessed, so grab on to the lock. - // See WARNING comment below - would be nice to rework to avoid this - // subtlety. - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = - tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); - if (!status->status.ok()) { - delete kernel; - return; - } - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); - } - std::vector copied_tensors; - status->status = ValidateInputTypeAndPlacement( - ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); - output_memory_types = &kernel->kernel()->output_memory_types(); + tensorflow::gtl::InlinedVector handle_retvals( + *num_retvals); + status->status = + tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals); if (!status->status.ok()) { - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); - } return; } - std::unique_ptr maybe_stats; - if (ctx->should_store_metadata.load()) { - maybe_stats.reset(new tensorflow::NodeExecStats); - maybe_stats->set_node_name(op->name); - maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); - maybe_stats->set_op_start_rel_micros(0); - maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); - // TODO(apassos) track referenced tensors - } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, - // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. - status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); - } - if (!status->status.ok()) return; - if (maybe_stats != nullptr) { - maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(ctx->metadata_mu); - if (ctx->should_store_metadata.load()) { - auto* step_stats = ctx->run_metadata.mutable_step_stats(); - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices().size()) { - step_stats->add_dev_stats(); - } - // Find the current device's index. - int device_idx = 0; - for (int i = 0; i < ctx->devices().size(); ++i) { - if (ctx->devices()[i] == device) { - device_idx = i; - break; - } - } - // Populate the device stats for this device. - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - dev_stats->set_device(device->name()); - *dev_stats->add_node_stats() = *maybe_stats; - } - } - *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { - tensorflow::Device* d = IsCPU(device) ? nullptr : device; - if (d != nullptr && output_memory_types != nullptr && - (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { - d = nullptr; - } - retvals[i] = new TFE_TensorHandle(outputs[i], d); + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } } +TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status) { + tensorflow::TensorHandle* handle; + status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + device_name, &handle); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); + } + return nullptr; +} + void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status) { @@ -896,46 +590,120 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function_def); + status->status = ctx->context.AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); + status->status = ctx->context.AddFunctionDef(function->fdef); +} + +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->context.SetShouldStoreMetadata(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + ctx->context.SetShouldStoreMetadata(false); } } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { - return new TFE_TensorHandle(t, nullptr); + return new TFE_TensorHandle(t, nullptr, nullptr); } const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status) { - if (h->d != nullptr) { + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( "TFE_TensorHandle is placed in device (not host) memory. Cannot return " "a tensorflow::Tensor"); return nullptr; } - return &h->t; + return t; } -void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->should_store_metadata.store(true); +void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); + ctx->context.RunMetadataProto()->Clear(); } -void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - ctx->should_store_metadata.store(false); - ctx->run_metadata.Clear(); +namespace { +TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, + TF_Status* status) { + TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); + for (const auto& attr : func.attr()) { + if (TF_GetCode(status) != TF_OK) return nullptr; + SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return func_op; } +} // namespace -void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, - TF_Status* status) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - status->status = MessageToBuffer(ctx->run_metadata, buf); - ctx->run_metadata.Clear(); -} +namespace tensorflow { +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status) { + switch (default_value.value_case()) { + case tensorflow::AttrValue::kS: + TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + break; + case tensorflow::AttrValue::kI: + TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); + break; + case tensorflow::AttrValue::kF: + TFE_OpSetAttrFloat(op, attr_name, default_value.f()); + break; + case tensorflow::AttrValue::kB: + TFE_OpSetAttrBool(op, attr_name, default_value.b()); + break; + case tensorflow::AttrValue::kType: + TFE_OpSetAttrType(op, attr_name, + static_cast(default_value.type())); + break; + case tensorflow::AttrValue::kShape: { + const auto& tensor_shape = default_value.shape(); + if (tensor_shape.unknown_rank()) { + TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); + } else { + const auto num_dims = tensor_shape.dim_size(); + std::unique_ptr dims(new int64_t[num_dims]); + for (int i = 0; i < num_dims; ++i) { + dims[i] = tensor_shape.dim(i).size(); + } + TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); + } + } break; + case tensorflow::AttrValue::kFunc: { + const auto func_op = GetFunc(ctx, default_value.func(), status); + if (TF_GetCode(status) != TF_OK) return; + // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList + // require TFE_Op* and just convert it internally a NameAttrValue, so + // consider adding an overload to the C API to make this case easier. + TFE_OpSetAttrFunction(op, attr_name, func_op); + } break; + case tensorflow::AttrValue::kList: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kTensor: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kPlaceholder: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::VALUE_NOT_SET: + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } +} +} // namespace tensorflow diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 7a321b54da343fd2b8912187bc620c1e7456db0c..1862af3ce2f505a6e83b4805417eaf335ed07bc0 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -30,7 +30,7 @@ limitations under the License. #ifdef SWIG #define TF_CAPI_EXPORT #else -#if defined(COMPILER_MSVC) +#if defined(_WIN32) #ifdef TF_COMPILE_LIBRARY #define TF_CAPI_EXPORT __declspec(dllexport) #else @@ -38,7 +38,7 @@ limitations under the License. #endif // TF_COMPILE_LIBRARY #else #define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // COMPILER_MSVC +#endif // _WIN32 #endif // SWIG #ifdef __cplusplus @@ -65,17 +65,32 @@ typedef enum TFE_ContextDevicePlacementPolicy { TFE_DEVICE_PLACEMENT_EXPLICIT = 0, // Copy the tensor to the right device but log a warning. TFE_DEVICE_PLACEMENT_WARN = 1, - // Silently copy the tensor, which has a performance cost since the - // operation will be blocked till the copy completes. + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default placement + // policy. TFE_DEVICE_PLACEMENT_SILENT = 2, - // Default placement policy which silently copies int32 tensors but not other - // dtypes. + // Placement policy which silently copies int32 tensors but not other dtypes. TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; +// Sets the default execution mode (sync/async). Note that this can be +// overridden per thread using TFE_ContextSetAsyncForThread. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, + unsigned char async); + TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -108,6 +123,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(TFE_Context*); +// Overrides the execution mode (sync/async) for the current thread. +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, + unsigned char async, + TF_Status* status); + +// Causes the calling thread to block till all ops dispatched in async mode +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*, + TF_Status* status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, @@ -117,13 +156,25 @@ typedef struct TFE_TensorHandle TFE_TensorHandle; TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status); +// Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); -TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, - int dim_index); + int dim_index, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( - TFE_TensorHandle* h); + TFE_TensorHandle* h, TF_Status* status); + +// This function will block till the operation that produces `h` has +// completed. The memory returned might alias the internal memory used by +// TensorFlow. Hence, callers should not mutate this memory (for example by +// modifying the memory region pointed to by TF_TensorData() on the returned +// TF_Tensor). TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); @@ -133,10 +184,52 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). +// If async execution is enabled, the copy may be enqueued and the call will +// return "non-ready" handle. Else, this function returns after the copy has +// been done. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); +// Debugging/Profiling information for TFE_TensorHandle +// +// TFE_TensorDebugInfo contains information useful for debugging and +// profiling tensors. +typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo; + +// Retrieves TFE_TensorDebugInfo for `handle`. +// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller +// is responsible for deleting returned TFE_TensorDebugInfo. +// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate +// error and nullptr is returned. This function can block till the operation +// that produces `handle` has completed. +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status); + +// Deletes `debug_info`. +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of dimensions used to represent the tensor on its device. +// The number of dimensions used to reprensent the tensor on device can be +// different from the number returned by TFE_TensorHandleNumDims. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of elements in dimension `dim_index`. +// Tensor representation on device can be transposed from its representation +// on host. The data contained in dimension `dim_index` on device +// can correspond to the data contained in another dimension in on-host +// representation. The dimensions are indexed using the standard TensorFlow +// major-to-minor order (slowest varying dimension first), +// not the XLA's minor-to-major order. +// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns +// the number of elements in a dimension after padding. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index); + // Description of the TensorFlow op to execute. // // Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., @@ -153,6 +246,7 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); + TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, @@ -238,13 +332,21 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, int num_values); // Execute the operation defined by 'op' and return handles to computed -// tensors in 'retvals'. +// tensors in `retvals`. +// +// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. // -// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* -// and '*num_retvals' should be set to the size of this array. +// If async execution is enabled, the call may simply enqueue the execution +// and return "non-ready" handles in `retvals`. Note that any handles contained +// in 'op' should not be mutated till the kernel execution actually finishes. // -// On return, 'num_retvals' will be set to the actual number of outputs -// returned by the operation. +// For sync execution, if any of the inputs to `op` are not ready, this call +// will block till they become ready and then return when the kernel execution +// is done. +// TODO(agarwal): change num_retvals to int from int*. TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status); @@ -270,6 +372,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); // Populates the passed-in buffer with a serialized RunMetadata protocol buffer // containing any run metadata information accumulated so far and clears this // information. +// If async mode is enabled, this call blocks till all currently pending ops are +// done. TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc new file mode 100644 index 0000000000000000000000000000000000000000..5006b76f1981d068e99a2c081115ebb3a66d8c7f --- /dev/null +++ b/tensorflow/c/eager/c_api_debug.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/jit/xla_device.h" +#endif // TENSORFLOW_EAGER_USE_XLA + +using tensorflow::int64; +using tensorflow::string; + +namespace { + +std::vector TensorShapeAsVector(TFE_TensorHandle* handle, + TF_Status* status) { + std::vector shape; + int rank = TFE_TensorHandleNumDims(handle, status); + if (!status->status.ok()) { + return shape; + } + shape.reserve(rank); + for (int i = 0; i < rank; ++i) { + shape.push_back(TFE_TensorHandleDim(handle, i, status)); + if (!status->status.ok()) { + return shape; + } + } + return shape; +} + +} // namespace + +extern "C" { + +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status) { + const tensorflow::Tensor* tensor; + status->status = handle->handle->Tensor(&tensor); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::Device* device; + status->status = handle->handle->Device(&device); + if (!status->status.ok()) { + return nullptr; + } + +#ifdef TENSORFLOW_EAGER_USE_XLA + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. + tensorflow::XlaDevice* xla_device = + dynamic_cast(device); + if (xla_device != nullptr) { + tensorflow::XlaDevice::PaddedShapeFn shape_fn = + xla_device->metadata().padded_shape_fn(); + xla::Shape padded_shape; + status->status = shape_fn(*tensor, &padded_shape); + if (!status->status.ok()) { + return nullptr; + } + if (VLOG_IS_ON(3)) { + std::vector shape_to_log = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + // Ignore the status here as we are simply logging. + status->status = tensorflow::Status::OK(); + } else { + VLOG(3) << "Fully padded shape of [" + << tensorflow::str_util::Join(shape_to_log, ", ") << "] is " + << padded_shape.DebugString(); + } + } + + if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { + // Currently, the only case of XlaTensor containing a tuple shape is to + // represent 64 bit ints, doubles, and complex numbers (we don't support + // 64bit complex numbers). + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should only contain tuples of size 2. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // shape0 is not a const& because we will assign it to padded_shape below. + // It is illegal to assign a part of a message to itself. + xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); + const xla::Shape& shape1 = + xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); + if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should not contain nested tuples. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + if (!xla::ShapeUtil::Equal(shape0, shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "Subshapes of XlaTensors should be the same. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // Since the only case we handle here are two equal subshapes, we + // simply return one of them. The caller will interpret it as this + // shape directly storing the 64bit types. This approximation is good + // enough for this API's debugging use case. + padded_shape = shape0; + } + + int rank = padded_shape.dimensions_size(); + std::vector dev_dims; + dev_dims.reserve(rank); + if (rank == 1) { + // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, + dev_dims.push_back(padded_shape.dimensions(0)); + } else { + for (int i = rank - 1; i >= 0; --i) { + int64 dim_index = padded_shape.layout().minor_to_major(i); + dev_dims.push_back(padded_shape.dimensions(dim_index)); + } + } + status->status = tensorflow::Status::OK(); + return new TFE_TensorDebugInfo(dev_dims); + } +#endif // TENSORFLOW_EAGER_USE_XLA + + // If the tensor is not an XLA tensor, the device shape is + // the same as regular tensor shape. + std::vector dev_dims = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + return nullptr; + } + return new TFE_TensorDebugInfo(dev_dims); +} + +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info) { + delete debug_info; +} + +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info) { + return debug_info->dev_dims.size(); +} + +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index) { + return debug_info->dev_dims[dim_index]; +} + +} // extern "C" diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cddb9f6e00e9d639026f4bbe061d58f76771c0a9 --- /dev/null +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +TEST(CApiDebug, ScalarCPU) { + TFE_TensorHandle* h = TestScalarTensorHandle(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} + +TEST(CApiDebug, 2DCPU) { + TFE_TensorHandle* h = TestMatrixTensorHandle3X2(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + // Shape is the same for CPU tensors. + EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0)); + EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 7b9f1db02ed9c53a280c7bd1284165cac4fb6353..04a6efc47c5177c82b7e88168b67cc584587de7c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -19,18 +19,35 @@ limitations under the License. #include #include +#include #include +#include #include #include #include #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" @@ -39,89 +56,79 @@ limitations under the License. struct TFE_ContextOptions { TF_SessionOptions session_options; - TFE_ContextDevicePlacementPolicy policy{ - TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; + // true if async execution is enabled. + bool async = false; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; + tensorflow::ServerDef server_def; }; struct TFE_Context { - explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) - : policy(opts.policy), - session(s), - rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), - pflr(new tensorflow::ProcessFunctionLibraryRuntime( - session->device_mgr, opts.session_options.options.env, - TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} - - const TFE_ContextDevicePlacementPolicy policy; - - // Note: we cannot use C++11 thread_local here as there is no concept of a - // thread-local-object-local variable in C++11. - tensorflow::mutex policy_map_mu; - std::unordered_map - thread_local_policies GUARDED_BY(policy_map_mu); - - // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* const session; - tensorflow::Rendezvous* const rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - const std::unique_ptr pflr; - - tensorflow::mutex cache_mu; - std::unordered_map - kernel_cache GUARDED_BY(cache_mu); - - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { - return pflr->GetFLR(d->name()); - } - - const std::vector& devices() { return session->devices; } - - // Whether we should compute RunMetadata. - std::atomic should_store_metadata{false}; - tensorflow::mutex metadata_mu; - tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); + explicit TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, + bool async, + std::unique_ptr device_mgr, + tensorflow::Rendezvous* rendezvous) + : context(opts, + static_cast( + default_policy), + async, std::move(device_mgr), rendezvous) {} + + explicit TFE_Context( + const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + tensorflow::DeviceMgr* local_device_mgr, + tensorflow::Rendezvous* rendezvous, + std::unique_ptr server, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_mgr, + const tensorflow::gtl::FlatMap& + remote_contexts) + : context(opts, + static_cast( + default_policy), + async, local_device_mgr, rendezvous, std::move(server), + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts) {} + + tensorflow::EagerContext context; }; struct TFE_TensorHandle { - TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) - : t(t), d(d) {} - - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, + tensorflow::Device* op_device) + : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} + + TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, + tensorflow::EagerContext* ctx) + : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} + + TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} + + tensorflow::TensorHandle* handle; +}; + +struct TFE_TensorDebugInfo { + TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; }; struct TFE_Op { // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a // primitive operation. TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} - - bool const is_function() const { return attr_types == nullptr; } - - TFE_Context* ctx; // Must outlive the TFE_Op. - const tensorflow::string name; - tensorflow::AttrBuilder attrs; - const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; - tensorflow::Device* device; - bool use_xla = false; + : operation(&ctx->context, op, t) {} + + tensorflow::EagerOperation operation; }; +namespace tensorflow { +// Set an AttrValue on the op. Doesn't handle the list types. +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status); +} // namespace tensorflow + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4a3ecbc0abb16296a84c0d2184dc3fc9f7f3ebb4..992d1afd5fcb0641794bb2abbe5ab20a287d3b62 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -23,100 +25,14 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::string; namespace { -TFE_TensorHandle* TestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, a, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, b, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); - - return op; -} - -TFE_TensorHandle* TestAxisTensorHandle() { - int64_t dims[] = {1}; - int data[] = {1}; - TF_Tensor* t = TF_AllocateTensor( - TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, - TFE_TensorHandle* axis) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "Min", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, input, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, axis, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpSetAttrBool(op, "keep_dims", 1); - TFE_OpSetAttrType(op, "Tidx", TF_INT32); - TF_DeleteStatus(status); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); - - return op; -} - -// If there is a GPU device, returns true and sets 'gpu_device_name' -// accordingly. -bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - const int num_devices = TF_DeviceListCount(devices); - for (int i = 0; i < num_devices; ++i) { - const string device_type(TF_DeviceListType(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - const string device_name(TF_DeviceListName(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - if (device_type == "GPU") { - *gpu_device_name = device_name; - LOG(INFO) << "Found GPU device " << device_name; - TF_DeleteDeviceList(devices); - return true; - } - } - TF_DeleteDeviceList(devices); - return false; -} - void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -139,10 +55,12 @@ void BM_InitOp(int iters) { } BENCHMARK(BM_InitOp); -void BM_Execute(int iters) { +void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -156,6 +74,9 @@ void BM_Execute(int iters) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); @@ -163,7 +84,7 @@ void BM_Execute(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_Execute); +BENCHMARK(BM_Execute)->Arg(0)->Arg(1); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); @@ -187,6 +108,182 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } +tensorflow::ServerDef GetServerDef(int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name("localhost"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name("localhost"); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +void TestRemoteExecute(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } +TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } + +void TestRemoteExecuteSilentCopies(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + + // Handles are on task0, but op is on remote (task1). + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopiesAsync) { + TestRemoteExecuteSilentCopies(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -205,10 +302,11 @@ TEST(CAPI, TensorHandle) { TFE_DeleteTensorHandle(h); } -TEST(CAPI, TensorHandleCopyBetweenDevices) { +void TensorHandleCopyBetweenDevices(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -274,10 +372,56 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { +TEST(CAPI, TensorHandleCopyBetweenDevices) { + TensorHandleCopyBetweenDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) { + TensorHandleCopyBetweenDevices(true); +} + +void TensorHandleCopyBetweenDevicesError(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* kErrorDevice = "NoSuchDevice:0"; + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get()); + EXPECT_NE(TF_OK, TF_GetCode(status.get())); + const char* msg = "NoSuchDevice:0 unknown device"; + EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr) + << TF_Message(status.get()); + TF_SetStatus(status.get(), TF_OK, ""); + const char* kCPUDevice = "CPU:0"; + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())); + TFE_DeleteTensorHandle(hcopy); + TFE_DeleteTensorHandle(hcpu); + if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); + TFE_DeleteContext(ctx, status.get()); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesError) { + TensorHandleCopyBetweenDevicesError(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) { + TensorHandleCopyBetweenDevicesError(true); +} + +void TensorHandleCopyBetweenTwoGPUDevices(bool async) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -332,11 +476,20 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopy) { +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { + TensorHandleCopyBetweenTwoGPUDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) { + TensorHandleCopyBetweenTwoGPUDevices(true); +} + +void TensorHandleSilentCopy(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -347,7 +500,7 @@ TEST(CAPI, TensorHandleSilentCopy) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -366,14 +519,20 @@ TEST(CAPI, TensorHandleSilentCopy) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopyLocal) { +TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } +TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); } + +void TensorHandleSilentCopyLocal(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status.get()); @@ -388,7 +547,7 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -407,11 +566,17 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } +TEST(CAPI, TensorHandleSilentCopyLocalAsync) { + TensorHandleSilentCopyLocal(true); +} -TEST(CAPI, SetAndGetOpDevices) { +void SetAndGetOpDevices(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -423,7 +588,7 @@ TEST(CAPI, SetAndGetOpDevices) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_OpSetDevice(matmul, "GPU:0", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); const char* device_name = TFE_OpGetDevice(matmul, status); @@ -442,27 +607,28 @@ TEST(CAPI, SetAndGetOpDevices) { TF_DeleteStatus(status); } -TEST(CAPI, Execute_MatMul_CPU) { +void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; + int num_retvals = 2; TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(1, num_retvals); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -474,7 +640,107 @@ TEST(CAPI, Execute_MatMul_CPU) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); } +TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); } + +void Execute_MatMul_CPU_Runtime_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", + status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* matmul2 = MatMulOp(ctx, m1, m1); + TFE_OpSetDevice(matmul2, "/job:localhost/replica:0/task:0/device:CPU:0", + status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + TFE_DeleteOp(matmul); + if (!async) { + EXPECT_NE(TF_OK, TF_GetCode(status)); + } else { + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + EXPECT_EQ(nullptr, t); + const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) + << TF_Message(status); + // Since error is not cleared, the following copy with correct device will + // still fail. + TF_SetStatus(status, TF_OK, ""); + TFE_DeleteTensorHandle(retvals[0]); + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_ContextAsyncClearError(ctx); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + } + // Following works in async mode since TFE_ContextAsyncClearError was called. + TF_SetStatus(status, TF_OK, ""); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteTensor(t); + TFE_DeleteOp(matmul2); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { + Execute_MatMul_CPU_Runtime_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) { + Execute_MatMul_CPU_Runtime_Error(true); +} +void Execute_MatMul_CPU_Type_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_MatMul_CPU_Type_Error) { + Execute_MatMul_CPU_Type_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) { + Execute_MatMul_CPU_Type_Error(true); +} TEST(CAPI, Execute_Min_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -485,33 +751,34 @@ TEST(CAPI, Execute_Min_CPU) { TFE_TensorHandle* input = TestMatrixTensorHandle(); TFE_TensorHandle* axis = TestAxisTensorHandle(); TFE_Op* minOp = MinOp(ctx, input, axis); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); float output[2] = {0}; EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } #ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Execute_MatMul_XLA_CPU) { +void Execute_MatMul_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -521,15 +788,14 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { TFE_OpSetXLACompilation(matmul, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); // Running a primitive TF operator via XLA is not yet supported. ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(1, num_retvals); @@ -545,13 +811,16 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } +TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } -TEST(CAPI, Execute_Min_XLA_CPU) { +void Execute_Min_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -562,14 +831,13 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TFE_OpSetXLACompilation(minOp, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); @@ -582,13 +850,17 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } +TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } #endif // TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, ExecuteWithTracing) { +void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); TFE_ContextEnableRunMetadata(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -596,8 +868,8 @@ TEST(CAPI, ExecuteWithTracing) { TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); @@ -609,12 +881,12 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_TRUE( rm.ParseFromString({reinterpret_cast(b->data), b->length})); TF_DeleteBuffer(b); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -626,6 +898,8 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); } TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. @@ -657,32 +931,37 @@ TEST(CAPI, Function_ident_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -719,35 +998,40 @@ TEST(CAPI, Function_ident_XLA_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -788,9 +1072,10 @@ string MatMulFunction() { return def.SerializeAsString(); } -TEST(CAPI, FunctionDefAndExecute) { +void FunctionDefAndExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -827,11 +1112,16 @@ TEST(CAPI, FunctionDefAndExecute) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); } +TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); } -void BM_ExecuteFunction(int iters) { +void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync" + : "ExecuteFunction"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -853,6 +1143,9 @@ void BM_ExecuteFunction(int iters) { TFE_Execute(matmul, &retval[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); @@ -860,7 +1153,7 @@ void BM_ExecuteFunction(int iters) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_ExecuteFunction); +BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, TF_Status* status) { @@ -932,7 +1225,8 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle)); - EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle)); + EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status)); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float value = 0.0f; TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -974,7 +1268,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(1, num_retvals); CHECK(h); CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); - CHECK_EQ(0, TFE_TensorHandleNumDims(h)); + CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; } tensorflow::testing::StopTiming(); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..5607c9dcb0bbec72b2f86def3dd4e6590d73197b --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_test_util.h" + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +using tensorflow::string; + +TFE_TensorHandle* TestScalarTensorHandle() { + float data[] = {1.0f}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrBool(op, "transpose_a", 0); + TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +bool GetDeviceName(TFE_Context* ctx, string* device_name, + const char* device_type) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string dev_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string dev_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (dev_type == device_type) { + *device_name = dev_name; + LOG(INFO) << "Found " << device_type << " device " << *device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..474cae67c89249af3a62707f0db00ba458ca8f31 --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include "tensorflow/core/platform/types.h" + +// Return a tensor handle containing a float scalar +TFE_TensorHandle* TestScalarTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle(); + +// Return a tensor handle containing a 3x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); + +// Return a tensor handle containing a 3x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle3X2(); + +// Return a matmul op multiplying `a` by `b`. +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return an 1-D INT32 tensor containing a single value 1. +TFE_TensorHandle* TestAxisTensorHandle(); + +// Return an op taking minimum of `input` long `axis` dimension. +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis); + +// If there is a device of type `device_type`, returns true +// and sets 'device_name' accordingly. +// `device_type` must be either "GPU" or "TPU". +bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, + const char* device_type); + +#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index bdb0815d6b68444ec1c89b835d563db20ce4d8a1..734e712daa39c03f0177eb199b1acb1b19e5d845 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -48,7 +48,7 @@ struct OpTapeEntry { // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. - std::function backward_function_deleter; + std::function backward_function_deleter; }; // Map from tensor_id to internally-defined operation-id of the operation which @@ -104,14 +104,12 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; - - // Lets this VSpace know that it can release resources held by the - // `backward_function`, It will not be called again. - // `backward_function` must not be null. - virtual void ReleaseBackwardFunction( - BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -126,19 +124,21 @@ class GradientTape { GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } - bool ShouldRecord(gtl::ArraySlice tensor_ids); + bool ShouldRecord(gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes); void Watch(int64 tensor_id); - void RecordOperation(const string& op_type, - gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, - BackwardFunction* backward_function, - const std::function& backward_function_deleter); + void RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -152,6 +152,8 @@ class GradientTape { gtl::ArraySlice output_gradients, std::vector* result); + bool IsPersistent() const { return persistent_; } + private: TensorTape tensor_tape_; OpTape op_tape_; @@ -168,12 +170,32 @@ class GradientTape { // Template instantiations here +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + default: + return false; + } +} + template bool GradientTape::ShouldRecord( - gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; + gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes) { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + if (IsDtypeTrainable(dtypes[i])) { + return true; + } } } return false; @@ -187,10 +209,12 @@ void GradientTape::Watch(int64 tensor_id) { template void GradientTape::RecordOperation( const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, - const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { - backward_function_deleter(); + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { + backward_function_deleter(backward_function); return; } std::vector ids; @@ -245,7 +269,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { for (int64 id : op_it->second.input_tensor_id) { DeleteTrace(id); } - op_it->second.backward_function_deleter(); + op_it->second.backward_function_deleter(op_it->second.backward_function); op_tape_.erase(op_it); } @@ -330,8 +354,7 @@ BackpropInitialState PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -352,7 +375,8 @@ BackpropInitialState PrepareBackprop( // backward functions that will be used for gradient computation // has been transferred to `result`. for (const auto& op_pair : *op_tape) { - op_pair.second.backward_function_deleter(); + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); } op_tape->clear(); } @@ -378,49 +402,39 @@ Status InitialGradients(const VSpace& vspace, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, - const gtl::FlatMap& tensor_usage_counts, gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; - if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { - if (!output_gradients.empty() && output_gradients[i] != nullptr) { - // TODO(apassos) figure out how to print debugging information here. - return errors::InvalidArgument( - "A gradient was provided for a tensor which is used as part of the " - "computation."); - } - } else { - if (output_gradients.empty() || output_gradients[i] == nullptr) { - auto tensor_it = tensor_tape.find(id); - if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { - auto op_it = op_tape.find(tensor_it->second); - if (op_it == op_tape.end()) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "failed to find operation producing a tensor"); - } - bool found = false; - for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { - found = true; - (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); - break; - } - } - if (!found) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "none of operations outputs match expected tensor"); + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; } - } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); } } else { - (*result)[id].push_back(output_gradients[i]); + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. } + } else { + (*result)[id].push_back(output_gradients[i]); } } return Status::OK(); @@ -449,13 +463,12 @@ Status GradientTape::ComputeGradient( InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, - tensor_tape_, state.op_tape, - state.tensor_usage_counts, &gradients); + tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { // Release all backprop functions for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } }; @@ -507,10 +520,15 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector in_gradients; @@ -518,7 +536,7 @@ Status GradientTape::ComputeGradient( Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } if (!s.ok()) { cleanup(); @@ -527,7 +545,7 @@ Status GradientTape::ComputeGradient( } else { in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } for (Gradient* grad : out_gradients) { if (grad != nullptr) { @@ -599,23 +617,28 @@ Status GradientTape::ComputeGradient( } CHECK(state.op_tape.empty()); result->reserve(source_tensor_ids.size()); + gtl::FlatSet used_gradient_ids(source_tensor_ids.size()); for (auto is : source_tensor_ids) { auto grad_it = gradients.find(is); if (grad_it == gradients.end()) { result->push_back(nullptr); } else { - if (grad_it->second.size() == 1) { - result->push_back(grad_it->second[0]); - } else { - result->push_back(vspace.AggregateGradients(grad_it->second)); + if (grad_it->second.size() > 1) { + Gradient* grad = vspace.AggregateGradients(grad_it->second); + grad_it->second.clear(); + grad_it->second.push_back(grad); } - gradients.erase(grad_it); + result->push_back(grad_it->second[0]); + used_gradient_ids.insert(is); } } - VLOG(1) << "Final gradients size: " << gradients.size(); + VLOG(1) << "Final gradients size: " + << gradients.size() - used_gradient_ids.size(); for (auto grad_pair : gradients) { - for (const auto& g : grad_pair.second) { - vspace.DeleteGradient(g); + if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } } } return Status::OK(); diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 02a6a58b6153bb78c684f9290ef95900f96e9357..7184ad68fb79f2598067d68d5ab5ba8f2c7a22c8 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -15,10 +15,12 @@ # ============================================================================== TF_PREFIX='/usr/local' +LIBDIR='lib' usage() { echo "Usage: $0 OPTIONS" echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-l, --libdir\tset lib directory (default: lib)" echo -e "-v, --version\tset TensorFlow version" echo -e "-h, --help\tdisplay this message" } @@ -26,7 +28,7 @@ usage() { [ $# == 0 ] && usage && exit 0 # read the options -ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@") eval set -- "$ARGS" # extract options and their arguments into variables. @@ -38,6 +40,11 @@ while true ; do "") shift 2 ;; *) TF_PREFIX=$2 ; shift 2 ;; esac ;; + -l|--libdir) + case "$2" in + "") shift 2 ;; + *) LIBDIR=$2 ; shift 2 ;; + esac ;; -v|--version) case "$2" in "") shift 2 ;; @@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} -libdir=\${exec_prefix}/lib +libdir=\${exec_prefix}/${LIBDIR} includedir=\${prefix}/include Name: TensorFlow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68..e18fdf6c57bd3f432d8cb73536fb816df90b3963 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/python_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" namespace tensorflow { @@ -99,4 +100,65 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { } } +void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { + mutex_lock l(graph->mu); + graph->refiner.set_require_shape_inference_fns(require); +} + +void ExtendSession(TF_Session* session, TF_Status* status) { + ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; +} + +std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + } + } + string result; + handle_data.SerializeToString(&result); + return result; +} + +void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (status->status.ok()) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index aa9d9e06b28c54cb8869eb547d36ee3cb0d4e6b8..4bcb5bde62c8a4df4e68c1ce0daaf459434ceb5d 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_PYTHON_API_H_ #define TENSORFLOW_C_PYTHON_API_H_ +#include + #include "tensorflow/c/c_api.h" // These functions can be removed without notice. They exist to facilitate some @@ -37,6 +39,31 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); +// Sets whether ops missing a shape inference function should trigger an +// error. The default is true. +void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); + +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. +void ExtendSession(TF_Session* session, TF_Status* status); + +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource tensor, or otherwise returns the empty string. +std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); + +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. +// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string +// because I couldn't get SWIG to work otherwise. +void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status); } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/testdata/tf_record b/tensorflow/c/testdata/tf_record new file mode 100644 index 0000000000000000000000000000000000000000..6e16076bfb79ad8151952e96567565e8820b0f5b Binary files /dev/null and b/tensorflow/c/testdata/tf_record differ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 9060c19e9d2cf965c2b9be07be07c42017da45a8..079e063d3e3fbdaf833e9031f5f9438853c14099 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -620,18 +620,6 @@ tf_cc_binary( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - cc_library( name = "queue_runner", srcs = ["training/queue_runner.cc"], diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a40ad1ffc3b262840e6ca0043139b1b61e04510d..dfdef88945deca376368edd6f7aa322b1e1cbf94 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" @@ -272,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return ""; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -296,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { - static const std::unordered_map, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, @@ -316,14 +323,34 @@ std::pair AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -439,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return name.ToString(); + return std::string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, @@ -667,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -697,17 +725,39 @@ string OpInfo::GetOpAttrStruct() const { attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); - strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); + strings::StrAppend(&setters, " TF_MUST_USE_RESULT Attrs ", attr_func_def, + " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -719,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 1e0f2d241bb350897a840dda90d6d0c009b1daad..5d9dfd95a5538ae0f3d2d111a1f989552c3363b8 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -61,12 +62,12 @@ op { )"; void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(s.contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { - EXPECT_FALSE(s.contains(expected)) + EXPECT_FALSE(str_util::StrContains(s, expected)) << "'" << s << "' contains '" << expected << "'"; } diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 71642492627422e09c19b7bcb4dc522846cf08b1..62a889181e787f2e181135ab0563c45e1bab8812 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -218,8 +219,8 @@ std::unordered_set Scope::Impl::GetColocationConstraints( if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); - if (s.Consume(kColocationGroupPrefix)) { - current_constraints.insert(s.ToString()); + if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { + current_constraints.insert(std::string(s)); } } } else { diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index 0734075fc6144d7c9f4fdb48c5e097faa58b8355..81870a0efa309ae6dbd5cc05a5dbe8c3e2d437c8 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -72,9 +72,9 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, }; // Body function that adds one to input. - BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, - const std::vector& inputs, - std::vector* outputs) { + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { DCHECK_EQ(inputs.size(), 1); outputs->emplace_back(ops::Add(scope, inputs[0], 1)); return scope.status(); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 6545e4ee3eb406436937a43ddac66d017af8e108..ff348fadb24e29a83bd6c8853aa67931f6df4182 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -385,6 +385,42 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad); +Status StridedSliceGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input x = Shape(scope, op.input(0)); + Input begin = op.input(1); + Input end = op.input(2); + Input strides = op.input(3); + int64 begin_mask; + int64 end_mask; + int64 ellipsis_mask; + int64 new_axis_mask; + int64 shrink_axis_mask; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask)); + grad_outputs->push_back( + StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0], + StridedSliceGrad::BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask))); + // No gradients returned for begin, end and strides + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 4a215fcc9299cf8b8da04cbf151640631ed0d449..de3bd0fc9e2493f8ff76163f5be6bd4327c58c5a 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -354,5 +354,29 @@ TEST_F(ArrayGradTest, MirrorPadGradGrad_Symmetric) { RunTest(x, x_shape, y, y_shape); } +TEST_F(ArrayGradTest, StridedSliceGrad) { + TensorShape x_shape({6, 4, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + + // y = x[2:6:2, 1:3, 1:3] + auto y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1}); + // y.shape = [2, 2, 2]; + RunTest(x, x_shape, y, {2, 2, 2}); + + // y = x[2:6:2, 1:3, 1:3] + // begin_mask = 1<<1 (ignore begin_index = 1) + // end_mask = 1<<2 (ignore end_index = 2) + y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1}, + StridedSlice::BeginMask(1 << 1).EndMask(1 << 2)); + // y.shape = [2, 3, 3]; + RunTest(x, x_shape, y, {2, 3, 3}); + + // y = [tf.newaxis, 2:6:2, 1:3, 1:3] + y = StridedSlice(scope_, x, {0, 2, 1, 1}, {0, 6, 3, 3}, {1, 2, 1, 1}, + StridedSlice::NewAxisMask(1 << 0)); + // y.shape = [1, 2, 2, 2]; + RunTest(x, x_shape, y, {1, 2, 2, 2}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a8c88f1857defcc38de4a01ac47dab0..35a01e0341cb08c9b314908b6dcd76fd99c1e68b 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual"); REGISTER_NO_GRADIENT_OP("LogicalAnd"); REGISTER_NO_GRADIENT_OP("LogicalOr"); REGISTER_NO_GRADIENT_OP("LogicalNot"); +REGISTER_NO_GRADIENT_OP("Floor"); // Conjugate helper function returns the conjugate of an Output if it // is complex valued. diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1b4c7c2688083e74433da3dce2849b8c37443684..fd7b6fe6625f27bda92e2f56f60908658cdecd7e 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -31,7 +31,6 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; -using ops::Greater; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -46,7 +45,6 @@ using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; using ops::Sum; -using ops::Where3; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d..c73482d5f4d13ade0dc0412941251d1651371b6e 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -48,8 +48,8 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); Status LogSoftmaxGrad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { + const std::vector& grad_inputs, + std::vector* grad_outputs) { auto softmax = Exp(scope, op.output(0)); auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true)); auto mul = Mul(scope, sum, softmax); @@ -107,11 +107,10 @@ Status BiasAddGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string data_format; - BiasAddGrad::Attrs input_attrs; TF_RETURN_IF_ERROR( GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format)); - input_attrs.DataFormat(data_format); - auto dx_1 = BiasAddGrad(scope, grad_inputs[0], input_attrs); + auto dx_1 = + BiasAddGrad(scope, grad_inputs[0], BiasAddGrad::DataFormat(data_format)); grad_outputs->push_back(Identity(scope, grad_inputs[0])); grad_outputs->push_back(dx_1); return scope.status(); @@ -130,19 +129,16 @@ Status Conv2DGrad(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu)); - Conv2DBackpropInput::Attrs input_attrs; - input_attrs.DataFormat(data_format); - input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu); - auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), - op.input(1), grad_inputs[0], - strides, padding, input_attrs); + auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), op.input(1), + grad_inputs[0], strides, padding, + Conv2DBackpropInput::DataFormat(data_format) + .UseCudnnOnGpu(use_cudnn_on_gpu)); grad_outputs->push_back(dx_1); - Conv2DBackpropFilter::Attrs filter_attrs; - filter_attrs.DataFormat(data_format); - filter_attrs.UseCudnnOnGpu(use_cudnn_on_gpu); - auto dx_2 = Conv2DBackpropFilter(scope, op.input(0), - Shape(scope, op.input(1)), grad_inputs[0], - strides, padding, filter_attrs); + auto dx_2 = + Conv2DBackpropFilter(scope, op.input(0), Shape(scope, op.input(1)), + grad_inputs[0], strides, padding, + Conv2DBackpropFilter::DataFormat(data_format) + .UseCudnnOnGpu(use_cudnn_on_gpu)); grad_outputs->push_back(dx_2); return scope.status(); } @@ -160,13 +156,9 @@ Status MaxPoolGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); - internal::MaxPoolGrad::Attrs grad_attrs; - grad_attrs.DataFormat(data_format); - auto dx = internal::MaxPoolGrad(scope, op.input(0), - op.output(0), - grad_inputs[0], - ksize, strides, - padding, grad_attrs); + auto dx = internal::MaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], ksize, strides, padding, + internal::MaxPoolGrad::DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -180,15 +172,9 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); - MaxPoolGradV2::Attrs grad_attrs; - grad_attrs.DataFormat(data_format); - auto dx = MaxPoolGradV2(scope, op.input(0), - op.output(0), - grad_inputs[0], - op.input(1), - op.input(2), - padding, - grad_attrs); + auto dx = MaxPoolGradV2(scope, op.input(0), op.output(0), grad_inputs[0], + op.input(1), op.input(2), padding, + MaxPoolGradV2::DataFormat(data_format)); grad_outputs->push_back(dx); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); @@ -196,18 +182,126 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + MaxPool3DGrad::Attrs grad_attrs; + auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); + +Status AvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + internal::AvgPoolGrad::Attrs grad_attrs; + auto dx = + internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); + +Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + AvgPool3DGrad::Attrs grad_attrs; + auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper); + Status LRNGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, - std::vector* grad_outputs){ - internal::LRNGrad::Attrs grad_attrs; - - auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), - grad_attrs); + std::vector* grad_outputs) { + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0)); grad_outputs->push_back(dx); return scope.status(); } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 0cfe5f6e3c49f7c4a3cafbf48ff4e54a0ffd0d47..b4d457a9d14eb79232cda9412fa0050f6a9968cc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,16 +28,23 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; +using ops::AvgPool; +using ops::AvgPool3D; using ops::MaxPool; using ops::MaxPoolV2; +using ops::MaxPool3D; using ops::Placeholder; using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -68,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool and MaxPoolV2, in which perturbations by the + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the // numeric gradient computation in the gradient checker can change the max - // value if values are too close together. + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -186,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -199,10 +214,45 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, MaxPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one MaxPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NNGradTest, AvgPoolGradHelper) { + TensorShape x_shape({1, 2, 2, 1}); + TensorShape y_shape({1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool. + const std::vector ksize{1, 2, 2, 1}; + const std::vector strides{1, 2, 2, 1}; + auto y = AvgPool(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + +TEST_F(NNGradTest, AvgPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(NNGradTest, LRN){ TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); @@ -210,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index 00799526fce572e7bb80199ccb8ce1cc89874031..cf65fe1ab99b49207a64e86310178141b30d07d7 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -9,6 +9,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], + tags = [ + "noguitar", # b/77649654 + ], deps = [ ":profiler", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index 6077c45c5854fd5812ccb7c91522f93ed4e54883..64edbb5766c3604fbe0f15c2299843718381aa3f 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -61,18 +61,18 @@ class Profiler { /// Adds tracing information `run_meta` to profiler. A `run_meta` is /// generated by a TensorFlow session run call. `step` is the key /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// `step` in `options` to selectively profile the corresponding `run_meta`. /// Multiple different `run_meta` can be keyed by the same `step` in order /// to group them together. void AddStep(int64 step, const RunMetadata& run_meta); /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are contected by the op inputs/outputs. + /// Each node is an op and the nodes are connected by the op inputs/outputs. GraphNodeProto ProfileGraph(const Options& options); /// Profiles the model by organizing nodes in name scope structure. /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a filesystem tree. + /// scope, similar to a file system tree. /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. GraphNodeProto ProfileNameScope(const Options& options); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index d29ad3ebcbe29087d5572b51c7713e0c98d0d840..06a3be18e08f611d3ecf9804908d791d15fdab13 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -94,18 +94,3 @@ filegroup( "testdata/half_plus_two/**", ]), ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 4c64d2cfe3c10e6c7ed82a2d72460a0b34283bb2..72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -133,9 +134,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied " - "tags: { missing-tag }")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) << st.error_message(); } @@ -149,9 +150,9 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags: ")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) << st.error_message(); } @@ -169,7 +170,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget)) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD index f5fbc75edcba9d5ae9ef7432de224df766bcab9e..6f04ebdc55cda329527c95f62efc37c8dfbb4ae5 100644 --- a/tensorflow/cc/saved_model/python/BUILD +++ b/tensorflow/cc/saved_model/python/BUILD @@ -7,18 +7,6 @@ package( default_visibility = ["//visibility:public"], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc") tf_py_clif_cc( diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 97f66e79b8ad9f383b22f56e9385fc6d2080e1f8..6f1c87354076565af22f7ba0610a5c6bb999d25c 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -32,6 +32,7 @@ tf_cc_test( deps = [ ":freeze_saved_model", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:protos_all_cc", @@ -40,18 +41,3 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index ddf372cdef21e1b3892c9a03714478d5a5785517..23e9dc40d23899b9cef168c9128b6d8ed1be3ee9 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include #include #include "tensorflow/core/framework/attr_value.pb.h" @@ -71,25 +72,29 @@ void GetNodeNameToNodeDefMap( } } +// Strips off the tensor part of the tensor_name to get the node_name. +const string GetNodeNameFromTensorName(string tensor_name) { + if (tensor_name[0] == '^') { + tensor_name.erase(0, 1); + } + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + return tensor_name_parts[0]; +} + // Gets the set of node names needed by `outputs` and the corresponding set of // variable nodes to convert. void GetReachableNodesAndVariables( GraphDef* graph_def, const std::unordered_set& outputs, + const std::unordered_map& name_to_node_map, std::unordered_set* reachable_node_names, std::unordered_set* variable_node_names) { // TODO(suharshs): Add support for ResourceVariables. static const std::unordered_set* kVariableTypes = - new std::unordered_set({"Variable", "VariableV2"}); - // name_to_node_map is needed to get the inputs from the NodeDef corresponding - // the a string node name. These inputs are used when doing our backwards - // traversal. - std::unordered_map name_to_node_map; - GetNodeNameToNodeDefMap(graph_def, &name_to_node_map); + new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); + std::queue nodes_to_visit; - for (const string& tensor_name : outputs) { - // We need to strip off the tensor part to get the node name. - std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); - nodes_to_visit.push(tensor_name_parts[0]); + for (const string& output_tensor_name : outputs) { + nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name)); } // We do a traversal backwards from the outputs specified in the MetaGraphDef. while (!nodes_to_visit.empty()) { @@ -99,19 +104,21 @@ void GetReachableNodesAndVariables( continue; } reachable_node_names->insert(node_name); - NodeDef* node = name_to_node_map[node_name]; + NodeDef* node = name_to_node_map.at(node_name); if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } - for (const string& input : node->input()) { - nodes_to_visit.push(input); + for (const string& input_tensor_name : node->input()) { + nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name)); } } } // Gets a map from variable name to variable value. Status GetVariableNameToTensorMap( - Session* session, std::unordered_set variable_names_set, + Session* session, + const std::unordered_map& name_to_node_map, + std::unordered_set variable_names_set, std::unordered_map* variable_name_to_value_map) { if (variable_names_set.empty()) { return Status::OK(); @@ -120,8 +127,14 @@ Status GetVariableNameToTensorMap( std::vector tensor_names; for (const string& node_name : variable_names_set) { variable_names.push_back(node_name); - // We need to run tensors, so append ":0". - tensor_names.push_back(node_name + ":0"); + NodeDef* node_def = name_to_node_map.at(node_name); + if (node_def->op() == "VarHandleOp") { + // If this is a resource variable, we have to run the corresponding + // ReadVariableOp. + tensor_names.push_back(node_name + "/Read/ReadVariableOp:0"); + } else { + tensor_names.push_back(node_name + ":0"); + } } std::vector outputs; TF_RETURN_IF_ERROR( @@ -143,6 +156,15 @@ void ConvertVariableToConstant(const NodeDef& variable_node, (*const_node->mutable_attr())["value"].mutable_tensor()); } +// Converts a ReadVariableOp NodeDef to an Identity NodeDef. +void ConvertReadVariableOpToIdentity(const NodeDef& node, + NodeDef* identity_node) { + identity_node->set_name(node.name()); + identity_node->set_op("Identity"); + (*identity_node->mutable_attr())["T"] = node.attr().at("dtype"); + identity_node->add_input(node.input(0)); +} + // Freezes the subgraph of all nodes needed by `outputs`. Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, const std::unordered_set& outputs, @@ -155,14 +177,19 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, if (graph_def.node_size() == 0) { return Status::OK(); } + // name_to_node_map is needed to get the inputs from the NodeDef corresponding + // the a string node name. These inputs are used when doing our backwards + // traversal. + std::unordered_map name_to_node_map; + GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map); std::unordered_set reachable_node_names; std::unordered_set variable_node_names; - GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names, - &variable_node_names); + GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map, + &reachable_node_names, &variable_node_names); std::unordered_map variable_to_value_map; - TF_RETURN_IF_ERROR( - GetVariableNameToTensorMap(saved_model_bundle.session.get(), - variable_node_names, &variable_to_value_map)); + TF_RETURN_IF_ERROR(GetVariableNameToTensorMap( + saved_model_bundle.session.get(), name_to_node_map, variable_node_names, + &variable_to_value_map)); // We copy the nodes in the same order they were in the original graph_def. for (const NodeDef& node : graph_def.node()) { if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { @@ -171,6 +198,12 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, if (variable_node_names.find(node.name()) != variable_node_names.end()) { ConvertVariableToConstant(node, variable_to_value_map[node.name()], frozen_graph_def->add_node()); + } else if (node.op() == "ReadVariableOp" && + variable_node_names.find(node.input(0)) != + variable_node_names.end()) { + // If the node is a ReadVariableOp, its input VarHandleOp will be + // converted to a Constant, so we will need to convert it to an Identity. + ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node()); } else { // If the node isn't a variable, just copy the node as-is. *frozen_graph_def->add_node() = node; diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 52a81a50284aec36bba4e56a0232c886cb0cb6cf..979b23c3fc5f66ec574736cb4d39cec0ffd8e6b6 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" @@ -113,6 +114,160 @@ class FreezeTest : public ::testing::Test { test::ExpectTensorEqual(unfrozen_outputs[0], frozen_outputs[0]); } + + void TestFreezeGraphWithoutDependentVariables(bool use_resource) { + // Test freezing a graph with variables that are not needed by the outputs + // in the SignatureDef. The resulting graph shouldn't be frozen, but + // non-dependent nodes should be pruned. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + Output read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + } else { + Output var = + ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + } + + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + GraphDef expected_graph_def; + Scope expected_scope = Scope::NewRootScope(); + Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); + Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); + Output expected_c = + ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); + TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); + + GraphDefEqual(frozen_graph_def, expected_graph_def); + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } + + void TestFreezeGraphWithDependentVariables(bool use_resource) { + // Test freezing a graph with variables that are needed by outputs in the + // SignatureDef. The variables should be frozen. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output read_var; + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + } else { + Output read_var = + ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a); + } + Output c = ops::Mul(scope.WithOpName("c"), a, read_var); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + // If using normal variables there should be 3 nodes in the resulting + // graph_def. If using resource variables there should be 4 nodes in the + // resulting graph_def. + // In both cases, none should be variables. + size_t expected_nodes = use_resource ? 4 : 3; + EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + EXPECT_NE(node.op(), "VarHandleOp") << node.name(); + EXPECT_NE(node.op(), "ReadVariableOp") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } + + void TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource) { + // Test freezing a graph with some variables that are needed and not needed + // by + // the outputs in the SignatureDef. The resulting graph should only freeze + // dependent variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output read_var; + + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + Output var_1 = + ops::VarHandleOp(scope.WithOpName("var_1"), DataType::DT_FLOAT, {}); + Output read_var_1 = + ops::ReadVariableOp(scope.WithOpName("var_1/Read/ReadVariableOp"), + var, DataType::DT_FLOAT); + auto assign_1 = + ops::AssignVariableOp(scope.WithOpName("assign_1"), var_1, a); + } else { + read_var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a); + Output var_1 = + ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); + Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var_1, a); + } + + Output c = ops::Mul(scope.WithOpName("c"), a, read_var); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + size_t expected_nodes = use_resource ? 4 : 3; + EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + EXPECT_NE(node.op(), "VarHandleOp") << node.name(); + EXPECT_NE(node.op(), "ReadVariableOp") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } }; TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) { @@ -196,22 +351,21 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } -TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { - // Test freezing a graph with variables that are not needed by the outputs in - // the SignatureDef. The resulting graph shouldn't be frozen, but - // non-dependent nodes should be pruned. +TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) { + // Tensors from operations with multiple outputs get tensor suffixes when used + // in input fields of following nodes, i.e. split:0, split:1. + // Test that we traverse those correctly. SavedModelBundle saved_model_bundle; GraphDef graph_def; Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2}); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output; Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); - Output c = ops::Mul(scope.WithOpName("c"), a, b); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + Output c = ops::Mul(scope.WithOpName("c"), split[1], b); TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); GraphDef frozen_graph_def; std::unordered_set inputs; @@ -219,34 +373,24 @@ TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, &outputs)); - GraphDef expected_graph_def; - Scope expected_scope = Scope::NewRootScope(); - Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); - Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); - Output expected_c = - ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); - TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); - - GraphDefEqual(frozen_graph_def, expected_graph_def); - - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); + GraphDefEqual(frozen_graph_def, graph_def); } -TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { - // Test freezing a graph with variables that are needed by outputs in the - // SignatureDef. The variables should be frozen. +TEST_F(FreezeTest, GraphDefWithControlDependency) { + // Inputs that are control dependencies get tensor prefixes, + // i.e. ^control_dependency. + // Test that we traverse those correctly. SavedModelBundle saved_model_bundle; GraphDef graph_def; Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output c = ops::Mul(scope.WithOpName("c"), a, var); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + Output source = ops::Const(scope.WithOpName("source"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source), + {10.0f, 10.0f}, {2}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); GraphDef frozen_graph_def; std::unordered_set inputs; @@ -254,53 +398,31 @@ TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, &outputs)); - // There should be 3 nodes in the resulting graph_def, and none should be - // variables. - EXPECT_EQ(frozen_graph_def.node_size(), 3); - for (const NodeDef& node : frozen_graph_def.node()) { - EXPECT_NE(node.op(), "Variable") << node.name(); - EXPECT_NE(node.op(), "VariableV2") << node.name(); - } + GraphDefEqual(frozen_graph_def, graph_def); +} - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { + TestFreezeGraphWithoutDependentVariables(false); } -TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) { - // Test freezing a graph with some variables that are needed and not needed by - // the outputs in the SignatureDef. The resulting graph should only freeze - // dependent variables. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output c = ops::Mul(scope.WithOpName("c"), a, var); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - Output var_1 = - ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); - Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); +TEST_F(FreezeTest, GraphDefWithoutDependentResourceVariables) { + TestFreezeGraphWithoutDependentVariables(true); +} - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); +TEST_F(FreezeTest, GraphDefWithDependentVariables) { + TestFreezeGraphWithDependentVariables(false); +} - // There should be 3 nodes in the resulting graph_def, and none should be - // variables. - EXPECT_EQ(frozen_graph_def.node_size(), 3); - for (const NodeDef& node : frozen_graph_def.node()) { - EXPECT_NE(node.op(), "Variable") << node.name(); - EXPECT_NE(node.op(), "VariableV2") << node.name(); - } +TEST_F(FreezeTest, GraphDefWithDependentResourceVariables) { + TestFreezeGraphWithDependentVariables(true); +} + +TEST_F(FreezeTest, GraphDefWithAndWithoutDependentVariables) { + TestFreezeGraphWithAndWithoutDependentVariables(false); +} - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) { + TestFreezeGraphWithAndWithoutDependentVariables(true); } } // namespace diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index 3675d72ee354533a7d84b5e8783cde452d8d60c9..5dbc4f5f6aa389978e55ca2656c17ff97202203d 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -166,7 +167,8 @@ namespace { bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, int32* dst) { - if (arg.Consume(flag) && arg.Consume("=")) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag) && + tensorflow::str_util::ConsumePrefix(&arg, "=")) { char extra; return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); } @@ -176,7 +178,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool* dst) { - if (arg.Consume(flag)) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag)) { if (arg.empty()) { *dst = true; return true; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 0900e87ebabd378e6237b77ca0ef01677c07c244..2119c8ec47f941a76e81346ae5d20da78eae11a3 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -60,6 +60,7 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -72,6 +73,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], ) @@ -212,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:support", "@llvm//:target", ], @@ -249,17 +250,3 @@ exports_files([ "benchmark_main.template", # used by tf_library(...,gen_benchmark=True) "test.cc", # used by tf_library(...,gen_test=True) ]) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2cae85e8965216eaaee4d3032015d0016258a5c1..0025842aead53973befc794378a26fa8db2ae1cb 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -333,6 +333,20 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" : ""; + const string include_hlo_profile_printer_data_proto = + opts.gen_hlo_profile_printer_data + ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")" + : ""; + + // When HLO profiling is disabled we only forward declare the + // HloProfilePrinter protobuf. So we can only conditionally emit this code + // calling HloProfilePrinter::profile_counters_size. + const string assign_profile_counters_size = + opts.gen_hlo_profile_printer_data + ? "data->profile_counters_size = " + "data->hlo_profile_printer_data->profile_counters_size();" + : ""; + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -348,6 +362,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) {{INCLUDE_XLA_DATA_PROTO}} +{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}} #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" @@ -418,6 +433,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); + data->hlo_profile_printer_data = StaticHloProfilePrinterData(); + {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); return *kStaticData; @@ -487,6 +504,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } + + // Metadata that can be used to pretty-print profile counters. + static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() { + static const xla::HloProfilePrinterData* kHloProfilePrinterData = + {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}; + return kHloProfilePrinterData; + } }; {{NS_END}} @@ -501,35 +525,41 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, + {"{{DECLS_FROM_OBJ_FILE}}", + str_util::Join(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, + {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", + metadata_result.hlo_profile_printer_data_access_shim}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, + {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}", + include_hlo_profile_printer_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", + metadata_result.program_shape_access_shim}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, - {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}, - {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, - {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", - metadata_result.program_shape_access_shim}}; + {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } -static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) { +static string CreateUniqueIdentifier(const CodegenOpts& opts, + StringPiece suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { strings::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape"); + strings::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -550,18 +580,31 @@ Status GenerateMetadata(const CodegenOpts& opts, // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives // a shim that evaluates to nullptr, which is what we want. + ProtobufToEmbed program_shape_protobuf{ + CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", + program_shape.get()}; + + ProtobufToEmbed hlo_profile_printer_data_protobuf{ + CreateUniqueIdentifier(opts, "HloProfilePrinterData"), + "xla::HloProfilePrinterData", + compile_result.aot->hlo_profile_printer_data()}; + TF_ASSIGN_OR_RETURN( - EmbeddedProtocolBuffer embedded_program_shape, - CreateEmbeddedProtocolBuffer(opts.target_triple, - CreateUniqueIdentifierForProgramShape(opts), - "xla::ProgramShape", program_shape.get())); + EmbeddedProtocolBuffers embedded_protobufs, + CreateEmbeddedProtocolBuffers( + opts.target_triple, + {program_shape_protobuf, hlo_profile_printer_data_protobuf})); metadata_result->program_shape_access_shim = - std::move(embedded_program_shape.cpp_shim_expression); + std::move(embedded_protobufs.cpp_shims[0].expression); + metadata_result->hlo_profile_printer_data_access_shim = + std::move(embedded_protobufs.cpp_shims[1].expression); + metadata_result->header_variable_decls.emplace_back( + std::move(embedded_protobufs.cpp_shims[0].variable_decl)); metadata_result->header_variable_decls.emplace_back( - std::move(embedded_program_shape.cpp_variable_decl)); + std::move(embedded_protobufs.cpp_shims[1].variable_decl)); metadata_result->object_file_data = - std::move(embedded_program_shape.object_file_data); + std::move(embedded_protobufs.object_file_data); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 3430b1f96cf4d3c035b76c77ccf124c5d164751e..83f2d3ee11d09d66f16d7ecdc11945ebe994a82a 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -44,6 +44,10 @@ struct CodegenOpts { // If true, generate program shape data for the ProgramShape method. bool gen_program_shape = false; + + // If true, emit a serialized HloProfilePrinterData protobuf that can be used + // to pretty print HLO profile counters. + bool gen_hlo_profile_printer_data = false; }; // Describes a generated metadata object file. @@ -57,6 +61,12 @@ struct MetadataResult { // GenerateMetadata. string program_shape_access_shim; + // hlo_profile_printer_data_access_shim is a C++ expression that constructs + // the xla::HloProfilePrinterData instance for the CompileResult passed to + // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a + // C++ expression that evaluates to nullptr at runtime. + string hlo_profile_printer_data_access_shim; + // The contents of the object (".o") file. string object_file_data; }; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 972b7d51ecb3798e61757ac55e973075a23b433a..29bc9c13b889c86c2ba8776c7b067c54cb05bc43 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -33,7 +34,7 @@ namespace { void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } @@ -171,7 +172,7 @@ TEST(CodegenTest, Golden) { fetch->set_name("myfetch"); CompileResult compile_result; compile_result.aot.reset( - new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5)); + new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index ac3b5873318873b5fdf41bd556a0b2abddc2b30b..6641d45e83020f4144616a6a2837c844330298f5 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -10,6 +10,7 @@ #define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) #include "tensorflow/compiler/xla/xla_data.pb.h" + #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" @@ -23,6 +24,7 @@ extern "C" void entry_point( extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; + namespace foo { namespace bar { @@ -54,9 +56,9 @@ namespace bar { // // Memory stats: // arg bytes total: 104 -// arg bytes aligned: 128 +// arg bytes aligned: 192 // temp bytes total: 126 -// temp bytes aligned: 224 +// temp bytes aligned: 320 class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. @@ -82,6 +84,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); + data->hlo_profile_printer_data = StaticHloProfilePrinterData(); + return data; }(); return *kStaticData; @@ -243,6 +247,13 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { }(); return kShape; } + + // Metadata that can be used to pretty-print profile counters. + static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() { + static const xla::HloProfilePrinterData* kHloProfilePrinterData = + nullptr; + return kHloProfilePrinterData; + } }; } // end namespace bar diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd..bbc35da2ef6d14ff0d3570ef2d5cf6743456c674 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -43,7 +44,7 @@ namespace { // Compiles the XLA computation into executable code. Status CompileXla(xla::CompileOnlyClient* client, - const xla::Computation& computation, + const xla::XlaComputation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -61,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::CompileOnlyClient::AotComputationInstance instance; + xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -87,20 +88,19 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, // Converts the graph into an XLA computation, and compiles the // computation. // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? - namespace gpu = perftools::gputools; - gpu::Platform* cpu_platform = - gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); + se::Platform* cpu_platform = + se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); xla::CompileOnlyClient* client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR( ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, + TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); - // Serialize the SessionModule deterministically so that all the outputs of - // a tf_library genrule are deterministic. + // Serialize the HloSnapshot deterministically so that all the outputs of a + // tf_library genrule are deterministic. string proto; TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); TF_RETURN_IF_ERROR( @@ -110,6 +110,7 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, flags.target_triple, flags.target_cpu, flags.target_features, flags.entry_point, xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic); + return CompileXla(client, computation, aot_opts, compile_result); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 6489929a576d6469c4ff1358ca5ee9d27fb578bb..4e27aafec7747655d8e4ea3ddd1788d495ca0710 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" -#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" @@ -37,9 +36,8 @@ namespace tfcompile { using xla::llvm_ir::AsStringRef; -static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( - llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine, - const ::tensorflow::protobuf::MessageLite& proto, +static void AddEmbeddedProtocolBufferToLlvmModule( + llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, StringPiece unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); @@ -47,19 +45,14 @@ static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( strings::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); - std::unique_ptr module = - MakeUnique("embedded_data_module", *llvm_context); - llvm::Constant* protobuf_array_initializer = - llvm::ConstantDataArray::getString(*llvm_context, + llvm::ConstantDataArray::getString(module->getContext(), AsStringRef(protobuf_array_contents), /*AddNull=*/false); new llvm::GlobalVariable( *module, protobuf_array_initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); - - return module; } static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, @@ -89,7 +82,8 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, llvm::legacy::PassManager codegen_passes; if (target_machine->addPassesToEmitFile( - codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + codegen_passes, ostream, nullptr, + llvm::TargetMachine::CGFT_ObjectFile)) { return xla::InternalError( "Could not create pass pipeline to generate object file"); } @@ -116,42 +110,44 @@ GetTargetMachineFromTriple(StringPiece target_triple) { /*Features=*/"", llvm::TargetOptions(), llvm::None)); } -StatusOr CreateEmbeddedProtocolBuffer( - StringPiece target_triple, StringPiece symbol_prefix, - StringPiece qualified_cpp_protobuf_name, - const ::tensorflow::protobuf::MessageLite* proto) { +StatusOr CreateEmbeddedProtocolBuffers( + StringPiece target_triple, + gtl::ArraySlice protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; - string object_file, cpp_shim, cpp_variable_decl; - - if (proto) { - string protobuf_array_symbol_name; - int64 protobuf_array_size; - - std::unique_ptr module_with_serialized_proto = - CreateModuleWithEmbeddedProtocolBuffer( - &llvm_context, target_machine.get(), *proto, symbol_prefix, - &protobuf_array_symbol_name, &protobuf_array_size); - TF_ASSIGN_OR_RETURN(object_file, - CodegenModule(target_machine.get(), - std::move(module_with_serialized_proto))); - cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name, - protobuf_array_symbol_name, - protobuf_array_size); - - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); - } else { - TF_ASSIGN_OR_RETURN( - object_file, - CodegenModule(target_machine.get(), - MakeUnique("empty_module", llvm_context))); - cpp_shim = "nullptr"; + std::unique_ptr module_with_serialized_proto = + MakeUnique("embedded_data_module", llvm_context); + + EmbeddedProtocolBuffers result; + + for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) { + string cpp_shim, cpp_variable_decl; + if (protobuf_to_embed.message) { + string protobuf_array_symbol_name; + int64 protobuf_array_size; + + AddEmbeddedProtocolBufferToLlvmModule( + module_with_serialized_proto.get(), *protobuf_to_embed.message, + protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name, + &protobuf_array_size); + cpp_shim = CreateCPPShimExpression( + protobuf_to_embed.qualified_cpp_protobuf_name, + protobuf_array_symbol_name, protobuf_array_size); + + cpp_variable_decl = strings::StrCat("extern \"C\" char ", + protobuf_array_symbol_name, "[];"); + } else { + cpp_shim = "nullptr"; + } + result.cpp_shims.push_back({cpp_shim, cpp_variable_decl}); } - return {{cpp_shim, cpp_variable_decl, object_file}}; + TF_ASSIGN_OR_RETURN(result.object_file_data, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); + return result; } } // namespace tfcompile diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 8436e0ff67f352a24e3d16b46f16c1ad2f3a5957..4e194a6aba9a9efcad27c47c42e148d8e537ae68 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -21,51 +21,70 @@ limitations under the License. #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace tfcompile { using xla::StatusOr; -// Represents a protocol buffer embedded into an object file and describes a way -// to access it at runtime. -struct EmbeddedProtocolBuffer { - // cpp_shim_expression is a C++ expression that creates an instance of said - // protocol buffer when executed. - string cpp_shim_expression; - - // cpp_variable_decl is an "extern C" array declaration that is used in - // cpp_shim_expression. It must be visible wherever cpp_shim_expression is - // emitted. - string cpp_variable_decl; - - // The contents of the object (".o") file the protocol buffer is embbed in. - // This needs to be linked in to any program that wants to execute - // cpp_variable_decl . +// Represents a set of protocol buffers embedded into an object file and +// describes how to access them at runtime. +struct EmbeddedProtocolBuffers { + // Each instance CPPShim describes how to generate C++ code to instantiate a + // protobuf instance from the corresponding static data emitted into the + // object file. + struct CPPShim { + // `expression` is a C++ expression that creates an instance of said + // protocol buffer when executed. + string expression; + + // `variable_decl` is an "extern C" array declaration that is used in + // `expression`. It must be visible wherever `expression` is emitted. + string variable_decl; + }; + + // Each cpp_shim corresponds to one embedded protocol buffer. + std::vector cpp_shims; + + // The contents of the object (".o") file the protocol buffers are embbed in. + // This needs to be linked in to any program that wants to execute any of the + // expressions in `cpp_shims`. string object_file_data; }; -// Creates an object file that contains `proto`. -// -// `proto` is allowed to be nullptr, in which case the generated C++ shim -// expression is just `nullptr`, and the generated object file does not define -// any symbols. +// Describes a protocol buffer to embed into an object file. +struct ProtobufToEmbed { + // `symbol_prefix` is prefix that is guaranteed to be unique across the binary + // or DSO the generated object file will be linked into. + string symbol_prefix; + + // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ + // namespace qualified) protocol buffer name. This is only used in + // CPPShim::expression so relatively qualified names are fine as long as + // they're valid wherever CPPShim::expression is emitted. + string qualified_cpp_protobuf_name; + + // `message` is the protocol buffer to be embedded. It is allowed to be + // nullptr, in which case the generated C++ shim expression is just `nullptr`, + // and the generated object file does not define any symbols. + const ::tensorflow::protobuf::MessageLite* message; +}; + +// Embeds a sequence of protocol buffers into an object file. // // `target_triple` is the target triple for the target architecture for the // generated object file. // -// `symbol_prefix` is prefix that is guaranteed to be unique across the binary -// or DSO the generated object file will be linked into. -// -// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ -// namespace qualified) protocol buffer name. This needs is only used in -// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified -// names are fine as long as they're valid wherever cpp_shim_expression -// is emitted. -StatusOr CreateEmbeddedProtocolBuffer( - StringPiece target_triple, StringPiece symbol_prefix, - StringPiece qualified_cpp_protobuf_name, - const ::tensorflow::protobuf::MessageLite* proto); +// `protobufs_to_embed` describes the protocol buffers to embed into the +// resulting object file. The C++ shim for protobufs_to_embed[i] is +// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents +// of all the protocol buffers are embedded into a single .o file whose content +// is stored in the object_file_data field in the returned +// EmbeddedProtocolBuffers instance. +StatusOr CreateEmbeddedProtocolBuffers( + StringPiece target_triple, + gtl::ArraySlice protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 5772776666129ed55a479c8917e69df3f3ce2fc0..5e74079fc158379b8977ada6412141e39142c3d3 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,7 +31,7 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); -#elif defined(COMPILER_MSVC) +#elif defined(_WIN32) return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; @@ -48,7 +48,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { } inline void aligned_free(void* aligned_memory) { -#if defined(COMPILER_MSVC) +#if defined(_WIN32) _aligned_free(aligned_memory); #else free(aligned_memory); diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h index d085864f0012e4de55685bb46961417bb3070e6f..d1a669ceb17b9fd71d26e978035283f8824b0376 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/aot/runtime.h @@ -25,8 +25,8 @@ namespace tensorflow { namespace tfcompile { namespace runtime { -// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 32; +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +static constexpr size_t kAlign = 64; // aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 // values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index 6d603a02eb4ceade6832ba67b2981814ee25327a..06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -24,7 +24,7 @@ namespace runtime { namespace { TEST(Runtime, AlignmentValue) { - // We've chosen 32 byte alignment for the tfcompile runtime to mimic the + // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 @@ -39,13 +39,13 @@ TEST(Runtime, AlignedBufferBytes) { EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192); + EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -101,11 +101,11 @@ TEST(Runtime, MallocFreeContiguousBuffers) { EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); - EXPECT_EQ(bufD[2], add_ptr(base, 32)); + EXPECT_EQ(bufD[2], add_ptr(base, 64)); EXPECT_EQ(bufD[3], nullptr); - EXPECT_EQ(bufD[4], add_ptr(base, 64)); - EXPECT_EQ(bufD[5], add_ptr(base, 128)); - EXPECT_EQ(bufD[6], add_ptr(base, 160)); + EXPECT_EQ(bufD[4], add_ptr(base, 128)); + EXPECT_EQ(bufD[5], add_ptr(base, 192)); + EXPECT_EQ(bufD[6], add_ptr(base, 256)); for (int i = 0; i < 7; ++i) { const intptr_t size = sizesD[i]; if (size != -1) { diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc index 47ef5f82cbc718ea300afa0c4eb4b73e1ca22fd0..6b098049cbd7539a2b2e2696b13139a8a6b28e0f 100644 --- a/tensorflow/compiler/aot/test.cc +++ b/tensorflow/compiler/aot/test.cc @@ -35,6 +35,7 @@ limitations under the License. // clang-format on #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 28aab6eb614ca7123d9e00f7f5cc3661b62e23f7..0ecc3feeb6fef1dd691ab2785b3221075a79ba88 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -7,6 +7,10 @@ package( load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# We disable some tfcompile tests in the open source build with the +# "manual" tag to avoid making our OSS users build LLVM twice +# (once for host and once for target). + test_suite( name = "all_tests", tags = ["manual"], @@ -14,6 +18,8 @@ test_suite( ":test_graph_tfadd_test", ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", + ":test_graph_tfassert_eq_test", + ":test_graph_tfcond_test", ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", @@ -33,6 +39,7 @@ py_binary( "//tensorflow/python", # TODO(b/34059704): remove when fixed "//tensorflow/python:array_ops", "//tensorflow/python:client", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform", @@ -52,6 +59,8 @@ genrule( "test_graph_tfadd_with_ckpt_saver.ckpt", "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", + "test_graph_tfassert_eq.pb", + "test_graph_tfcond.pb", "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", @@ -104,6 +113,28 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfassert_eq", + testonly = 1, + config = "test_graph_tfassert_eq.config.pbtxt", + cpp_class = "AssertComp", + graph = "test_graph_tfassert_eq.pb", + tags = [ + "manual", + ], +) + +tf_library( + name = "test_graph_tfcond", + testonly = 1, + config = "test_graph_tfcond.config.pbtxt", + cpp_class = "CondComp", + graph = "test_graph_tfcond.pb", + tags = [ + "manual", + ], +) + tf_library( name = "test_graph_tffunction", testonly = 1, @@ -149,6 +180,15 @@ tf_library( tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) +tf_library( + name = "test_graph_tfmatmulandadd_with_profiling", + testonly = 1, + config = "test_graph_tfmatmulandadd.config.pbtxt", + cpp_class = "MatMulAndAddCompWithProfiling", + enable_xla_hlo_profiling = True, + graph = "test_graph_tfmatmulandadd.pb", +) + tf_library( name = "test_graph_tfsplits", testonly = 1, @@ -170,29 +210,21 @@ tf_cc_test( ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", + ":test_graph_tfassert_eq", + ":test_graph_tfcond", ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", + ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_profile_printer", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 89c7cd4507cbd476104a039d6083d8f89de11278..9ec7df163b1425f917e9ec51559efad3e6f05e75 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app @@ -77,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir): f.write(saver.as_saver_def().SerializeToString()) +def tfassert_eq(_): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + control_flow_ops.Assert( + math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') + math_ops.add(x, math_ops.negative(y), name='x_y_diff') + + +def tfcond(_): + p = array_ops.placeholder(dtypes.bool, name='p_hold') + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + z = control_flow_ops.cond(p, lambda: x, lambda: y) + array_ops.identity(z, name='result') + + def tfgather(_): params = array_ops.placeholder(dtypes.float32, name='params') indices = array_ops.placeholder(dtypes.int32, name='indices') @@ -139,10 +156,12 @@ def main(_): write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) + write_graph(tfassert_eq, FLAGS.out_dir) + write_graph(tfcond, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) - write_graph(tffunction, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) diff --git a/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8732d1709e809bb47d3769c483483c2c4f350e1c --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_diff" } +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..94a01ad4abfaab5e4b087b7cc219e86c1d0179b8 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt @@ -0,0 +1,20 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "p_hold" } + shape {} +} +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "result" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 413efd9cea3b6f71574615ad9ca92471ff925781..fee46280e9a0e7ba2cf7c3ed46469ae8cc0841d4 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -20,19 +20,28 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h" #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace tfcompile { namespace { +using ::testing::HasSubstr; +using ::testing::IsSupersetOf; + TEST(TFCompileTest, Add) { AddComp add; EXPECT_EQ(add.arg0_data(), add.args()[0]); @@ -142,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); } +TEST(TFCompileTest, Cond) { + CondComp cond; + EXPECT_EQ(cond.arg0_data(), cond.args()[0]); + EXPECT_EQ(cond.arg1_data(), cond.args()[1]); + EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + cond.arg1() = 10; + cond.arg2() = 20; + { + cond.arg0() = true; + const int32 expected_result = cond.arg1(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } + { + cond.arg0() = false; + const int32 expected_result = cond.arg2(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } +} + TEST(TFCompileTest, Gather) { GatherComp gather; EXPECT_EQ(gather.arg0_data(), gather.args()[0]); @@ -413,6 +447,23 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, AssertEqAndReturnDiff) { + // Assert is converted into a no-op in XLA, so there is no failure even if the + // two args are different. + AssertComp assert; + EXPECT_EQ(assert.arg0_data(), assert.args()[0]); + EXPECT_EQ(assert.arg1_data(), assert.args()[1]); + + assert.arg0() = 2; + assert.arg1() = 1; + const int32 expected_result = assert.arg0() - assert.arg1(); + EXPECT_TRUE(assert.Run()); + EXPECT_EQ(assert.error_msg(), ""); + EXPECT_EQ(assert.result0(), expected_result); + EXPECT_EQ(assert.result0_data()[0], expected_result); + EXPECT_EQ(assert.result0_data(), assert.results()[0]); +} + TEST(TFCompileTest, LookupNameIndex) { // add doesn't have any names defined in its config. AddComp add; @@ -466,6 +517,56 @@ TEST(TFCompileTest, ProgramShape) { EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); } +TEST(TFCompileTest, HloProfiling) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + MatMulAndAddCompWithProfiling fn; + ASSERT_TRUE(fn.hlo_profiling_enabled()); + + fn.set_thread_pool(&device); + + // x = [[1, 2], [3, 4]] + fn.arg0(0, 0) = 1; + fn.arg0(0, 1) = 2; + fn.arg0(1, 0) = 3; + fn.arg0(1, 1) = 4; + + // y = [[10, 20], [30, 40]] + fn.arg1(0, 0) = 10; + fn.arg1(0, 1) = 20; + fn.arg1(1, 0) = 30; + fn.arg1(1, 1) = 40; + + EXPECT_TRUE(fn.Run()); + + string hlo_profile_as_string = + xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), + /*clock_rate_ghz=*/1.0); + VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + + std::vector hlo_profile_lines = + tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + + auto header = HasSubstr("Execution profile for"); + auto total_cycles_profile_line = HasSubstr("[total]"); + auto dot_profile_line = HasSubstr( + "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); + auto add_profile_line = HasSubstr( + "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); + auto tuple_profile_line = HasSubstr( + "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " + "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); + auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + + EXPECT_THAT(hlo_profile_lines, + IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, + add_profile_line, tuple_profile_line})); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 9dff1be09fede6f65f82c2f36d94be07e781949f..5c57fee326ca743dcb8aaae354d261ed4d7f44be 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -25,7 +25,8 @@ def tf_library(name, graph, config, visibility=None, testonly=None, tfcompile_flags=None, tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, deps=None, tags=None): + include_standard_runtime_deps=True, + enable_xla_hlo_profiling=False, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. Given an invocation of tf_library(name="foo", ...), generates the following @@ -68,6 +69,8 @@ def tf_library(name, graph, config, include_standard_runtime_deps: If True, the standard list of kernel/runtime deps is added to deps. If False, deps must contain the full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, + and emit metadata that lets us pretty-print the gathered profile counters. deps: a list of deps to include on the build rules for the generated library, added to the standard deps if standard_runtime_deps is True. tags: tags to apply to subsidiary build rules. @@ -132,11 +135,15 @@ def tf_library(name, graph, config, header_file = name + ".h" metadata_object_file = name + "_tfcompile_metadata.o" function_object_file = name + "_tfcompile_function.o" - ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags else: flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( name=("gen_" + name), srcs=[ @@ -157,7 +164,7 @@ def tf_library(name, graph, config, " --out_header=$(@D)/" + header_file + " --out_metadata_object=$(@D)/" + metadata_object_file + " --out_function_object=$(@D)/" + function_object_file + - " " + flags), + " " + flags + " " + profiling_flag), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -220,6 +227,8 @@ def tf_library(name, graph, config, ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the proto. "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data" ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index e2f01179d4e2e4f6ef72b2761d06e130ffa3a94f..839e1588b7be6c91cf30c87bbaf75402446bd169 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -55,7 +55,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (StringPiece(fname).ends_with(".pbtxt")) { + if (str_util::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -100,6 +100,8 @@ Status Main(const MainFlags& flags) { if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } + codegen_opts.gen_hlo_profile_printer_data = + xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a711319607f4ff2b83aa0ebe50e215b3d0e2258e..8c74014614789758192691ee065f92759a113a7a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -25,11 +25,15 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", - visibility = [":friends"], + visibility = [ + ":friends", + "//learning/tfx:__subpackages__", + ], deps = [ ":xla_cpu_device", ":xla_cpu_jit", @@ -73,6 +77,7 @@ cc_library( ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -102,22 +107,45 @@ cc_library( cc_library( name = "xla_interpreter_device", srcs = ["xla_interpreter_device.cc"], + visibility = [":friends"], deps = [ + ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_tensor", + srcs = ["xla_tensor.cc"], + hdrs = ["xla_tensor.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], - alwayslink = True, ) cc_library( name = "xla_device", srcs = [ + "xla_compile_on_demand_op.cc", "xla_device.cc", "xla_device_context.cc", "xla_device_ops.cc", ], hdrs = [ + "xla_compile_on_demand_op.h", "xla_device.h", "xla_device_context.h", "xla_device_ops.h", @@ -127,6 +155,8 @@ cc_library( deps = [ ":common", ":jit_compilation_passes", + ":xla_launch_util", + ":xla_tensor", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", @@ -146,13 +176,23 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", ], ) +cc_library( + name = "shape_inference_helpers", + srcs = ["shape_inference_helpers.cc"], + hdrs = ["shape_inference_helpers.h"], + deps = ["//tensorflow/core:graph"], +) + # Internal targets below this point. cc_library( @@ -166,6 +206,30 @@ cc_library( visibility = [":friends"], ) +cc_library( + name = "xla_launch_util", + srcs = ["xla_launch_util.cc"], + hdrs = ["xla_launch_util.h"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:variable_ops", + ], +) + cc_library( name = "xla_compilation_cache", srcs = ["xla_compilation_cache.cc"], @@ -197,30 +261,40 @@ cc_library( ) cc_library( - name = "graph_to_functiondef", - srcs = ["graph_to_functiondef.cc"], - hdrs = ["graph_to_functiondef.h"], + name = "create_xla_launch_op", + srcs = [ + "create_xla_launch_op.cc", + "create_xla_launch_op.h", + ], deps = [ - "//tensorflow/core:core_cpu", + ":common", + ":compilation_passes", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], + alwayslink = 1, ) -cc_library( - name = "create_xla_launch_op", +tf_cc_test( + name = "create_xla_launch_op_test", srcs = [ - "create_xla_launch_op.cc", + "create_xla_launch_op.h", + "create_xla_launch_op_test.cc", ], deps = [ - ":common", - ":compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:xla_compiler", + ":create_xla_launch_op", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -238,11 +312,11 @@ cc_library( ], deps = [ ":common", - ":graph_to_functiondef", + ":shape_inference_helpers", ":union_find", + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", - "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", @@ -256,6 +330,20 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", + ], +) + +cc_library( + name = "xla_cluster_util", + srcs = ["xla_cluster_util.cc"], + hdrs = ["xla_cluster_util.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", ], ) @@ -264,22 +352,19 @@ cc_library( hdrs = ["union_find.h"], ) +cc_library( + name = "producer_consumer_queue", + hdrs = ["producer_consumer_queue.h"], + deps = ["//tensorflow/core:lib"], +) + tf_cc_test( - name = "graph_to_functiondef_test", + name = "producer_consumer_queue_test", size = "small", - srcs = [ - "graph_to_functiondef_test.cc", - ], + srcs = ["producer_consumer_queue_test.cc"], deps = [ - ":graph_to_functiondef", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework_internal", + ":producer_consumer_queue", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -306,24 +391,68 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], ) -# ----------------------------------------------------------------------------- +tf_cc_test( + name = "xla_launch_util_test", + size = "small", + srcs = ["xla_launch_util_test.cc"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_launch_util", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/kernels:variable_ops", + ], +) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +cc_library( + name = "xla_fusion_optimizer", + srcs = ["xla_fusion_optimizer.cc"], + hdrs = ["xla_fusion_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ], +) + +tf_cuda_cc_test( + name = "xla_fusion_optimizer_test", + srcs = ["xla_fusion_optimizer_test.cc"], + deps = [ + ":common", + ":xla_cluster_util", + ":xla_fusion_optimizer", + "//tensorflow/core:graph", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/utils:grappler_test", + ], ) # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb0007527557f79b70ad2b9c9576af2ab10ea..b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 18d901323f108505979be484c2bfad5998ab0748..731b8ebfdc6262500940274c94a03ae7c0376096 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/create_xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" @@ -21,82 +22,194 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { -// Givens a NodeDef 'ndef' and the function library runtime 'flr', if -// 'ndef' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. // -// This routine is here so that FunctionLibraryRuntime can jit a -// specific function call as requested. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr* kernel) { - bool xla_compile = false; - if (!flr->GetFunctionLibraryDefinition() - ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) - .ok() || - !xla_compile) { - // Not marked as _XlaCompile=true. - return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; } - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, ndef)) { - // ndef is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + + private: + int current_index_; + const std::vector* values_; +}; + +Status CompilationRequested(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { + bool xla_compile = false; + // Check if op is marked _XlaCompile=true. + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return Status(error::INVALID_ARGUMENT, ""); } + return Status::OK(); +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector* constant_arg_indices, + std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If ndef is not instantiable, e.g., the function does not exist, + // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); // Can't be nullptr since we just instantiated it. - std::vector const_args(fbody->arg_types.size()); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - // There is a const arg. Bail out. - return errors::InvalidArgument("Const arg: ", i, " in ", - DebugString(fbody->fdef)); + constant_arg_indices->push_back(i); + } + } + + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); } } - NodeDef launch_def; - launch_def.set_name(ndef.name()); - launch_def.set_op("_XlaLaunch"); - launch_def.set_device(flr->device()->name()); - AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); - AddNodeAttr("Nresources", 0, &launch_def); - AddNodeAttr("Targs", fbody->arg_types, &launch_def); - AddNodeAttr("Tresults", fbody->ret_types, &launch_def); - NameAttrList func; - func.set_name(ndef.op()); - *(func.mutable_attr()) = ndef.attr(); - AddNodeAttr("function", func, &launch_def); - - // TODO(b/32387911): Handles the host memory types across function - // calls properly. For now, we assume all inputs and outputs are on - // the device memory. + return Status::OK(); +} + +} // namespace + +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel) { + TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); + + VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, node_def)) { + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", + node_def.ShortDebugString()); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &launch_def, + dev->GetAllocator(AllocatorAttributes()), &node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - kernel->reset(new XlaLocalLaunchOp(&construction)); + + *kernel = MakeUnique(&construction, constant_arg_indices, + resource_arg_indices, function); return s; } +namespace { + bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h new file mode 100644 index 0000000000000000000000000000000000000000..98a22e351532c197c69c5ea908305d885fd2c9d0 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +// Given a NodeDef 'node_def' and the function library runtime 'flr', if +// 'node_def' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b75ab486b80e098bc0a59f9ea8cdbaa23a28fef9 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/create_xla_launch_op.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +// Create a FunctionDef that takes one resource and one regular param +FunctionDef XTimesY() { + return FunctionDefHelper::Define( + // Name + "XTimesY", + // Args + {"x: float", "y: resource"}, + // Return values + {"z: float"}, + // Attr def + {}, + // Nodes + { + {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, + {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, + }); +} + +class CreateXlaLaunchOpTest : public ::testing::Test { + protected: + void Init(const std::vector& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + + FunctionDefLibrary proto; + for (const auto& fdef : flib) { + *(proto.add_function()) = fdef; + } + lib_def_ = + MakeUnique(OpRegistry::Global(), proto); + OptimizerOptions opts; + device_mgr_ = MakeUnique(devices_); + pflr_ = MakeUnique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + } + + FunctionLibraryRuntime* flr_; + std::vector devices_; + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; + + std::unique_ptr kernel_; +}; + +AttrValue BoolAttr(bool b) { + AttrValue v; + v.set_b(b); + return v; +} + +TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + Init({fdef}); + + Status status = CreateXlaLaunchOp( + flr_, ToNodeDef(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), &kernel_); + ASSERT_TRUE(status.ok()) << status.ToString(); + + EXPECT_EQ("XTimesY", kernel_->name()); + EXPECT_EQ("XTimesY", kernel_->type_string()); + + EXPECT_EQ(2, kernel_->num_inputs()); + EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); + EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); + EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); + + EXPECT_EQ(1, kernel_->num_outputs()); + EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { + FunctionDef fdef = XTimesY(); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9c372a012789fc25ca0a711349c09ca62edc6754..ea90d714c8fe2c2959a92a83ec105dbf2c098f7a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/graph_to_functiondef.h" -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -33,9 +33,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" @@ -53,6 +55,8 @@ namespace tensorflow { const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; namespace { @@ -143,7 +147,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kHostComputeOp = "_XlaHostCompute"; +static const char* const kHostComputeOp = "XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -156,6 +160,11 @@ class Encapsulator { std::move(outside_compilation_attribute)), graph_in_(graph_in) {} + // Find dependencies between subgraphs and outside_compilation clusters that + // only manifest via edges between outside_compilation clusters in the outer + // (non-compiled) graph. + Status FindClusterDependencies(); + // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. Status SplitIntoSubgraphs(); @@ -172,8 +181,7 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, - FunctionLibraryDefinition* library); + Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -226,6 +234,19 @@ class Encapsulator { // the shapes of any ancestor RAH outputs. If it can be determined that the // shape of the SFH inputs will not be inferrable even once the shapes of the // RAH outputs are known, an error is returned by the rewriter. + // + // Once edges between compiled and outside_compilation clusters have been + // replaced by send/recv ops, some dependencies may no longer be apparent. + // A clustering pass finds all the dependencies between HC nodes that are only + // present as a result of edges between nodes in outside_compilation clusters. + // Suppose there is a path from outside_compilation cluster C in subgraph S + // to outside_compilation cluster D in subgraph T. If S != T then a control + // edge is added from the call node for S to the call node for T, which + // ensures that C will execute before D because S executes before T. If S==T + // then a control dependency is added between the HC nodes for C and D in S, + // and the HC node for C is added to an 'ancestors' attr in the HC node for D + // so that during compilation of the HC node for D, an XLA control dependency + // can be added to ensure C's SendToHost executes before D's RecvFromHost. class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -248,11 +269,12 @@ class Encapsulator { // Adds the function call node to graph_out. Status AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( - const string& subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const std::unordered_map& node_images, Graph* graph_out); @@ -260,11 +282,9 @@ class Encapsulator { // Subgraph. void GetOutsideCompilationSubgraphNames(std::vector* names) const; - // Returns the Node that inputs to the function should be wired up to. - Node* GetCallNodeForInputs() const; - - // Returns the Node that outputs to the function should be wired up to. - Node* GetCallNodeForOutputs() const; + // Returns the Node that the inputs and outputs of the function should be + // wired up to. + Node* GetCallNode() const; // Returns the index of the arg that the dst of edge should connect to. int GetArgIndexForEdge(const Edge* edge) const; @@ -319,6 +339,18 @@ class Encapsulator { void RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge); + // Records the fact that there is a path from a node in outside_compilation + // cluster ancestor to node in cluster successor that does not go through + // the subgraph. + void RecordOutsideCompilationDependency(const string& successor, + const string& ancestor); + + // Returns the mapping from outside_compilation cluster C to the set of + // outside_compilation clusters that have a path to C entirely outside + // compiled subgraphs. + const std::unordered_map> + OutsideCompilationAncestorMap() const; + // Adds the HostCompute nodes for each outside_compilation subgraph. Status AddHostComputes( const string& subgraph_name, @@ -328,12 +360,14 @@ class Encapsulator { Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to - // all the downstream nodes of call_node_outputs. - void ConnectSequencerToOutputs(Graph* graph_out); + // the call node. + void ConnectSequencerToCallNode(Graph* graph_out); Status AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph); + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library); Status ReplaceFunctionDef(FunctionLibraryDefinition* library); @@ -381,15 +415,30 @@ class Encapsulator { Node* send_from_host = nullptr; }; - // Builds a ParallelCheck op that compares the output of the original - // subgraph with the encapsulated subgraph. - Status BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out); + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Returns the (possible newly created) subgraph for + // outside_compilation_id. + OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( + const string& outside_compilation_id); + + // Builds a placeholder node used to provide the key input to a RecvAtHost + // or SendFromHost node. This placeholder node will be removed by a later + // pass. + Status AddHostComputeKeyPlaceholder(OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); + + // Get the set of outside_compilation clusters and the dependency edges + // between them. + void GetActiveClusterDependencyGraph( + std::unordered_set* clusters, + std::unordered_set* has_successor, + std::unordered_map>* ancestors_map); // Builds a _RecvAtHost node producing all the inputs of an // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. - Status AddRecvAtHostNode(const string& subgraph_name, + Status AddRecvAtHostNode(const string& group_attribute, + const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); @@ -398,8 +447,10 @@ class Encapsulator { // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. Status AddSendFromHostNode( const std::unordered_map& node_images, - const string& subgraph_name, const string& oc_subgraph_name, - OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, + const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are @@ -413,13 +464,16 @@ class Encapsulator { // NodeDef for the function call node. NodeDef call_node_def_; - // Function call node(s) in the output graph. Not owned. - // If parallel_checking is enabled, 'call_node_inputs' is the function call - // node to which inputs should be fed, and 'call_node_outputs' is the - // parallel check op from which outputs should be read. If parallel checking - // is disabled, both point to the function call node. - Node* call_node_inputs_; - Node* call_node_outputs_; + // Name that is used for the call node. This may not be + // call_node_def_.name() if the client supplies a rewrite lambda. + string function_def_name_; + + // Placeholder node simulating the host compute key in the output graph. + // Not owned. + Node* host_compute_key_placeholder_ = nullptr; + + // Function call node in the output graph. Not owned. + Node* call_node_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in @@ -437,6 +491,14 @@ class Encapsulator { // The outside_compilation clusters in this subgraph. std::unordered_map outside_compilation_subgraphs_; + // For each outside_compilation cluster C, the outside_compilation clusters + // that have a path to C outside the compiled graph. + std::unordered_map> + outside_compilation_ancestors_; + // For each outside_compilation cluster C, the outside_compilation clusters + // that have a path from C outside the compiled graph. + std::unordered_map> + outside_compilation_successors_; // NoOp node in the output graph that is sequenced after the call node and // used to prevent host-side outside_compilation sends and recvs from being @@ -464,13 +526,12 @@ class Encapsulator { // Copies all nodes that aren't in a compiled subgraph to the output graph. Status CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images); + Graph* graph_out, std::unordered_map* node_images); // Adds function call nodes for each compiled subgraph. Status AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all // outside_compilation subgraphs. @@ -521,14 +582,18 @@ class Encapsulator { const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, + Graph* graph_out, std::unordered_set, NodeSlot::PairHasher>* edges_added); + // Adds control dependencies between subgraph call nodes that have + // dependencies via outside_compilation edges. + Status AddCallNodeDependencies(Graph* graph_out); + // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Constructs a minimal shape inference graph that can be used to determine // the shape of send_node at the time that the subgraph is compiled. @@ -547,11 +612,12 @@ class Encapsulator { // satisfied, e.g., because send_node depends on a node that doesn't have a // registered shape inference function. Status DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const ShapeRefiner& shape_refiner, + const Graph& graph_in, const BackEdgeHelper& back_edge_helper, + const ShapeRefiner& shape_refiner, const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out); + std::unique_ptr* graph_out); // Makes a copy of graph containing only nodes that are ancestors of at least // one node in send_from_host_nodes and store it in pruned_graph. On exit @@ -570,7 +636,7 @@ class Encapsulator { // to nodes in pruned_graph. Status MakeGraphForOutsideCompilationSends( const Graph& graph, std::unique_ptr* pruned_graph, - ShapeRefiner* shape_refiner, + BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, std::unordered_map* node_images, FunctionLibraryDefinition* library); @@ -588,18 +654,67 @@ class Encapsulator { const Graph* graph_in_; std::unordered_map subgraphs_; + // For each subgraph S the subgraphs S' such that there is a path in some + // outside_compilation cluster C in S to some outside_compilation cluster C' + // in S', that goes only through the uncompiled graph. + std::unordered_map> subgraph_ancestors_; TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; -Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { - return call_node_inputs_; -} +namespace { -Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { - return call_node_outputs_; +// Return in 'sorted' a topological sort of clusters according to the +// dependencies encoded in ancestors. clusters is the list of all clusters +// including clusters that are not present in the ancestors map. has_successors +// is the set of clusters that are ancestors of some other cluster. +void TopologicalClusterSort( + const std::unordered_set& clusters, + const std::unordered_set& has_successors, + const std::unordered_map>& ancestors, + std::vector* sorted) { + // The nodes are placed in 'sorted' in topological order. + sorted->clear(); + // We don't use the standard DFS because we are not operating on Node* + // objects. + struct Work { + string cluster; + bool leave; + }; + std::set visited; + std::vector stack; + // Seed the processing list with clusters that have no successors. + for (const auto& cluster : clusters) { + if (has_successors.find(cluster) == has_successors.end()) { + stack.push_back({cluster, false}); + } + } + while (!stack.empty()) { + const Work item = stack.back(); + stack.pop_back(); + if (item.leave) { + sorted->push_back(item.cluster); + continue; + } + + if (visited.find(item.cluster) != visited.end()) continue; + visited.insert(item.cluster); + + stack.push_back({item.cluster, true}); + const auto& iter = ancestors.find(item.cluster); + if (iter != ancestors.end()) { + for (const auto& ancestor : iter->second) { + stack.push_back({ancestor, false}); + } + } + } + CHECK(sorted->size() == clusters.size()); } +} // namespace + +Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; } + int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); } @@ -712,49 +827,113 @@ Status Encapsulator::Subgraph::RecordResult( return Status::OK(); } -void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( - const string& outside_compilation_id, const Edge* edge) { +Encapsulator::Subgraph::OutsideCompilationSubgraph* +Encapsulator::Subgraph::LookupOrCreateOutsideCompilationSubgraph( + const string& outside_compilation_id) { auto iter = outside_compilation_subgraphs_ .emplace(outside_compilation_id, OutsideCompilationSubgraph()) .first; - OutsideCompilationSubgraph& outside_subgraph = iter->second; + OutsideCompilationSubgraph* outside_subgraph = &iter->second; + return outside_subgraph; +} + +void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge) { + OutsideCompilationSubgraph* outside_subgraph = + LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); if (edge->IsControlEdge()) { - outside_subgraph.control_inputs.insert(edge->src()); + outside_subgraph->control_inputs.insert(edge->src()); } else { - int input_index = outside_subgraph.inputs.size(); - outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + int input_index = outside_subgraph->inputs.size(); + outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), + input_index); } } void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge) { - auto subgraph_iter = - outside_compilation_subgraphs_ - .emplace(outside_compilation_id, OutsideCompilationSubgraph()) - .first; - OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; + OutsideCompilationSubgraph* outside_subgraph = + LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); if (edge->IsControlEdge()) { - outside_subgraph.control_outputs.insert(edge->dst()); + outside_subgraph->control_outputs.insert(edge->dst()); } else { DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = - outside_subgraph.outputs_by_src + outside_subgraph->outputs_by_src .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph.outputs_by_src.size()) + outside_subgraph->outputs_by_src.size()) .first; int output_index = output_iter->second; - outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = output_index; } } +void Encapsulator::Subgraph::RecordOutsideCompilationDependency( + const string& successor, const string& ancestor) { + outside_compilation_ancestors_[successor].insert(ancestor); + outside_compilation_successors_[ancestor].insert(successor); +} + +const std::unordered_map> +Encapsulator::Subgraph::OutsideCompilationAncestorMap() const { + return outside_compilation_ancestors_; +} + +void Encapsulator::Subgraph::GetActiveClusterDependencyGraph( + std::unordered_set* clusters, + std::unordered_set* has_successor, + std::unordered_map>* ancestors_map) { + // During initial clustering the ancestor and successor datastructures may + // have been built including oc_cluster names that never turned into subgraphs + // because they had no edges into or out of the compiled cluster. Remove them + // before proceeding to simplify the logic. Get the set of clusters that was + // actually added, then remove references to the others. + for (const auto& oc_subgraph : outside_compilation_subgraphs_) { + clusters->insert(oc_subgraph.first); + } + for (const auto& cluster : outside_compilation_successors_) { + if (clusters->find(cluster.first) != clusters->end()) { + for (const auto& successor : cluster.second) { + if (clusters->find(successor) != clusters->end()) { + has_successor->insert(cluster.first); + break; + } + } + } + } + for (const auto& cluster : outside_compilation_ancestors_) { + if (clusters->find(cluster.first) != clusters->end()) { + std::unordered_set& ancestors = (*ancestors_map)[cluster.first]; + for (const auto& ancestor : cluster.second) { + if (clusters->find(ancestor) != clusters->end()) { + ancestors.insert(ancestor); + } + } + } + } +} + Status Encapsulator::Subgraph::AddHostComputes( const string& subgraph_name, const std::unordered_map& node_images) { - for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { - const string& oc_subgraph_name = oc_subgraph_iter.first; - OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + // Get the set of outside_compilation clusters and the dependency edges + // between them. + std::unordered_set clusters; + std::unordered_set has_successor; + std::unordered_map> ancestors_map; + GetActiveClusterDependencyGraph(&clusters, &has_successor, &ancestors_map); + // Topologically sort the outside_compilation clusters according to their + // dependency relation. + std::vector sorted_clusters; + TopologicalClusterSort(clusters, has_successor, ancestors_map, + &sorted_clusters); + + // The host compute nodes added for each outside_compilation_cluster; + std::unordered_map host_compute_node; + for (const string& oc_subgraph_name : sorted_clusters) { + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraphs_[oc_subgraph_name]; if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || !oc_subgraph.outputs_by_src.empty() || !oc_subgraph.control_outputs.empty()) { @@ -774,13 +953,22 @@ Status Encapsulator::Subgraph::AddHostComputes( inputs[input_index].Reset(src_image->name(), src_slot, dtype); input_dtypes[input_index] = dtype; } - for (const auto& output : oc_subgraph.outputs_by_src) { DataType dtype = output.first.dtype; int output_index = output.second; output_dtypes[output_index] = dtype; } + std::vector host_compute_ancestors; + const auto iter = ancestors_map.find(oc_subgraph_name); + if (iter != ancestors_map.end()) { + for (const string& ancestor_cluster : iter->second) { + host_compute_ancestors.push_back( + outside_compilation_subgraphs_[ancestor_cluster] + .host_compute_name); + } + } + NodeDef host_compute_def; NodeDefBuilder builder(strings::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), @@ -788,14 +976,17 @@ Status Encapsulator::Subgraph::AddHostComputes( builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); + builder.Attr("ancestors", host_compute_ancestors); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; Node* host_compute = graph_->AddNode(host_compute_def, &s); if (!s.ok()) return s; + host_compute_node[host_compute->name()] = host_compute; oc_subgraph.host_compute_name = host_compute->name(); // Connect the _HostCompute node to its producers in the subgraph. @@ -814,6 +1005,12 @@ Status Encapsulator::Subgraph::AddHostComputes( graph_->AddControlEdge(src_image, host_compute); } + // Connect the _HostCompute node to its ancestor host compute nodes. + for (const auto& ancestor_name : host_compute_ancestors) { + Node* ancestor = host_compute_node[ancestor_name]; + graph_->AddControlEdge(ancestor, host_compute); + } + // Connect the consumers in the subgraph to the _HostCompute node. for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; @@ -842,25 +1039,21 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, NodeDef seq_def; NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), "NoOp"); + builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); + builder.Device(device_); Status s = builder.Finalize(&seq_def); if (!s.ok()) return s; sequencer_ = graph_out->AddNode(seq_def, &s); if (!s.ok()) return s; - sequencer_->set_assigned_device_name(device_); } return Status::OK(); } -void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { +void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { - std::unordered_set output_dependencies; - for (Node* node : call_node_outputs_->out_nodes()) { - output_dependencies.insert(node); - } - for (Node* node : output_dependencies) { - graph_out->AddControlEdge(sequencer_, node); - } + VLOG(2) << "ConnectSequencerToCallNode"; + graph_out->AddControlEdge(sequencer_, call_node_); } } @@ -906,6 +1099,8 @@ Status Encapsulator::Subgraph::BuildFunctionDef( name = call_node_def_.op(); } + function_def_name_ = name; + FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -924,8 +1119,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( } Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph) { + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library) { OutsideCompilationSubgraph& oc_subgraph = outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); @@ -947,21 +1144,22 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shape_inference_graph", ""); host_compute->AddAttr("shapes", shapes); } else { - string serialized_graph; - if (!inference_graph->SerializeToString(&serialized_graph)) { - return errors::Internal( - "Failed to serialize graph for outside compilation subgraph ", - oc_subgraph.host_compute_name); - } - host_compute->AddAttr("shape_inference_graph", serialized_graph); + string inference_graph_name = + strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); + FunctionDef fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); + host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); } return Status::OK(); } Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { - const string& name = call_node_def_.name(); + const string& name = function_def_name_; FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -980,89 +1178,50 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( return Status::OK(); } -Status Encapsulator::Subgraph::BuildParallelCheckOp( +Status Encapsulator::Subgraph::AddFunctionCallNode( const std::unordered_map& node_images, Graph* graph_out) { - // Build an index mapping output positions to node/slot pairs in the - // original graph. - std::vector results_by_num(results_.size()); - for (const auto& entry : results_) { - results_by_num[entry.second] = entry.first; - } - - // Build a parallel check NodeDef. - int num_results = results_by_num.size(); - std::vector result_dtypes(num_results); - std::vector expected_outputs(num_results); - std::vector actual_outputs(num_results); - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - result_dtypes[i] = node_slot.node->output_type(node_slot.slot); - expected_outputs[i] = - NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), - node_slot.slot, result_dtypes[i]); - actual_outputs[i] = - NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); - } - // Assign the parallel check op to a CPU on the same task as the cluster it is - // checking. - string device, dummy; - if (!DeviceNameUtils::SplitDeviceName( - call_node_inputs_->assigned_device_name(), &device, &dummy)) { - return errors::InvalidArgument("Could not parse device name"); - } - strings::StrAppend(&device, "/cpu:0"); - - NodeDef check_def; - TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), - "_parallel_check")), - "ParallelCheck") - .Device(device) - .Attr("T", result_dtypes) - .Input(expected_outputs) - .Input(actual_outputs) - .Finalize(&check_def)); - Status s; - Node* check_op = graph_out->AddNode(check_def, &s); + call_node_ = graph_out->AddNode(call_node_def_, &s); if (!s.ok()) return s; - check_op->set_assigned_device_name(device); - // TODO(phawkins): it seems redundant to call AddEdge as well as - // pass Inputs to the NodeDefBuilder, but I have been unable to find a - // way to avoid it. - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, - i); - graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); - } + // Copy the assigned device and the key_annotation over. + call_node_->set_assigned_device_name(device_); - call_node_outputs_ = check_op; return Status::OK(); } -Status Encapsulator::Subgraph::AddFunctionCallNode( - const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { - Status s; - call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); +Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + TensorShapeProto shape_proto; + TensorShape shape({2}); + shape.AsProto(&shape_proto); + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeDef key_def; + NodeDefBuilder builder( + strings::StrCat(call_node_def_.name(), "_key_placeholder"), + "Placeholder"); + builder.Attr("dtype", DT_STRING); + builder.Attr("shape", shape_proto); + builder.Attr("_host_compute_call_node", call_node_def_.name()); + Status s = builder.Finalize(&key_def); if (!s.ok()) return s; - // Copy the assigned device and the key_annotation over. - call_node_inputs_->set_assigned_device_name(device_); - call_node_outputs_ = call_node_inputs_; + host_compute_key_placeholder_ = graph_out->AddNode(key_def, &s); + if (!s.ok()) return s; + host_compute_key_placeholder_->set_assigned_device_name(device_); - if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); - } return Status::OK(); } Status Encapsulator::Subgraph::AddRecvAtHostNode( - const string& subgraph_name, const string& oc_subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + if (host_compute_key_placeholder_ == nullptr) { + TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); + } + std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); for (const auto& input : oc_subgraph->inputs) { @@ -1078,15 +1237,23 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); + builder.Device(device_); builder.Attr("Toutputs", dtypes); + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. + builder.Attr("device_ordinal", 0); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr(group_attribute, subgraph_name); + builder.Attr(outside_compilation_attribute, oc_subgraph_name); + builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); if (!s.ok()) return s; - oc_subgraph->recv_at_host->set_assigned_device_name(device_); + graph_out->AddEdge(host_compute_key_placeholder_, 0, + oc_subgraph->recv_at_host, 0); // Add a control dependency forcing the RecvAtHost to run before the subgraph // completes. This has no effect on execution order but prevents the @@ -1099,8 +1266,13 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( Status Encapsulator::Subgraph::AddSendFromHostNode( const std::unordered_map& node_images, - const string& subgraph_name, const string& oc_subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + if (host_compute_key_placeholder_ == nullptr) { + TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); + } + std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); std::vector inputs( oc_subgraph->outputs_by_src.size()); @@ -1120,16 +1292,24 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); + builder.Device(device_); builder.Attr("Tinputs", dtypes); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. + builder.Attr("device_ordinal", 0); + builder.Attr(group_attribute, subgraph_name); + builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(inputs); + builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&send_def); if (!s.ok()) return s; oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); if (!s.ok()) return s; - oc_subgraph->send_from_host->set_assigned_device_name(device_); + graph_out->AddEdge(host_compute_key_placeholder_, 0, + oc_subgraph->send_from_host, inputs.size()); // Add a control dependency forcing the SendFromHost to run before the // subgraph completes. This has no effect on execution order but prevents the @@ -1141,7 +1321,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( - const string& subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const std::unordered_map& node_images, Graph* graph_out) { for (auto& outside_compilation_subgraph_entry : @@ -1151,14 +1332,16 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( outside_compilation_subgraph_entry.second; if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - TF_RETURN_IF_ERROR( - AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); + TF_RETURN_IF_ERROR(AddRecvAtHostNode(group_attribute, subgraph_name, + outside_compilation_attribute, + oc_name, &oc_subgraph, graph_out)); } if (!oc_subgraph.outputs_by_src.empty() || !oc_subgraph.control_outputs.empty()) { - TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, - oc_name, &oc_subgraph, graph_out)); + TF_RETURN_IF_ERROR(AddSendFromHostNode( + node_images, group_attribute, subgraph_name, + outside_compilation_attribute, oc_name, &oc_subgraph, graph_out)); } } return Status::OK(); @@ -1355,29 +1538,17 @@ Status Encapsulator::BuildFunctionDefs( } Status Encapsulator::CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images) { + Graph* graph_out, std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; string outside_compilation_id; TF_RETURN_IF_ERROR( GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - // Don't copy nodes that going to be encapsulated, unless parallel checking - // is enabled. - if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) - continue; + // Don't copy nodes that are going to be encapsulated. + if (IsInSubgraph(func_id, outside_compilation_id)) continue; Node* image = graph_out->CopyNode(node); - if (!outside_compilation_id.empty()) { - if (parallel_checking) { - return errors::InvalidArgument( - "Parallel checking is not supported when outside_compilation " - "clusters are present."); - } - image->ClearAttr(group_attribute_); - image->ClearAttr(outside_compilation_attribute_); - } (*node_images)[node] = image; } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); @@ -1387,10 +1558,10 @@ Status Encapsulator::CopyNodesToOutputGraph( Status Encapsulator::AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { - TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( - node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + subgraph_entry.second.AddFunctionCallNode(node_images, graph_out)); } return Status::OK(); } @@ -1402,7 +1573,8 @@ Status Encapsulator::AddOutsideCompilationHostIONodes( const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( - subgraph_name, node_images, graph_out)); + group_attribute_, subgraph_name, outside_compilation_attribute_, + node_images, graph_out)); } return Status::OK(); } @@ -1423,7 +1595,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( } else { // The edge is from a subgraph to a regular node in the output graph so // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + *src_image = subgraphs_.at(src_func_id).GetCallNode(); } } else { // The source of the edge is in the output graph so use the node image in @@ -1471,7 +1643,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst( } else { // The edge is to a subgraph from a regular node in the output graph so // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); } } else { // The destination of the edge is in the output graph so use the node image @@ -1507,8 +1679,7 @@ Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, - const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, + const std::unordered_map& node_images, Graph* graph_out, std::unordered_set, NodeSlot::PairHasher>* edges_added) { Node* src_image; @@ -1530,11 +1701,6 @@ Status Encapsulator::CopyEdgeToOutputGraph( graph_out->AddControlEdge(src_image, dst_image); } - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } return Status::OK(); } @@ -1546,14 +1712,6 @@ Status Encapsulator::CopyEdgeToOutputGraph( FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, dst_func_id, dst_outside_compilation_id, edge); - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && - parallel_checking) { - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - // Add the edge, if we have not already added it. if (edges_added ->emplace(NodeSlot(src_image, src_output), @@ -1564,9 +1722,20 @@ Status Encapsulator::CopyEdgeToOutputGraph( return Status::OK(); } +Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { + for (const auto& ancestors : subgraph_ancestors_) { + const string& subgraph = ancestors.first; + for (const string& ancestor : ancestors.second) { + graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), + subgraphs_[subgraph].GetCallNode()); + } + } + return Status::OK(); +} + Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. @@ -1588,16 +1757,6 @@ Status Encapsulator::AddEdgesToOutputGraph( if (IsInSubgraph(src_func_id, src_outside_compilation_id) && IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { - if (parallel_checking) { - Node* src_image = node_images.at(edge->src()); - Node* dst_image = node_images.at(edge->dst()); - if (edge->IsControlEdge()) { - graph_out->AddControlEdge(src_image, dst_image); - } else { - graph_out->AddEdge(src_image, edge->src_output(), dst_image, - edge->dst_input()); - } - } continue; } @@ -1605,14 +1764,14 @@ Status Encapsulator::AddEdgesToOutputGraph( // unclustered graph. TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( edge, src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, parallel_checking, graph_out, - &edges_added)); + dst_outside_compilation_id, node_images, graph_out, &edges_added)); } for (auto& subgraph_entry : subgraphs_) { Subgraph& subgraph = subgraph_entry.second; - subgraph.ConnectSequencerToOutputs(graph_out); + subgraph.ConnectSequencerToCallNode(graph_out); } + TF_RETURN_IF_ERROR(AddCallNodeDependencies(graph_out)); return Status::OK(); } @@ -1625,9 +1784,13 @@ namespace { // matter because it will only be used subsequently for shape inference. (It // would be possible to add a switch statement over data_type to create a value // for the constant, but that would entail maintaining the logic as new types -// are added, and is not necessary.) -Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, - Graph* graph_out) { +// are added, and is not necessary.) If the node being replaced was within a +// control flow frame, adds appropriate Enter nodes so that the use of the Const +// is well-formed. +Node* AddDummyShapedNode(const Node* src_node, int src_port, + const std::vector& control_flow_info, + const TensorShapeProto& shape, Graph* graph_out) { + DataType data_type = src_node->output_type(src_port); TensorProto dummy_proto; dummy_proto.set_dtype(data_type); *dummy_proto.mutable_tensor_shape() = shape; @@ -1638,7 +1801,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", options.op_registry()); node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); - return options.FinalizeBuilder(&node_builder); + Node* node = options.FinalizeBuilder(&node_builder); + // Add any Enter nodes required to bring the constant to the correct control + // flow frame. + while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", + options.op_registry()); + enter_builder.Attr("frame_name", + control_flow_info[src_node->id()].frame_name); + enter_builder.Attr("is_constant", true); + enter_builder.Input(node, 0); + Node* enter_node = options.FinalizeBuilder(&enter_builder); + // Adopt the new Enter node as the value in the current frame. + node = enter_node; + // Recurse to the parent frame to see if more Enter nodes need to be added. + src_node = control_flow_info[src_node->id()].parent_frame; + } + return node; } // Adds a copy of node_in to graph_out and adds the mapping to @@ -1680,17 +1859,30 @@ Status CopyShapeInferenceNodeToGraph( } } } + // Work around the fact that Enter nodes refuse to propagate shape information + // unless they are marked loop invariant. Since we are never going to execute + // this graph, marking them all loop invariant is fine. + if (node_out->type_string() == "Enter") { + node_out->ClearAttr("is_constant"); + node_out->AddAttr("is_constant", true); + } return Status::OK(); } } // namespace Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const ShapeRefiner& shape_refiner, + const Graph& graph_in, const BackEdgeHelper& back_edge_helper, + const ShapeRefiner& shape_refiner, const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out) { + std::unique_ptr* graph_out) { + // Get the control flow structure of the input graph so we can build + // well-formed output graphs. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info)); + // Maps from nodes in graph_in to nodes in graph_out. // // When an edge has fully defined shape the source node in graph_in is @@ -1707,13 +1899,14 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( std::unordered_map dummy_node_images; std::unordered_map copied_node_images; - std::unique_ptr graph_out(new Graph(graph_in.op_registry())); - graph_out->set_versions(graph_in.versions()); - static_shape_out->resize(send_node->num_inputs()); + graph_out->reset(new Graph(graph_in.op_registry())); + (*graph_out)->set_versions(graph_in.versions()); + // The final input to the send node is the dynamic key, which we don't include + // in the static shapes. + static_shape_out->resize(send_node->num_inputs() - 1); // We don't use the standard ReverseDFS because we want to cut off traversal // whenever we find an output with fully defined shape. - // TODO(misard) make this work properly in the presence of control flow. struct Work { Node* node; bool leave; // Are we entering or leaving node? @@ -1728,7 +1921,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (w.leave) { TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( n, send_node, dummy_node_images, library, &copied_node_images, - graph_out.get())); + graph_out->get())); } else { if (visited[n->id()]) continue; visited[n->id()] = true; @@ -1750,14 +1943,24 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // continue. TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); - dummy_node_images[src_node] = AddDummyShapedNode( - src_node->output_type(src_port), proto, graph_out.get()); - if (n == send_node) { + if (dummy_node_images.find(src_node) == dummy_node_images.end()) { + dummy_node_images[src_node] = + AddDummyShapedNode(src_node, src_port, control_flow_info, + proto, graph_out->get()); + } + // The final input to the send node is the dynamic key, which we + // don't include in the static shapes. + if (n == send_node && + in_edge->dst_input() < static_shape_out->size()) { (*static_shape_out)[in_edge->dst_input()] = proto; } } else { + has_parent_with_unknown_shape = true; if (!visited[src_node->id()]) { - has_parent_with_unknown_shape = true; + if (VLOG_IS_ON(2)) { + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + } stack.push_back({src_node, false}); } } @@ -1768,7 +1971,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // The shapes of all the inputs to send_node are statically known. We // won't have to do any inference at compile time so return now: the // shapes were stored in static_shape_out above. - graphdef_out->reset(); + graph_out->reset(); return Status::OK(); } else { // Any shape that is being processed is either the original send node @@ -1791,9 +1994,214 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( } } - graphdef_out->reset(new GraphDef()); - graph_out->ToGraphDef(graphdef_out->get()); + for (const auto edge : back_edge_helper.RemovedEdges()) { + if (copied_node_images.find(edge.dst) != copied_node_images.end()) { + // The destination of this back edge was added to the inference graph, so + // fix it up. + Node* dst = copied_node_images[edge.dst]; + if (dst->type_string() != "Merge") { + return errors::InvalidArgument( + "outside_compilation cluster contains a back-edge to node ", + dst->name(), " of type ", dst->type_string(), + ". The analysis pass only supports back-edges to Merge nodes."); + } + const Edge* existing_input_edge; + if (edge.dst_input != 1 || dst->num_inputs() != 2 || + !dst->input_edge(0, &existing_input_edge).ok()) { + // TODO(misard) if we see graphs built with a different structure, relax + // this constraint. Leaving it here for now to avoid writing unnecessary + // complex code since we believe graphs generated by front ends all have + // the back edge as the second input to the merge node. + return errors::Internal( + "Internal assumption failed while rewriting an outside_compilation " + "cluster that contains a while loop. Logic assumes back-edge is to " + "port 1 of a 2-input " + "Merge node."); + } + // Connect the existing edge to both inputs of the Merge node so that the + // graph will be well-formed. + (*graph_out) + ->AddEdge(existing_input_edge->src(), + existing_input_edge->src_output(), dst, edge.dst_input); + } + } + + return Status::OK(); +} + +namespace { + +// Helper struct for building cluster dependencies and also debugging cycles in +// the dependencies. While computing dependencies we construct a mapping from +// Node* to PathDetails. +struct PathDetails { + struct SubgraphAndCluster { + string subgraph; + string outside_compilation_cluster; + bool operator==(const SubgraphAndCluster& other) const { + return subgraph == other.subgraph && + outside_compilation_cluster == other.outside_compilation_cluster; + } + }; + + struct SubgraphAndClusterHash { + inline std::size_t operator()(const SubgraphAndCluster& v) const { + return hash()( + strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + } + }; + + typedef std::unordered_set + SubgraphAndClusterSet; + + // Returns the set of (subgraph, oc_cluster) pairs that should be recorded as + // ancestors for any successor of this node. If the node is in the outer + // graph, it returns the transitive union of the ancestors of the node's + // inputs. If the node is in an outside_compilation cluster, it returns just + // that cluster. If the node is compiled, it returns the empty set. + SubgraphAndClusterSet AncestorsForSuccessor() { + if (subgraph.empty()) { + return ancestor_clusters; + } else if (outside_compilation_cluster.empty()) { + return SubgraphAndClusterSet(); + } else { + SubgraphAndCluster entry; + entry.subgraph = subgraph; + entry.outside_compilation_cluster = outside_compilation_cluster; + return SubgraphAndClusterSet({entry}); + } + } + + // The transitive union of the ancestor's of this node's inputs. This is only + // saved for debugging in order to print out enough information to debug a + // discovered cycle. + SubgraphAndClusterSet ancestor_clusters; + // The subgraph attr on this node. + string subgraph; + // The outside_compilation attr on this node. + string outside_compilation_cluster; +}; + +// Adds an edge from ancestor to successor to the cycle detector, and returns an +// error if that edge causes the formation of a cycle. In the error case, logs +// the contents of the node_ancestors_map to facilitate debugging. +Status CheckClusterDependencyForCycles( + const string& ancestor, const string& successor, + const std::unordered_map>& ancestors, + const std::unordered_map& node_ancestors_map, + GraphCycles* cycle_detector, std::map* cycle_detector_map) { + if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { + (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); + } + if (cycle_detector_map->find(successor) == cycle_detector_map->end()) { + (*cycle_detector_map)[successor] = cycle_detector->NewNode(); + } + + if (!cycle_detector->InsertEdge((*cycle_detector_map)[ancestor], + (*cycle_detector_map)[successor])) { + LOG(ERROR) << "Cycle in outside_compilation clusters"; + for (const auto& cluster : ancestors) { + LOG(ERROR) << "Cluster " << cluster.first << " depends on:"; + for (const auto& ancestor : cluster.second) { + LOG(ERROR) << " " << ancestor; + } + } + for (const auto& node_ancestors : node_ancestors_map) { + LOG(ERROR) << "Node " << node_ancestors.first->name() << " (" + << node_ancestors.second.subgraph << ";" + << node_ancestors.second.outside_compilation_cluster + << ") has ancestor clusters:"; + for (const auto& ancestor : node_ancestors.second.ancestor_clusters) { + LOG(ERROR) << " " << ancestor.subgraph << ";" + << ancestor.outside_compilation_cluster; + } + } + return errors::InvalidArgument( + "Can't compile outside_compilation clusters because there is a " + "dependency cycle: see error log for details."); + } + return Status::OK(); +} + +} // namespace +Status Encapsulator::FindClusterDependencies() { + // Map from nodes to ancestor details. A node is entered into the map if it is + // in a compilation subgraph, and outside_compilation cluster, or appears on a + // path in the outer graph leading from an outside_compilation subgraph. + std::unordered_map node_ancestors_map; + // We check that clusters are acyclic using this cycle detector. + GraphCycles cycle_detector; + // Map from cluster name to cycle detector node id. + std::map cycle_detector_map; + // Process the nodes in topologically-sorted order. + std::vector nodes; + GetReversePostOrder(*graph_in_, &nodes); + for (Node* node : nodes) { + string subgraph_name; + string oc_cluster; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &subgraph_name, &oc_cluster)); + // First create an entry in the ancestors map if the node is in a compiled + // subgraph or outside_compilation cluster, or if any incoming edge is from + // a node with an ancestor map entry; and find the union of all the + // ancestors. + if (!subgraph_name.empty()) { + node_ancestors_map[node].subgraph = subgraph_name; + node_ancestors_map[node].outside_compilation_cluster = oc_cluster; + } + for (Node* src : node->in_nodes()) { + const auto iter = node_ancestors_map.find(src); + if (iter != node_ancestors_map.end()) { + const auto& ancestors_to_follow = iter->second.AncestorsForSuccessor(); + for (const auto& ancestor : ancestors_to_follow) { + if (ancestor.subgraph != subgraph_name || + ancestor.outside_compilation_cluster != oc_cluster) { + node_ancestors_map[node].ancestor_clusters.insert(ancestor); + } + } + } + } + if (!subgraph_name.empty()) { + // The node is in a compiled subgraph or an outside_compilation cluster. + if (oc_cluster.empty()) { + // The node is not in an outside_compilation cluster. Record the + // subgraph's ancestor dependencies. + for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) { + if (cluster.subgraph != subgraph_name) { + subgraph_ancestors_[subgraph_name].insert(cluster.subgraph); + TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( + cluster.subgraph, subgraph_name, subgraph_ancestors_, + node_ancestors_map, &cycle_detector, &cycle_detector_map)); + } + } + } else { + Subgraph& subgraph = subgraphs_[subgraph_name]; + // The node is in an outside_compilation cluster. Record the cluster + // and/or subgraph ancestor dependencies. + for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) { + if (cluster.subgraph == subgraph_name) { + // The ancestor is in the same subgraph. + if (cluster.outside_compilation_cluster != oc_cluster) { + // But not in the same oc_cluster, so record the dependency. + subgraph.RecordOutsideCompilationDependency( + oc_cluster, cluster.outside_compilation_cluster); + TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( + cluster.outside_compilation_cluster, oc_cluster, + subgraph.OutsideCompilationAncestorMap(), node_ancestors_map, + &cycle_detector, &cycle_detector_map)); + } + } else { + // The ancestor is in a different subgraph, so record the + // dependency. + subgraph_ancestors_[subgraph_name].insert(cluster.subgraph); + TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( + cluster.subgraph, subgraph_name, subgraph_ancestors_, + node_ancestors_map, &cycle_detector, &cycle_detector_map)); + } + } + } + } + } return Status::OK(); } @@ -1861,7 +2269,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( Status Encapsulator::MakeGraphForOutsideCompilationSends( const Graph& graph, std::unique_ptr* pruned_graph, - ShapeRefiner* shape_refiner, + BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, std::unordered_map* node_images, FunctionLibraryDefinition* library) { // Find all the send_from_host nodes in all subgraphs, to use as roots for the @@ -1883,10 +2291,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends( // nodes, inlining any functions as needed. TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( graph, send_from_host_nodes, pruned_graph, node_images, library)); + FixupSourceAndSinkEdges(pruned_graph->get()); + + // Remove back edges from any cycles in the pruned graph to simplify shape + // inference traversal. They will be fixed up in the per-subgraph shape + // inference graphs stored in the function library. + TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get())); // Perform shape inference on the pruned graph. shape_refiner->set_require_shape_inference_fns(false); - FixupSourceAndSinkEdges(pruned_graph->get()); std::vector post_order; GetReversePostOrder(*(*pruned_graph), &post_order); for (auto node : post_order) { @@ -1904,20 +2317,28 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends( Status Encapsulator::GetShapeInfoForOutsideCompilationSends( Graph* graph_out, FunctionLibraryDefinition* library) { + BackEdgeHelper back_edge_helper; std::unique_ptr pruned_graph; ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); std::unordered_map node_images; TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( - *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + *graph_out, &pruned_graph, &back_edge_helper, &shape_refiner, + &node_images, library)); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", + *pruned_graph, library); + } for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; // Find all the recv_at_host nodes in this subgraph. std::vector outside_compilation_names; subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); std::unordered_set recv_at_host_names; - for (const auto& name : outside_compilation_names) { - Node* recv_node = subgraph.GetRecvAtHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(oc_name); if (recv_node != nullptr) { recv_at_host_names.insert(recv_node->name()); } @@ -1926,26 +2347,30 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( // without knowing the shape of the recv_at_host nodes, and store the // result, along with enough information to complete the job at compile time // once the recv_at_host shapes are known. - for (const auto& name : outside_compilation_names) { - Node* send_node = subgraph.GetSendFromHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(oc_name); std::vector static_shape; - std::unique_ptr graphdef; + std::unique_ptr graph; if (send_node != nullptr) { TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( - *pruned_graph, shape_refiner, recv_at_host_names, - node_images[send_node], library, &static_shape, &graphdef)); - if (graphdef == nullptr) { + *pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names, + node_images[send_node], library, &static_shape, &graph)); + if (graph == nullptr) { VLOG(2) << "Send node " << send_node->name() << " shapes"; for (int i = 0; i < static_shape.size(); ++i) { VLOG(2) << static_shape[i].DebugString(); } } else { - VLOG(2) << "Send node " << send_node->name() << " graph\n" - << graphdef->DebugString(); + if (VLOG_IS_ON(2)) { + GraphDef graphdef; + graph->ToGraphDef(&graphdef); + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef.DebugString(); + } } } - TF_RETURN_IF_ERROR( - subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo( + subgraph_name, oc_name, static_shape, graph.get(), library)); } if (!outside_compilation_names.empty()) { TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); @@ -1955,18 +2380,15 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, +Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; - TF_RETURN_IF_ERROR( - CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); - TF_RETURN_IF_ERROR( - AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); + TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); - TF_RETURN_IF_ERROR( - AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); TF_RETURN_IF_ERROR( GetShapeInfoForOutsideCompilationSends(graph_out, library)); @@ -1979,13 +2401,14 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, Status EncapsulateSubgraphsInFunctions( string group_attribute, string outside_compilation_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, - bool parallel_checking, bool reuse_existing_functions, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), std::move(outside_compilation_attribute), &graph_in); + TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies()); TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( @@ -1993,8 +2416,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); + TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library)); *graph_out = std::move(out); return Status::OK(); @@ -2035,8 +2457,6 @@ static Status RenumberArguments(Graph* graph, Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; - legacy_flags::EncapsulateSubgraphsPassFlags* flags = - legacy_flags::GetEncapsulateSubgraphsPassFlags(); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, options.flib_def); @@ -2113,7 +2533,7 @@ Status EncapsulateSubgraphsPass::Run( TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, flags->tf_xla_parallel_checking, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be4409a381197d2191e083727aa8d48ab8cd63..e5dab7c657c79afa861b0443314d0c7801e4b66d 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -61,10 +61,6 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index aed9cae0f1799c4524da8ee309344849798755d5..6a7cd932e53cc0428850ab048cc325e84ef1fce6 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" @@ -20,15 +21,38 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace { +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; + +Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, + const string& name_suffix, + FunctionDefLibrary* library) { + GraphDef graphdef; + TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); + std::unique_ptr graph = + std::unique_ptr(new Graph(OpRegistry::Global())); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get())); + FunctionDef* fdef = library->add_function(); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *graph, + strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + fdef)); + return Status::OK(); +} + template bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const ::tensorflow::protobuf::Map& b, @@ -50,7 +74,7 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), " has value '", + key_to_string(elt_a.first), "' has value '", value_to_string(elt_a.second), "' got: '", value_to_string(iter->second), "'"); } @@ -97,8 +121,22 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } return false; } + std::unordered_set control_input_a; + std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (a.input(i) != b.input(i)) { + if (str_util::StartsWith(a.input(i), "^")) { + if (!str_util::StartsWith(b.input(i), "^")) { + if (diff) { + *diff = strings::StrCat( + diff_preamble, " mismatch for node ", a.name(), " input ", i, + ", expected control input ", a.input(i), " got ", b.input(i), + " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } + control_input_a.insert(a.input(i)); + control_input_b.insert(b.input(i)); + } else if (a.input(i) != b.input(i)) { if (diff) { *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), " input ", i, ", expected ", a.input(i), @@ -108,24 +146,26 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return false; } } + if (control_input_a != control_input_b) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } return EqualProtoMap( a.attr(), b.attr(), [](const string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, [](const string& key, const AttrValue& av, const AttrValue& bv) { - if (key == "shape_inference_graph") { - // Default serialization of GraphDef is unstable because maps don't - // serialize deterministically. Rather than go through the hoops to - // turn on deterministic serialization of this attr just for this - // test, add logic here to compare determinstically. - GraphDef ga; - if (!ga.ParseFromString(av.s())) { - return false; - } - GraphDef gb; - if (!gb.ParseFromString(bv.s())) { - return false; - } - return EqualGraphDef(ga, gb, nullptr); + if (key == "ancestors") { + // The ancestors are added from a set so the order is unpredictable; + // just compare set equality not list equality. + std::unordered_set a_set(av.list().s().begin(), + av.list().s().end()); + std::unordered_set b_set(bv.list().s().begin(), + bv.list().s().end()); + return a_set == b_set; } else { return av.DebugString() == bv.DebugString(); } @@ -246,26 +286,33 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) -// TODO(misard): remove these fake registrations once there are real Ops to be -// compiled. -REGISTER_OP("_XlaHostCompute") +// These dummy Op registrations are here because the real Op registrations live +// in contrib and there can't be a dependence from this test to contrib. +REGISTER_OP("XlaHostCompute") .Input("inputs: Tinputs") .Output("outputs: Toutputs") .Attr("Tinputs: list(type) >= 0") .Attr("Toutputs: list(type) >= 0") + .Attr("ancestors: list(string) >= 0") .Attr("key: string") + .Attr("shape_inference_graph: string = ''") + .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaSendFromHost") - .Input("input: Tinputs") + .Input("inputs: Tinputs") + .Input("dynamic_key: string") .Attr("Tinputs: list(type) >= 0") .Attr("key: string") + .Attr("device_ordinal: int") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaRecvAtHost") - .Output("output: Toutputs") + .Input("dynamic_key: string") + .Output("outputs: Toutputs") .Attr("Toutputs: list(type) >= 0") .Attr("key: string") + .Attr("device_ordinal: int") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("InputTest") @@ -315,8 +362,13 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); -Node* NoOp(const GraphDefBuilder::Options& opts) { - return ops::SourceOp("NoOp", opts); +Node* Sequencer(const GraphDefBuilder::Options& opts, + const string& call_node_name) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp", + opts.op_registry()); + return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name) + .FinalizeBuilder(&node_builder); } Node* Input(const GraphDefBuilder::Options& opts) { @@ -327,43 +379,85 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShape(const gtl::ArraySlice& shape, - const GraphDefBuilder::Options& opts) { +Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", opts.op_registry()); TensorProto value; - value.set_dtype(DT_FLOAT); + value.set_dtype(dtype); for (int dim : shape) { value.mutable_tensor_shape()->add_dim()->set_size(dim); } return opts.WithAttr("value", value) - .WithAttr("dtype", DT_FLOAT) + .WithAttr("dtype", dtype) + .FinalizeBuilder(&node_builder); +} + +Node* KnownShape(const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { + return KnownShapeBase(DT_FLOAT, shape, opts); +} + +Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) { + return KnownShapeBase(DT_STRING, {2}, opts); +} + +Node* KeyPlaceholder(const string& call_node, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Placeholder"), "Placeholder", + opts.op_registry()); + TensorShapeProto shape; + shape.add_dim()->set_size(2); + return opts.WithAttr("shape", shape) + .WithAttr("dtype", DT_STRING) + .WithAttr("_host_compute_call_node", call_node) .FinalizeBuilder(&node_builder); } -Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, +Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, + const string& oc_cluster, + const gtl::ArraySlice& dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), + string key = + strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = strings::StrCat("outside_compilation_", cluster, "_", + oc_cluster, "_recv"); + NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); + node_builder.Input(std::move(key_input)); return opts.WithAttr("Toutputs", dtypes) .WithAttr("key", key) + .WithAttr("device_ordinal", 0) + .WithAttr("_encapsulate", cluster) + .WithAttr("_outside", oc_cluster) .FinalizeBuilder(&node_builder); } -Node* SendFromHost(const string& key, const std::vector& inputs, +Node* SendFromHost(ops::NodeOut key_input, const string& cluster, + const string& oc_cluster, + const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), + string key = + strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = strings::StrCat("outside_compilation_", cluster, "_", + oc_cluster, "_send"); + NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); + node_builder.Input(std::move(key_input)); std::vector dtypes; for (const auto& node : inputs) { dtypes.push_back(node.dt); } - return opts.WithAttr("key", key) - .WithAttr("Tinputs", dtypes) + return opts.WithAttr("Tinputs", dtypes) + .WithAttr("key", key) + .WithAttr("device_ordinal", 0) + .WithAttr("_encapsulate", cluster) + .WithAttr("_outside", oc_cluster) .FinalizeBuilder(&node_builder); } @@ -417,7 +511,6 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -466,8 +559,9 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* b = Input(b1.opts().WithName("B")); // Give nodes 'c' and 'd' names that collide after lowercasing. Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); - Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( - "_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } @@ -520,8 +614,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Node* c = Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( "_encapsulate", "F1")); - Node* d = - Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( + Node* d = Binary(b, c, + b1.opts().WithName("D").WithControlInput(control).WithAttr( "_encapsulate", "F2")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -613,7 +707,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; @@ -627,47 +721,6 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -TEST(EncapsulateSubgraphsTest, ParallelChecking) { - Scope root = Scope::NewRootScope().ExitOnError().WithDevice( - "/job:localhost/replica:0/task:0/cpu:0"); - auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); - auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); - add1.node()->AddAttr("_cluster", "cluster1"); - auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); - add2.node()->AddAttr("_cluster", "cluster1"); - auto out = ops::Mul(root.WithOpName("mul"), x1, add2); - - Graph graph_before_encapsulation(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); - - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - std::unique_ptr graph; - TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, - /*reuse_existing_functions=*/false, &graph, &library)); - - std::vector expected_nodes = { - "add1", "add2", "cluster1", "cluster1_parallel_check/_0", - "mul", "x1", "x2"}; - EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - - std::vector> expected_edges = { - {"add1:0", "add2:0"}, - {"add2:0", "cluster1_parallel_check/_0:0"}, - {"cluster1:0", "cluster1_parallel_check/_0:1"}, - {"cluster1_parallel_check/_0:0", "mul:1"}, - {"x1:0", "add1:0"}, - {"x1:0", "cluster1:0"}, - {"x1:0", "mul:0"}, - {"x2:0", "add1:1"}, - {"x2:0", "add2:1"}, - {"x2:0", "cluster1:1"}, - }; - EXPECT_EQ(expected_edges, GraphEdges(*graph)); -} - const Node* FindNodeByName(const Graph& graph, const string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; @@ -711,7 +764,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - StringPiece(n->name()).starts_with("const")) { + str_util::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -720,7 +773,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); EXPECT_EQ(2, guaranteed_consts); } @@ -756,7 +808,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - StringPiece(n->name()).starts_with("const")) { + str_util::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -765,7 +817,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const // and another non-const, so overall non-const. @@ -806,19 +857,20 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -833,13 +885,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -851,24 +906,29 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + b2.opts() + .WithName("E") + .WithControlInputs({recv, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = + b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -918,38 +978,43 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected_1; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape1.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape1.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape1.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape1_graph; - TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); - EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } - string shape_string_expected_2; { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - shape2.opts().WithName("E")); - Node* recv2 = - RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, - shape2.opts().WithName("outside_compilation_F1_O2_recv")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); - SendFromHost("host_compute_channel_F1_O2", {h}, - shape2.opts().WithName("outside_compilation_F1_O2_send")); - GraphDef shape2_graph; - TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); - EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + shape2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* h = Binary(ops::NodeOut(recv2, 1), e, + shape2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } *library_expected.add_function() = FunctionDefHelper::Create( @@ -966,22 +1031,29 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, - "_XlaHostCompute", - {"D:o:0", "F:o:0"}, + "XlaHostCompute", + {"F:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", + gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", shape_string_expected_2}, - {"shapes", gtl::ArraySlice({})}}, - {"F"}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O2"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}, + {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected_1}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -993,35 +1065,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); - - Node* recv2 = - RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O2_recv")); - Node* g = Binary(e, ops::NodeOut(recv2, 1), - b2.opts().WithName("G").WithControlInputs({recv2, e})); - Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* g = Binary(e, ops::NodeOut(recv2, 0), + b2.opts() + .WithName("G") + .WithControlInputs({recv2, e}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* h = Binary(ops::NodeOut(recv2, 1), e, + b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost("host_compute_channel_F1_O2", {h}, - b2.opts().WithName("outside_compilation_F1_O2_send")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); + + Node* s = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); - Node* s = NoOp(b2.opts() - .WithName("F1_sequencer") - .WithControlInputs({recv1, send1, recv2, send2})); + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + Binary(g, call, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1070,19 +1152,20 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } TensorShapeProto shape_proto_expected; @@ -1100,13 +1183,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -1120,14 +1206,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"G:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1138,39 +1226,221 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = InputShaped(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant1 = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Node* recv2 = - RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* key_constant2 = + KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", + {DT_FLOAT}, b2.opts()); Node* h = Binary(ops::NodeOut(call1, 1), recv2, - b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = - SendFromHost("host_compute_channel_F2_O1", {h}, - b2.opts().WithName("outside_compilation_F2_O1_send")); + b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts()); + Node* s2 = Sequencer( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), + "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); Node* call2 = b2.opts() - .WithControlInputs({s1, e, call1}) + .WithControlInputs({s2, e, call1}) .FinalizeBuilder(&node_builder2); - Node* s2 = NoOp( - b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); - Binary(call2, ops::NodeOut(call2, 1), - b2.opts().WithName("J").WithControlInput(s2)); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two functions to transform, each with one outside_compilation +// cluster, with the dependency between them purely from an outside_compilation +// edge. +TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = InputShaped(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = + Binary(a, b, b1.opts().WithName("G").WithAttr("_encapsulate", "F2")); + Node* h = Unary(g, b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1") + .WithControlInput(e)); + Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); + Binary(f, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); + } + + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F2", "O1", + {DT_FLOAT}, shape.opts()); + Node* h = Unary(recv, shape.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F2", "O1", {h}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F2_O1", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"C:o:0", "D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, + {"D"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + { + {{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}}, + {{"I"}, + "UnaryTest", + {"outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"G:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F2_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F2_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, + }, + {{"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = InputShaped(b2.opts().WithName("B")); + + Node* key_constant1 = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); + + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + + Node* key_constant2 = + KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", + {DT_FLOAT}, b2.opts()); + Node* h = Unary(recv2, b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1") + .WithControlInput(e)); + Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts()); + + Node* s2 = Sequencer( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), + "F2"); + NodeBuilder node_builder2("F2", "F2", lib_def.get()); + node_builder2.Input(a).Input(b); + Node* call2 = b2.opts() + .WithControlInputs({s2, call1}) + .FinalizeBuilder(&node_builder2); + Binary(call1, call2, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1218,14 +1488,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1236,16 +1508,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts().WithName("E")); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* send1 = - SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1294,14 +1572,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1313,20 +1593,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = - SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1368,13 +1654,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}}}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1385,16 +1673,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1441,13 +1735,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}}}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1458,21 +1754,390 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("G")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two outside_compilation clusters that interact outside the compiled +// subgraph, where the ancestor has no HostCompute Op. +TEST(EncapsulateSubgraphsTest, + OutsideCompilationClusterDependencyNoSrcCluster) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Node* g = Unary(f, b1.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* h = Unary(g, b1.opts().WithName("H").WithAttr("_encapsulate", "F1")); + Binary(e, h, b1.opts().WithName("I")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT}, shape2.opts()); + Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"H"}, + "UnaryTest", + {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"F:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O2"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}}, + }, + {{"h_0_retval", "H:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT}, b2.opts()); + Node* g = Unary(recv, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("I")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two outside_compilation clusters that interact outside the compiled +// subgraph, where the successor has no HostCompute Op. +TEST(EncapsulateSubgraphsTest, + OutsideCompilationClusterDependencyNoDstCluster) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + /*Node* g =*/Unary(a, b1.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1")); + Binary(e, h, b1.opts().WithName("I")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape1.opts()); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "UnaryTest", + {"outside_compilation_O1_host_compute:outputs:0"}}, + {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, + }, + {{"h_0_retval", "H:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + /*Node* g =*/Unary(a, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b).ControlInput(s1); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("I")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two outside_compilation clusters that interact outside the compiled +// subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Node* g = Unary(d, b1.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1")); + /*Node* i =*/Binary(d, e, + b1.opts() + .WithName("I") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O3") + .WithControlInput(g)); + Binary(e, h, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape1.opts()); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + {{{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, + {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", + gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O3_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", + gtl::ArraySlice({"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, + {"key", "host_compute_channel_F1_O3"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O3"}}, + {"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"}}}, + {{"h_0_retval", "H:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT}, b2.opts()); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", + {DT_FLOAT}, b2.opts()); + /*Node* i =*/Binary(recv3, e, + b2.opts() + .WithName("I") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O3") + .WithControlInput(g)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send, recv2, recv3}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b).ControlInput(s1); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1523,7 +2188,10 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts().WithName("E")); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -1569,19 +2237,21 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape.opts()); + Node* e = BinaryUnknownShape(known, recv, + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -1595,13 +2265,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1614,26 +2287,29 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* b = Input(b2.opts().WithName("B")); Node* c = Unary(a, b2.opts().WithName("C")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + b2.opts() + .WithName("E") + .WithControlInputs({recv, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); - - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = BinaryUnknownShape( - c, ops::NodeOut(recv, 0), - b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); - - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); - - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 15507b3851751c681044a744c07c247410fb3e2d..676f71a75aede2a7720ae0c8a579d64cc184509a 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -27,17 +27,3 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index bc68afb322b5cfc814ce0537254ba14053ae4550..805bbc62c1e2e877de87ab8faf3d60b829743df8 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { return reachable; } +bool GraphCycles::CanContractEdge(int32 a, int32 b) { + CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b; + RemoveEdge(a, b); + bool reachable = IsReachableNonConst(a, b); + // Restore the graph to its original state. + InsertEdge(a, b); + // If reachable, then contracting edge will cause cycle. + return !reachable; +} + bool GraphCycles::ContractEdge(int32 a, int32 b) { CHECK(HasEdge(a, b)); RemoveEdge(a, b); @@ -388,4 +398,8 @@ std::unordered_set GraphCycles::Successors(int32 node) { return rep_->nodes_[node]->out; } +std::unordered_set GraphCycles::Predecessors(int32 node) { + return rep_->nodes_[node]->in; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index d11d6e27b1b7bb514127e16a9be21f044100d885..44448fa3d787d0785a797d40ed1b968438a903c9 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -85,6 +85,9 @@ class GraphCycles { // and returns false. bool ContractEdge(int32 a, int32 b); + // Return true if can contract edge, otherwise return false. + bool CanContractEdge(int32 a, int32 b); + // Return whether dest_node is reachable from source_node // by following edges. bool IsReachable(int32 source_node, int32 dest_node) const; @@ -115,6 +118,7 @@ class GraphCycles { bool CheckInvariants() const; std::unordered_set Successors(int32 node); + std::unordered_set Predecessors(int32 node); // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index e47b782207e9122740fe9d5daf1fa0dbaeb47754..274f5938a1228baf68ad4d8e1a7b13f276321d27 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) { EXPECT_TRUE(g_.HasEdge(1, 4)); } +TEST_F(GraphCyclesTest, CanContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + EXPECT_FALSE(g_.CanContractEdge(1, 3)); + EXPECT_FALSE(g_.CanContractEdge(2, 4)); + EXPECT_TRUE(g_.CanContractEdge(1, 2)); + EXPECT_TRUE(g_.CanContractEdge(2, 3)); + EXPECT_TRUE(g_.CanContractEdge(3, 4)); +} + static void BM_StressTest(int iters, int num_nodes) { while (iters > 0) { tensorflow::GraphCycles g; diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 9bea5663319c8a25249fdc265cee0191556a7c04..00a6f4075f9a18efc3895b033eb6d08e36088a53 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -14,6 +14,7 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", @@ -40,17 +41,3 @@ cc_library( ], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 6353149e4afdf739fe44dd5c76502ef5d98b8477..902fe27acdec1cb323217e6409fbd02f62177612 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -36,135 +37,28 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" -namespace gpu = perftools::gputools; - namespace tensorflow { -// Adapter class that wraps a Tensorflow allocator as an XLA allocator. -// Assumes that the Tensorflow allocator permits asynchronous deallocation: -// see comment on `AllowsAsynchronousDeallocation()`. -class XlaAllocator : public xla::DeviceMemoryAllocator { - public: - XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); - ~XlaAllocator() override; - xla::StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override; - - // Register an Tensor (input or resource variable) with the allocator. If - // the operation returns an alias to one of its inputs, then the allocator - // needs to be able to handle it. - Status RegisterArgument(const Tensor* t); - - // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is - // interpreted as having data type 'dtype' and shape 'shape'. - Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, - const TensorShape& shape, Tensor* tensor) const; - - // The Tensorflow BFC allocator used on GPU allows host-side deallocation - // before GPU execution takes place. Tensorflow uses the ordering of the main - // compute stream to enforce a happens-before relationship between a memory - // allocation and code that reuses the same memory. If Tensorflow adds - // support for multiple GPU streams or allocators with different ordering - // requirements, this code may need to change. - // (This attribute has no effect on CPU.) - bool AllowsAsynchronousDeallocation() const override { return true; } - - private: - OpKernelContext* const op_context_; - - // Map from pointer address to the owning Tensor; used by - // MakeTensorFromBuffer. Also used to automatically release Tensors when the - // allocator is freed. - std::unordered_map tensors_; -}; - -XlaAllocator::XlaAllocator(const gpu::Platform* platform, - OpKernelContext* op_context) - : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} - -XlaAllocator::~XlaAllocator() = default; - -xla::StatusOr XlaAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { - AllocatorAttributes allocator_attrs; - allocator_attrs.set_on_host(false); - - AllocationAttributes allocation_attrs; - allocation_attrs.no_retry_on_failure = !retry_on_failure; - - Tensor t; - Status status = op_context_->allocate_temp( - DT_UINT8, TensorShape({static_cast(size)}), &t, allocator_attrs, - allocation_attrs); - if (!status.ok()) { - VLOG(2) << "Allocation failed " << size; - return status; - } - void* data = - reinterpret_cast(const_cast(t.tensor_data().data())); - tensors_[data] = t; - return gpu::DeviceMemoryBase(data, size); -} - -Status XlaAllocator::RegisterArgument(const Tensor* t) { - void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); - tensors_[data] = *t; - return Status::OK(); -} - -Status XlaAllocator::Deallocate(int device_ordinal, - gpu::DeviceMemoryBase* mem) { - if (mem->opaque() != nullptr) { - if (tensors_.erase(mem->opaque()) == 0) { - return tensorflow::errors::InvalidArgument("Unknown tensor address"); - } - } - return Status::OK(); -} - -Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, - DataType dtype, - const TensorShape& shape, - Tensor* out_tensor) const { - void* ptr = const_cast(buffer.opaque()); - auto it = tensors_.find(ptr); - if (it == tensors_.end()) { - return errors::InvalidArgument("Unknown tensor address"); - } - const Tensor& tensor = it->second; - - int64 output_size = DataTypeSize(dtype) * shape.num_elements(); - if (tensor.TotalBytes() == output_size) { - out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); - } else { - Tensor slice = tensor.Slice(0, output_size); - out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); - } - return Status::OK(); -} - -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx), device_type_(ctx->device_type()) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + device_type_(ctx->device_type()), + function_(function) { if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id_ = gpu::host::kHostPlatformId; + platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = gpu::cuda::kCudaPlatformId; + platform_id_ = se::cuda::kCudaPlatformId; } else { platform_id_ = nullptr; } } -Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -173,9 +67,9 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } - auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); + auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_); if (!platform.ok()) { - return StreamExecutorUtil::ConvertStatus(platform.status()); + return platform.status(); } xla::LocalClientOptions client_options; client_options.set_platform(platform.ValueOrDie()); @@ -196,32 +90,15 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -std::vector SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { - std::vector snapshot(num_variables); - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { - Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - if (LookupResource(ctx, handle, &variable).ok()) { - tf_shared_lock lock(*variable->mu()); - snapshot[i].name = handle.name(); - snapshot[i].present = true; - snapshot[i].value = *variable->tensor(); - } - } - return snapshot; -} - -void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOp::Compute " +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - gpu::Stream* stream = + se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; XlaCompilationCache* cache; @@ -235,184 +112,153 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); + const XlaDevice::Metadata* metadata = nullptr; + Status s = XlaDevice::GetMetadata(ctx, &metadata); + bool allocate_xla_tensors = s.ok(); + // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { platform_id_ = metadata->platform()->id(); } } - std::vector variables = - SnapshotResourceVariables(ctx, num_resource_args_); + std::map variables = + SnapshotResourceVariables(ctx, resources_); xla::LocalClient* client = static_cast(cache->client()); - // Builds an XLA allocator for the device. - XlaAllocator xla_allocator(client->platform(), ctx); + XlaAllocator local_xla_allocator(client->backend().platform(), + ctx->device()->GetAllocator({})); + xla::DeviceMemoryAllocator* xla_allocator; + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which local_xla_allocator above uses) as on an XlaDevice, this is a + // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a + // real allocator to allocate real buffers. + if (allocate_xla_tensors) { + xla_allocator = client->backend().memory_allocator(); + } else { + xla_allocator = &local_xla_allocator; + } XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); - options.device_allocator = &xla_allocator; + options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); + options.device_allocator = xla_allocator; + if (metadata) { + options.shape_representation_fn = metadata->shape_representation_fn(); + } const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + std::map constant_args; + for (int i : constants_) { + constant_args.insert({i, ctx->input(i)}); + } + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + OP_REQUIRES_OK( + ctx, cache->Compile(options, function_, constant_args, variables, ctx, + &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; - std::unique_ptr output; - // Build xla::ShapedBuffers that point directly to the Tensor buffers. - std::vector> arg_buffers; - arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); - arg_buffers.resize(kernel->xla_input_shapes.size()); - std::vector arg_ptrs(arg_buffers.size()); - - const int first_variable_arg = ctx->num_inputs() - num_resource_args_; - // Pass remaining parameters. - const Tensor* t; - for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int arg_num = kernel->input_mapping[i]; - const xla::Shape& shape = kernel->xla_input_shapes[i]; - if (arg_num >= first_variable_arg) { - t = &(variables[arg_num - first_variable_arg].value); - } else { - t = &(ctx->input(arg_num)); - } - - gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( - const_cast(t->tensor_data().data()), t->tensor_data().size()); - - const xla::Shape on_device_shape = - client->backend().transfer_manager()->HostShapeToDeviceShape(shape); - CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) - << "On-device shape " - << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) - << " not the same as on-host shape " - << xla::ShapeUtil::HumanStringWithLayout(shape); - arg_buffers[i] = xla::MakeUnique( - /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), - client->default_device_ordinal()); - arg_buffers[i]->set_buffer(dmem, /*index=*/{}); - arg_ptrs[i] = arg_buffers[i].get(); - - OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); - } + XlaComputationLaunchContext launch_context(client, xla_allocator, + allocate_xla_tensors); + launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); - run_options.set_allocator(&xla_allocator); + run_options.set_allocator(xla_allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(ctx->step_id()); Env* env = Env::Default(); auto start_time = env->NowMicros(); - auto run_result = executable->Run(arg_ptrs, run_options); + + auto run_result = executable->Run(launch_context.arguments(), run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - // Computation output should always be a tuple. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); - } - CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); - - // Copy XLA results to the OpOutputList. - int output_num = 0; - for (int i = 0; i < ctx->num_outputs(); ++i) { - if (kernel->outputs[i].is_constant) { - // Output is a constant. - const Tensor& const_tensor = kernel->outputs[i].constant_value; - const size_t total_bytes = const_tensor.TotalBytes(); - if (stream && total_bytes > 0) { - // Copy host -> device. (Empty tensors don't have backing buffers.) - VLOG(1) << "Constant output tensor on device"; - Tensor* output_tensor; - TF_CHECK_OK( - ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - - const void* src_ptr = DMAHelper::base(&const_tensor); - void* dst_ptr = DMAHelper::base(output_tensor); - gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); - } else { - // No copy required. - ctx->set_output(i, const_tensor); - } - } else { - const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - Tensor output_tensor; - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( - buffer, ctx->expected_output_dtype(i), shape, - &output_tensor)); - ctx->set_output(i, output_tensor); - ++output_num; - } + launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); + VLOG(1) << "Done"; +} - if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DebugString(); - } - } +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} - // Apply variable updates, if any. - VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->resource_updates.size(); ++i) { - const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); - - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - - Var* variable = nullptr; - // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not - // a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - - core::ScopedUnref s(variable); - - mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); - - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, - variable->tensor())); - ++output_num; - } +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); - VLOG(1) << "Done"; + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; } +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 47fd912b12abbbe876e933ab57f6f586fd299909..8dfc4b382d51151b6383fe7dd75429f3124d39be 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,13 +26,40 @@ limitations under the License. namespace tensorflow { -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back -// resource variables since concurrent updates may modify the shape, and it is -// important that the shapes used for compilation match the true shapes of the -// buffers. -std::vector SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache); + + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + DeviceType device_type_; + NameAttrList function_; + se::Platform::Id platform_id_; +}; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is @@ -43,26 +70,12 @@ std::vector SnapshotResourceVariables(OpKernelContext* ctx, // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public OpKernel { +class XlaLocalLaunchOp : public XlaLocalLaunchBase { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; - void Compute(OpKernelContext* ctx) override; - private: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** compiler); - - DeviceType device_type_; - NameAttrList function_; - int num_constant_args_; - // Number of resource variable arguments. - int num_resource_args_; - - perftools::gputools::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 4491dd6ac8f2b84f341162eb469cc8194f817c9a..5b6692f523658749f7ef48f9d7d89e97d4ce8b09 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -16,18 +16,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -cc_library( - name = "encapsulate_subgraphs_pass_flags", - srcs = ["encapsulate_subgraphs_pass_flags.cc"], - hdrs = ["encapsulate_subgraphs_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "mark_for_compilation_pass_flags", srcs = ["mark_for_compilation_pass_flags.cc"], @@ -52,16 +40,14 @@ cc_library( ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", +cc_library( + name = "xla_device_flags", + srcs = ["xla_device_flags.cc"], + hdrs = ["xla_device_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", ], - ), - visibility = ["//tensorflow:__subpackages__"], ) diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 4bc209b7ecf499d82e7567f7eff12b17cefa9863..7277a1d1f8ad5fa045645ead839ab9efa01e89c7 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -40,6 +40,8 @@ static void AllocateFlags() { flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; flags->tf_xla_cpu_global_jit = false; + flags->tf_xla_clustering_fuel = std::numeric_limits::max(); + flags->tf_xla_fusion_only = false; flag_list = new std::vector( {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, "Control compilation of operators into XLA computations on CPU and " @@ -55,7 +57,13 @@ static void AllocateFlags() { Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, "Dump graphs during XLA compilation."), Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions.")}); + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index e1ccd7ddb8706ca445b6811ca1fec369af7cd5d5..2affda6ab4e0fbad32a246744fa5b38aeb629c1b 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -48,6 +48,13 @@ typedef struct { bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU // via SessionOptions. + int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this + // many ops will be marked as eligible for + // clustering. + bool tf_xla_fusion_only; // This flag is effective only when global_jit_level + // is set to ON* and overrides its behavior. If + // true, enable fusion of element-wise operations + // only using XLA. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc similarity index 58% rename from tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc rename to tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc index 856475f12c8a411cd80c1c1859323304ca4029e0..1bb2fce2dbad5bffce2e33b665b7222090d0855a 100644 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. +// Legacy flags for the XLA bridge's xla_device module. #include #include -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -28,33 +28,26 @@ namespace legacy_flags { // Pointers to the parsed value of the flags and flag descriptors, initialized // via flags_init. -static EncapsulateSubgraphsPassFlags* flags; +static XlaDeviceFlags* flags; static std::vector* flag_list; static std::once_flag flags_init; // Allocate *flags. Called via call_once(&flags_init,...). static void AllocateFlags() { - flags = new EncapsulateSubgraphsPassFlags; - flags->tf_xla_parallel_checking = false; + flags = new XlaDeviceFlags; + flags->tf_xla_compile_on_demand = false; flag_list = new std::vector({ - Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking, - "Debug tool. Runs both JIT-compiled and interpreted graphs in " - "parallel and verifies they produce the same outputs."), + Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), }); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } -// Append to *append_to flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; +// Return a pointer to the XlaDeviceFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() { +XlaDeviceFlags* GetXlaDeviceFlags() { std::call_once(flags_init, &AllocateFlags); return flags; } diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h similarity index 50% rename from tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h rename to tensorflow/compiler/jit/legacy_flags/xla_device_flags.h index d371bd269dbdfbf737d81490fb877fcf88661a8f..27b22121ac1e089bd5d5a494e1e3fb60b05bc76d 100644 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. +// Legacy flags for the XLA bridge's xla_device module. #include @@ -26,25 +26,22 @@ limitations under the License. namespace tensorflow { namespace legacy_flags { -// Append to *flag_list flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags( - std::vector* flag_list); - // The values of flags associated with the XLA bridge's -// encapsulate_subgraphs_pass module. +// xla_device module. typedef struct { - bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and - // interpreted graphs in parallel and verifies - // they produce the same outputs. -} EncapsulateSubgraphsPassFlags; - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +} XlaDeviceFlags; + +// Return a pointer to the XlaDeviceFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags(); +XlaDeviceFlags* GetXlaDeviceFlags(); } // namespace legacy_flags } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a0211acbbe9eec77d30c7d14293650de8826f41c..8c3882116dd4f048ea3e32c037bf4139c67a3eb9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -35,14 +36,12 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/version.h" namespace tensorflow { -const char* const kXlaClusterAttr = "_XlaCluster"; -const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; - namespace { bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { @@ -50,6 +49,23 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // is really a kind of function call and will be handled by // IsCompilableCall(). if (node.type_string() == "SymbolicGradient") return false; + if (node.type_string() == "Const") { + // Skip Const op with type DT_STRING, since XLA doesn't support it, but the + // registered Const KernelDef says that it does, to support no-op Assert for + // tfcompile. + const AttrValue* attr = node.attrs().Find("dtype"); + if (attr != nullptr && attr->type() == DT_STRING) { + return false; + } + } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (HasForwardedRefInput(node)) { + return false; + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } @@ -155,16 +171,6 @@ bool IsCompilableCall(const NodeDef& call_def, return true; } -// Returns the DeviceType corresponding to 'device'. -Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - // Tests whether `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), @@ -173,10 +179,157 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } -struct NodeCompare { - bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } -}; -using OrderedNodeSet = std::set; +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +// +// TODO(hpucha): Remove this code since this functionality is subsumed by +// Grappler XlaFusionOptimizer. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/batchtospace_op.cc + "BatchToSpace", "BatchToSpaceND", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/categorical_op.cc + "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/, + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/cross_op.cc + "Cross", + // tf2xla/kernels/depthtospace_op.cc + "DepthToSpace", + // tf2xla/kernels/diag_op.cc + "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart", + // tf2xla/kernels/dynamic_stitch_op.cc + "DynamicStitch", "ParallelDynamicStitch", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fake_quantize_ops.cc + "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/, + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/gather_op.cc + "Gather", "GatherV2", "GatherNd", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", "StopGradient", + "Snapshot", + // tf2xla/kernels/image_ops.cc + "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/, + "AdjustSaturation", "AdjustHue", + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/l2loss_op.cc + "L2Loss" /*(Reduce)*/, + // tf2xla/kernels/lrn_ops.cc (ReduceWindow) + "LRN", "LRNGrad", + // tf2xla/kernels/matrix_band_part_op.cc + "MatrixBandPart", + // tf2xla/kernels/matrix_set_diag_op.cc + "MatrixSetDiag", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/no_op.cc + "NoOp", "ControlTrigger", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/pooling_ops.cc + "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool", + "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/ + "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad", + "AvgPool3DGrad", + // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce) + "QuantizeAndDequantizeV2", + // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend + // currently) + "RandomUniform", "RandomUniformInt", "RandomStandardNormal", + "TruncatedNormal", + // tf2xla/kernels/reduction_ops.cc (Reduce) + "Sum", "Prod", "Min", "Max", "Mean", "All", "Any", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/scan_ops.cc (ReduceWindow) + "Cumsum", "Cumprod", + // tf2xla/kernels/scatter_nd_op.cc (Reduce) + "ScatterNd", + // tf2xla/kernels/segment_reduction_ops.cc (Reduce) + "UnsortedSegmentSum", + // tf2xla/kernels/select_op.cc + "Select", + // tf2xla/kernels/sequence_ops.cc + "Range", "LinSpace", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/softmax_op.cc (Reduce) + "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + // tf2xla/kernels/spacetobatch_op.cc + "SpaceToBatchND", "SpaceToBatch", + // tf2xla/kernels/spacetodepth_op.cc + "SpaceToDepth", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/stack_ops.cc + "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2", + // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on + // GPU + // backend currently) + "StatelessRandomUniform", + "StatelessRandomNormal" + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", + "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/training_ops.cc + "ResourceApplyGradientDescent", "ResourceApplyMomentum", + "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp", + "ResourceApplyFtrl", "ResourceApplyFtrlV2", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, @@ -189,15 +342,39 @@ Status FindCompilationCandidates( FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + int64& fuel = + legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + + // Iterate over nodes in sorted order so that compiler fuel is deterministic. + // We can't simply pass op_nodes().begin() and op_nodes().end to the + // std::vector constructor because they're not proper iterators, with + // iterator_traits defined and so on. + std::vector sorted_nodes; for (Node* node : graph.op_nodes()) { + sorted_nodes.push_back(node); + } + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); + + for (Node* node : sorted_nodes) { + VLOG(2) << "Fuel: " << fuel; + if (fuel <= 0) { + VLOG(2) + << "Hit fuel limit; not marking any remaining ops as clusterable."; + break; + } + VLOG(2) << "FindCompilationCandidates(): Processing " << node->DebugString(); DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + DeviceToDeviceType(node->assigned_device_name(), &device_type)); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + if (is_compilable_fn && !is_compilable_fn(node, device_type)) { + VLOG(2) << "Compilation rejected node: not compilable " << node->name() + << ": " << node->type_string(); + continue; + } const XlaOpRegistry::DeviceRegistration* registration; CHECK( @@ -234,7 +411,9 @@ Status FindCompilationCandidates( continue; } candidates->insert(node); + --fuel; } + VLOG(2) << "candidates->size() = " << candidates->size(); return Status::OK(); } @@ -244,43 +423,6 @@ struct Cluster { int representative = -1; }; -// Returns a string describing how an edge from src to dst would -// create a cycle. -string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, - int dst) { - int32 max_path_size = graph.num_node_ids() + 1; - std::vector path(max_path_size); - int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); - if (path_size == 0) { - return ""; - } - - auto node_name = [&cycles, &graph](int node_id) { - auto* node = graph.FindNodeId(node_id); - if (node == nullptr) { - return string("(null)"); - } - return node->name(); - }; - - string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); - path.resize(path_size); - for (int32 node_id : path) { - string ascii_art; - if (node_id == dst) { - ascii_art = "+-> "; - } else if (node_id != src) { - ascii_art = "| "; - } else { - ascii_art = "+-- "; - } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); - } - return description; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -314,10 +456,13 @@ Status MarkForCompilationPass::Run( static_cast(flags->tf_xla_auto_jit); } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + bool fusion_only = flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; + VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -340,6 +485,11 @@ Status MarkForCompilationPass::Run( status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; + // Check for fusable ops only if requested. + if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + return false; + } + // Otherwise use the value of global_jit_level. // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag @@ -378,84 +528,13 @@ Status MarkForCompilationPass::RunImpl( : Env::Default(), is_compilable_fn, &compilation_candidates)); - GraphCycles cycles; - for (int i = 0; i < graph->num_node_ids(); ++i) { - // We rely on the node IDs in the cycle detection graph being consecutive - // integers starting from 0. - CHECK_EQ(i, cycles.NewNode()); + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - // Compute the loop structure of the graph. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); - - // The clustering code must avoid adding cycles to the graph to prevent - // deadlock. However, the graph may contain loops, which would trigger the - // cycle detection code. To handle loops, we alter the structure of the cycle - // detection graph, disconnecting each loop from the enclosing graph. - // Specifically, we: - // * add a new "frame" node for each loop. - // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges - // to/from the corresponding frame node. In essence, we collapse the loop - // into a single node for the purpose of cycle detection in the enclosing - // graph. - // * the body of the loop should now be disconnected from the rest of the - // graph; we make it acyclic by breaking loop backedges (edges outgoing from - // "NextIteration" nodes. - - // Map from frame name strings to node IDs in the cycle detection graph. - std::unordered_map frame_nodes; - - // Get the cycle graph node ID for frame 'frame_name', or add one if none - // exists. - auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { - int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; - if (frame_id < 0) { - // The emplace succeeded; we have not allocated a frame node yet. - frame_id = cycles.NewNode(); - } - return frame_id; - }; - - for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); - } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(src, edge->dst()->id())) { - return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); - } - // Drop the original edge. - continue; - } - if (edge->src()->IsNextIteration()) { - // Break loop back-edges. - continue; - } - if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { - // This should never happen. All cycles in the graph should contain - // a control flow operator. - return errors::Internal( - "Found cycle in graph without control flow operator during XLA " - "compilation: ", - DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); - } - } + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -473,6 +552,9 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). while (!worklist.empty()) { int from = worklist.front()->Get().representative; worklist.pop_front(); @@ -544,11 +626,15 @@ Status MarkForCompilationPass::RunImpl( } } - // Count the number of elements in each cluster. - std::vector cluster_sizes(graph->num_node_ids()); + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph->num_node_ids()); for (const Node* n : compilation_candidates) { int cluster = clusters[n->id()].Get().representative; - cluster_sizes[cluster]++; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } } // Names for each cluster. @@ -577,13 +663,16 @@ Status MarkForCompilationPass::RunImpl( // compilation. DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + DeviceToDeviceType(n->assigned_device_name(), &device_type)); const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); - // Or compile if this is a cluster of >= min_cluster_size compilable - // operators. - if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || + // Compile if this is a cluster of >= min_cluster_size compilable operators. + // Also, always compile if the operator is placed on a device that requires + // compilation, or if it contains at least one op that is marked for + // compilation that is not an Identity op. + if (effective_cluster_sizes[cluster] >= min_cluster_size || + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || registration->requires_compilation) { string& name = cluster_names[cluster]; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 1a8858cccef623185709ab5dc2187a313dd130f7..772c92d369e67f431b5d030d1d5cdc5ae2700d39 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,6 +29,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -137,7 +140,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, UnsupportedTypes) { +TEST(XlaCompilationTest, Complex128Unsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -157,6 +160,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, HalfSupported) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Tensor t(DT_HALF, TensorShape()); + t.scalar()() = static_cast(0.0f); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_HALF) + .WithAttr("value", t)); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_FALSE(clusters.empty()); +} + TEST(XlaCompilationTest, ConcatWithConstArg) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -519,11 +543,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.ToString()) - .contains("Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(str_util::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { @@ -553,5 +577,108 @@ TEST(XlaCompilationTest, Retval) { EXPECT_EQ(clusters["A"], clusters["B"]); } +TEST(XlaCompilationTest, DontCountIdentityOps) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + { + auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); + auto b = ops::Identity(root.WithOpName("B"), a); + auto c = ops::Identity(root.WithOpName("C"), b); + auto r = ops::_Retval(root.WithOpName("R"), c, 0); + } + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_TRUE(clusters.empty()); +} + +TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + { + auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); + auto b = ops::Identity(root.WithOpName("B"), a); + b.node()->AddAttr(kXlaCompileAttr, true); + auto r = ops::_Retval(root.WithOpName("R"), b, 0); + } + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_TRUE(clusters.empty()); +} + +TEST(XlaCompilationTest, ConstOp) { + // valid data type + { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + auto c = ops::Const(root.WithOpName("const"), 0.5f); + c.node()->AddAttr(kXlaCompileAttr, true); + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + EXPECT_EQ(1, GetClusters(*graph).size()); + } + + // invalid data type + { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + auto c = ops::Const(root.WithOpName("const"), string("string")); + c.node()->AddAttr(kXlaCompileAttr, true); + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + EXPECT_TRUE(GetClusters(*graph).empty()); + } +} + +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index e5787ca4c8cff436e4404b8488970248b24a5eda..c9e46bc1475aed0e35a48765ad70eef4362e8281 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -17,17 +17,3 @@ cc_library( deps = ["//tensorflow/core:framework"], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b43dab790e6cda5e85688bdacf48a35adc4..f2473d98ffd5dae55983e601b8d2d65af6a6d54c 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..7c8c04152d2f3a0fd46711df24756b7e68b967ea --- /dev/null +++ b/tensorflow/compiler/jit/producer_consumer_queue.h @@ -0,0 +1,132 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ +#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ + +#include +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// A thread-safe, first-in-first-out queue. +template +class ProducerConsumerQueue { + public: + ProducerConsumerQueue() + : capacity_(std::numeric_limits::max()) {} + ~ProducerConsumerQueue() = default; + + // Wait until the queue is non-full, then append a copy of v. + void Put(const T &v); + + // Wait until the queue is non-empty, then remove and return the head value. + T Get(); + + // If the queue is non-empty, remove the head value, placing it in *pv, and + // return true; otherwise return false. + bool TryGet(T *pv); + + // Set the capacity of the queue; the queue is full whenever count() >= + // capacity(). The initial value is the maximum size_t. Requires size > 0. + void set_capacity(std::size_t size); + + // Return the capacity of the queue. + std::size_t capacity() const; + + // Return the number of elements in the queue. + std::size_t count() const; + + // Implementation details follow. Clients should ignore. + private: + mutable tensorflow::mutex mu_; // protects all fields below + tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); + tensorflow::condition_variable non_full_ GUARDED_BY(mu_); + std::size_t capacity_ GUARDED_BY(mu_); + std::deque queue_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); +}; + +// ------------------------------------------------------ +// Implementation details follow. Clients should ignore. + +// Wait until the queue is non-full, then append a copy of v. +template +void ProducerConsumerQueue::Put(const T &v) { + mutex_lock lock(mu_); + while (queue_.size() >= capacity_) { + non_full_.wait(lock); + } + queue_.push_back(v); + non_empty_.notify_one(); +} + +// Wait until the queue is non-empty, then remove and return the head value. +template +T ProducerConsumerQueue::Get() { + mutex_lock lock(mu_); + while (queue_.empty()) { + non_empty_.wait(lock); + } + non_full_.notify_one(); + T result_value = queue_.front(); + queue_.pop_front(); + return result_value; +} + +// If the queue is non-empty, remove the head value, placing it in *pv, and +// return true; otherwise return false. +template +bool ProducerConsumerQueue::TryGet(T *pv) { + mutex_lock lock(mu_); + bool got_element = !queue_.empty(); + if (got_element) { + non_full_.notify_one(); + *pv = queue_.front(); + queue_.pop_front(); + } + return got_element; +} + +// Set the capacity of the queue; the queue is full whenever count() >= +// capacity(). The initial value is the maximum size_t. Requires size > 0. +template +void ProducerConsumerQueue::set_capacity(std::size_t size) { + mutex_lock lock(mu_); + CHECK_NE(size, 0); + capacity_ = size; + non_full_.notify_all(); +} + +// Return the capacity of the queue. +template +std::size_t ProducerConsumerQueue::capacity() const { + mutex_lock lock(mu_); + std::size_t max_elements = capacity_; + return max_elements; +} + +// Return the number of elements in the queue. +template +std::size_t ProducerConsumerQueue::count() const { + mutex_lock lock(mu_); + std::size_t num_elements = queue_.size(); + return num_elements; +} +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f61260c6e52756ee039829afdc7452f5f760c221 --- /dev/null +++ b/tensorflow/compiler/jit/producer_consumer_queue_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/producer_consumer_queue.h" + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +typedef ProducerConsumerQueue IntQueue; + +// Insert integers between low inclusive and high exclusive into q. +void PushRange(IntQueue *q, int low, int high) { + while (low != high) { + q->Put(low); + VLOG(2) << "Pushing " << low; + ++low; + } +} + +// Push the numbers between 0 and 999 inclusive from several threads in the +// pool. +void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { + VLOG(1) << "Adding 20-36"; + pool->Schedule([queue] { PushRange(queue, 20, 36); }); + VLOG(1) << "Adding 7-20"; + pool->Schedule([queue] { PushRange(queue, 7, 20); }); + VLOG(1) << "Adding 36-501"; + pool->Schedule([queue] { PushRange(queue, 36, 501); }); + VLOG(1) << "Adding 501-1000"; + pool->Schedule([queue] { PushRange(queue, 501, 1000); }); + VLOG(1) << "Adding 0-5"; + pool->Schedule([queue] { PushRange(queue, 0, 5); }); + VLOG(1) << "Adding 5-7"; + pool->Schedule([queue] { PushRange(queue, 5, 7); }); +} + +// Pop elements from queue using Get(). Make sure that exactly elements +// were present and their values are all integers between 0 and high-1 +// inclusive. +void GetRange(IntQueue *queue, int high) { + VLOG(1) << "Testing Wait"; + std::vector results; + for (int i = 0; i != high; ++i) { + int r = queue->Get(); + VLOG(2) << "Waited and got " << r; + results.push_back(r); + } + CHECK_EQ(queue->count(), 0); + std::sort(results.begin(), results.end()); + for (int i = 0; i != high; ++i) { + CHECK(results[i] == i); + } +} + +// Pop elements from queue using TryGet(). Make sure that exactly +// elements were present and their values are all integers between 0 and high-1 +// inclusive. +void TryGetRange(IntQueue *queue, int high) { + std::vector results; + // Give up if we don't get all the elements back from the queue + // in 10 seconds. + int timeout = 10; + int r; + for (int i = 0; i != high; ++i) { + while (!queue->TryGet(&r)) { + if (!timeout--) { + LOG(FATAL) << "Can't find all elements in the queue"; + } + VLOG(1) << "Sleeping for a second..."; + sleep(1); + } + VLOG(2) << "Popped " << r; + results.push_back(r); + } + CHECK_EQ(queue->count(), 0); + CHECK(!queue->TryGet(&r)); + std::sort(results.begin(), results.end()); + for (int i = 0; i != high; ++i) { + CHECK_EQ(i, results[i]); + } +} + +const int kNumThreads = 15; + +TEST(ProducerConsumerQueue, GetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + PushRanges(&queue, &pool); + } + GetRange(&queue, 1000); +} + +TEST(ProducerConsumerQueue, TryGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + PushRanges(&queue, &pool); + } + TryGetRange(&queue, 1000); +} + +TEST(ProducerConsumerQueue, ParallelGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + pool.Schedule([&queue] { GetRange(&queue, 1000); }); + PushRanges(&queue, &pool); + } +} + +TEST(ProducerConsumerQueue, ParallelTryGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); + PushRanges(&queue, &pool); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference_helpers.cc b/tensorflow/compiler/jit/shape_inference_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9cfa16526bc5d809942a35e86075b4ec6e88a59 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_helpers.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains helpers for use in shape inference. + +#include "tensorflow/compiler/jit/shape_inference_helpers.h" + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +Status BackEdgeHelper::Remove(Graph* graph) { + if (graph_ != nullptr) { + return errors::Internal("BackEdgeHelper duplicate call to Remove."); + } + graph_ = graph; + for (Node* n : graph_->nodes()) { + if (n->IsMerge()) { + for (const Edge* e : n->in_edges()) { + if (e->src()->IsNextIteration()) { + back_edges_.push_back( + BackEdge{e, e->src(), e->src_output(), e->dst(), e->dst_input()}); + } + } + } + } + for (const BackEdge& be : back_edges_) { + graph_->RemoveEdge(be.edge); + } + return Status::OK(); +} + +const std::vector& BackEdgeHelper::RemovedEdges() + const { + return back_edges_; +} + +Status BackEdgeHelper::Replace() { + if (graph_ == nullptr) { + return errors::Internal("BackEdgeHelper Replace called before Remove."); + } + if (replaced_) { + return errors::Internal("BackEdgeHelper Replace called more than once."); + } + replaced_ = true; + for (const BackEdge& be : back_edges_) { + graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference_helpers.h b/tensorflow/compiler/jit/shape_inference_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..2f053c9a45dd47ca1b056634d2248d6181e77d68 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_helpers.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Helper class to temporarily remove, then replace, the back edges in a +// graph. Simple algorithms for shape inference don't work with cycles, and this +// class can be used to remove cycles before running inference and replace them +// after. Correct usage requires exactly one call to Remove(), followed by any +// number of calls to RemovedEdges() and at most one call to Replace(). The call +// to Replace() is optional if the graph will be discarded without being +// executed, e.g., if it is being used purely for a shape inference pass. +class BackEdgeHelper { + public: + struct BackEdge { + const Edge* edge; + Node* src; + int src_output; + Node* dst; + int dst_input; + }; + + BackEdgeHelper() = default; + // Disallows copy and assign. + BackEdgeHelper(const BackEdgeHelper& other) = delete; + BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete; + + // Temporarily removes all the back edges in graph. + Status Remove(Graph* graph); + + // Gets the list of removed edges. + const std::vector& RemovedEdges() const; + + // Replaces the back edges removed by a prior call to Remove. + Status Replace(); + + private: + Graph* graph_ = nullptr; // not owned + std::vector back_edges_; + // Set once Replace has been called. + bool replaced_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..05b7821b8865d0f210ca9af92370e177d6043e80 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; + +namespace { +// Returns a string describing how an edge from src to dst would +// create a cycle. +string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, + int dst) { + int32 max_path_size = graph.num_node_ids() + 1; + std::vector path(max_path_size); + int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data()); + if (path_size == 0) { + return ""; + } + + auto node_name = [cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + + string description; + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); + path.resize(path_size); + for (int32 node_id : path) { + string ascii_art; + if (node_id == dst) { + ascii_art = "+-> "; + } else if (node_id != src) { + ascii_art = "| "; + } else { + ascii_art = "+-- "; + } + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + } + return description; +} + +bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } + +} // namespace + +Status DeviceToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +bool HasForwardedRefInput(const Node& node) { + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input " + << incoming_node->name() << " " << incoming_node->type_string(); + return true; + } + } + } + return false; +} + +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles->NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles->NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + int dst = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(edge->src()->id(), dst)) { + return errors::Internal( + "Cycle detected when adding enter->frame edge: ", + DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + } + continue; + } + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + int src = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(src, edge->dst()->id())) { + return errors::Internal( + "Cycle detected when adding frame->exit edge: ", + DescribeCycle(cycles, *graph, src, edge->dst()->id())); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation: ", + DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bcce082aaf6044ff0654efa4d78c0f493a350d00 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + +using OrderedNodeSet = std::set; + +// Returns the DeviceType corresponding to 'device'. +Status DeviceToDeviceType(const string& device, DeviceType* device_type); + +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6d854a920eb0b4c01b09024ceaef5035e847d392..7ed609c43748062656b631243c01d790519c54fd 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -92,113 +92,88 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( } Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, OpKernelContext* ctx, + const NameAttrList& function, const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, Signature* signature) { signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.resize(num_constant_args); - - signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); - - // Inputs are in the order: constants, non-constants, resource variables. - int input_num = 0; - // Use the values of compile time constants in the signature-> - while (input_num < num_constant_args) { - signature->arg_values[input_num] = ctx->input(input_num); - ++input_num; - } - // Add the types and shapes of the remaining arguments. - while (input_num < ctx->num_inputs() - variable_args.size()) { - signature->arg_types.emplace_back(ctx->input_dtype(input_num), - ctx->input(input_num).shape()); - ++input_num; - } - // For variable signatures, use the type and shape of the variable's - // current value. - for (const OptionalTensor& variable : variable_args) { - TF_RET_CHECK(input_num < ctx->num_inputs()); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); + signature->arg_values.reserve(constant_args.size()); + + signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); + + for (int i = 0; i < ctx->num_inputs(); ++i) { + if (constant_args.count(i) > 0) { + // Use the values of compile time constants in the signature. + signature->arg_values.push_back(constant_args.at(i)); + } else if (variable_args.count(i) > 0) { + const OptionalTensor& variable = variable_args.at(i); + if (variable.present) { + signature->arg_types.emplace_back(variable.value.dtype(), + variable.value.shape()); + } else { + signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + } } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + signature->arg_types.emplace_back(ctx->input_dtype(i), + ctx->input(i).shape()); } - ++input_num; } return Status::OK(); } namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. The first `num_constant_args` arguments must be host-memory Tensors. -Status BuildArguments(int num_constant_args, - const std::vector& variable_args, +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. +Status BuildArguments(const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, std::vector* args) { args->resize(ctx->num_inputs()); - int input_num = 0; - - // Handles compile-time constants. - TF_RET_CHECK(num_constant_args <= ctx->num_inputs()); - while (input_num < num_constant_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - XlaCompiler::Argument& arg = (*args)[input_num]; - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - ++input_num; - } - - // Handles the non-constant arguments. - int num_variable_args = variable_args.size(); - int num_nonconst_args = - ctx->num_inputs() - num_variable_args - num_constant_args; - TF_RET_CHECK(num_nonconst_args >= 0); - while (input_num < num_constant_args + num_nonconst_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { XlaCompiler::Argument& arg = (*args)[input_num]; - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - ++input_num; - } - - // Handles resource variables. - TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs()); - for (int variable_id = 0; variable_id < num_variable_args; ++variable_id) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - - XlaCompiler::Argument& arg = (*args)[input_num]; - - arg.name = variable_args[variable_id].name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable_args[variable_id].present) { - const Tensor& value = variable_args[variable_id].value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do permit - // uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } } - ++input_num; } return Status::OK(); @@ -233,16 +208,43 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - int num_constant_args, const std::vector& variable_args, - OpKernelContext* ctx, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options) { + return CompileImpl(options, function, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, false); +} + +Status XlaCompilationCache::CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { + const NodeDef& def = ctx->op_kernel().def(); + NameAttrList name; + name.set_name(def.op()); + *name.mutable_attr() = def.attr(); + return CompileImpl(options, name, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, true); +} + +Status XlaCompilationCache::CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << num_constant_args + << " num_constant_args=" << constant_args.size() << " num_variable_args=" << variable_args.size(); for (int i = 0; i < ctx->num_inputs(); i++) { TensorShape shape = ctx->input(i).shape(); @@ -250,10 +252,12 @@ Status XlaCompilationCache::Compile( << " present=" << ctx->has_input(i) << " shape=" << shape.DebugString(); } - for (const OptionalTensor& variable : variable_args) { + for (auto& iterator : variable_args) { + const OptionalTensor& variable = iterator.second; VLOG(2) << "variable present=" << variable.present << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString(); + << " shape=" << variable.value.shape().DebugString() + << " TF arg= " << iterator.first; } VLOG(2) << "num_outputs = " << ctx->num_outputs(); for (int i = 0; i < ctx->num_outputs(); i++) { @@ -261,11 +265,12 @@ Status XlaCompilationCache::Compile( } } - TF_RET_CHECK(num_constant_args + variable_args.size() <= ctx->num_inputs()); + TF_RET_CHECK(constant_args.size() + variable_args.size() <= + ctx->num_inputs()); Signature signature; - TF_RETURN_IF_ERROR(BuildSignature(function, num_constant_args, variable_args, - ctx, &signature)); + TF_RETURN_IF_ERROR( + BuildSignature(function, constant_args, variable_args, ctx, &signature)); VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not @@ -292,13 +297,20 @@ Status XlaCompilationCache::Compile( // a long time.) std::vector args; TF_RETURN_IF_ERROR( - BuildArguments(num_constant_args, variable_args, ctx, &args)); + BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + + if (compile_single_op) { + entry->compilation_status = compiler.CompileSingleOp( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + signature.name, ctx, args, &entry->compilation_result); + } else { + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); + } } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 0858020716fcf4763e42dc0699ad22cfda756942..be1043d8c3fc0573922837e541615114a6d7a1a5 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -52,29 +52,52 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `num_constant_args` is the number of compile-time constant arguments to - // `function`. `variable_args` is a snapshot of the current values of the + // `constant_args` is a map of tensorflow argument number to its constant + // value. + // `variable_args` is a snapshot of the current values of the // resource variable arguments to `function`; uninitialized variables are // represented by an absent OptionalTensor. // The result of compilation is written to `*compilation_result`, which must // be non-null. If `executable` is non-null, also builds an - // xla::LocalExecutable and sets `executable to point to it. The resulting + // xla::LocalExecutable and sets `executable` to point to it. The resulting // executable pointer may be null if the computation has no non-constant // outputs. Status Compile(const XlaCompiler::Options& options, - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options); + // As above, but calls XlaCompiler::CompileSingleOp instead of + // XlaCompiler::CompileFunction. + Status CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); + xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl(const XlaCompiler::Options& options, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op); + // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. Status BuildExecutable(const XlaCompiler::Options& options, @@ -104,8 +127,9 @@ class XlaCompilationCache : public ResourceBase { static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, + Status BuildSignature(const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, Signature* signature); // The value associated with a cache entry. diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1943d3e1a7e321b5a3796a0c6e4f2b5d9ac7018 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines the XlaCompileOnDemandOp. + +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +namespace { +std::map GetVariables(OpKernelContext* ctx) { + std::map variables; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dtype() == DT_RESOURCE) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& optional = variables[i]; + optional.name = handle.name(); + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + optional.present = true; + optional.value = *variable->tensor(); + } + } + } + return variables; +} +} // namespace + +Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, + const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable) { + std::map variables = GetVariables(ctx); + + xla::LocalClient* client = metadata.client(); + + // Builds an XLA allocator for the device. + XlaComputationLaunchContext launch_context( + client, client->backend().memory_allocator(), true); + + launch_context.PopulateInputs(ctx, result, variables); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + TF_RET_CHECK(stream); + + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(client->backend().memory_allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(ctx->step_id()); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + TF_RETURN_IF_ERROR(run_result.status()); + + launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + return Status::OK(); +} + +bool XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // TODO(jmolloy): This could be expensive, so memoize. + auto* constant_inputs = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + op_kernel->def().op()); + CHECK(constant_inputs); + std::set constant_input_indices; + for (const auto& name : *constant_inputs) { + int start, stop; + TF_CHECK_OK(op_kernel->InputRange(name, &start, &stop)); + for (int i = start; i < stop; ++i) { + constant_input_indices.insert(i); + } + } + return constant_input_indices.count(argument_idx) > 0; +} + +bool XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // Right now we only create kConstant arguments when absolutely required, but + // there may be benefit in eagerly constant-folding a larger subset of + // arguments in the future. + return MustArgumentBeConstant(op_kernel, argument_idx); +} + +Status XlaCompileOnDemandOp::Compile( + OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable) { + std::map constant_arguments; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& device_tensor = ctx->input(i); + if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { + if (xla_tensor->has_host_tensor() && + ShouldArgumentBeConstant(&ctx->op_kernel(), i)) { + constant_arguments[i] = xla_tensor->host_tensor(); + } + } + if (constant_arguments.count(i) == 0 && + MustArgumentBeConstant(&ctx->op_kernel(), i)) { + // Slow path; the argument is not available as a host constant so we must + // fetch it synchronously. + Tensor host_tensor; + AllocatorAttributes attrs; + attrs.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp( + device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); + Notification n; + ctx->op_device_context()->CopyDeviceTensorToCPU( + &device_tensor, "ConstantArgument", + reinterpret_cast(ctx->device()), &host_tensor, + [&](Status status) { n.Notify(); }); + n.WaitForNotification(); + constant_arguments[i] = host_tensor; + } + } + + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + CHECK(rm); + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + *cache = new XlaCompilationCache(metadata.client(), + metadata.jit_device_type()); + return Status::OK(); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + XlaCompiler::Options options; + options.device_type = metadata.jit_device_type(); + options.client = metadata.client(); + options.flib_def = + new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.shape_representation_fn = metadata.shape_representation_fn(); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + + std::map variable_args = GetVariables(ctx); + return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, + result, executable, &compile_options); +} + +void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { + const XlaCompiler::CompilationResult* result; + xla::LocalExecutable* executable; + const XlaDevice::Metadata* metadata; + OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); + OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable)); + OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7cc3d0e007ba2974fbfbe6fbabc4aa08f9fa910f --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaCompileOnDemandOp is an OpKernel that, when its Compute method is +// called, will generate an xla::Computation and run it asynchronously. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// An OpKernel that compiles an op to an XLA computation and runs it. Unlike +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// vanilla TensorFlow op as long as the bridge supports it. +class XlaCompileOnDemandOp : public OpKernel { + public: + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override; + + private: + XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); + bool ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + bool MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable); + Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e238252751e677eb947f6df03e3b2f2e948ffe19..43648402f65c656b6b4eb2e83e61ce45f1c73669 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,6 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -34,14 +36,26 @@ class XlaCpuDeviceFactory : public DeviceFactory { Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); + bool compile_on_demand = flags->tf_xla_compile_on_demand; + + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_CPU_XLA_JIT; + registration.requires_compilation = !compile_on_demand; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT); (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } @@ -50,8 +64,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index d4d8fe1c1d575b4e35d624621cc709e3a16569d5..ed007d603ea1b3d27dd25f00726261cdd029c20c 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -19,9 +19,11 @@ limitations under the License. #include #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" @@ -47,10 +49,9 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" -namespace se = ::perftools::gputools; - namespace tensorflow { // Caches a XlaDeviceAllocator per pair. A @@ -99,34 +100,49 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(backend, device_ordinal); + xla::MakeUnique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; } +namespace { + +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + +} // namespace + /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, bool register_device_for_compilation, - std::unique_ptr* device) { + const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - if (register_device_for_compilation) { - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); - } + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { - return StreamExecutorUtil::ConvertStatus(platform.status()); + return platform.status(); } const DeviceAttributes attrs = Device::BuildDeviceAttributes( @@ -135,17 +151,22 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice(options, attrs, device_ordinal, - DeviceType(jit_device_name), - platform.ValueOrDie())); + device->reset(new XlaDevice( + options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + shape_representation_fn_(std::move(shape_representation_fn)), + padded_shape_fn_(std::move(padded_shape_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -162,6 +183,7 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { + *metadata = nullptr; XlaDevice* xla_device = dynamic_cast(ctx->device()->UnderlyingDevice()); if (xla_device == nullptr) { @@ -175,17 +197,29 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice(const SessionOptions& options, - const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform) +XlaDevice::XlaDevice( + const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_(device_ordinal, platform, jit_device_name, + shape_representation_fn, padded_shape_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), - platform_(platform) {} + platform_(platform), + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(shape_representation_fn) { + VLOG(1) << "Created XLA device " << jit_device_name; +} -XlaDevice::~XlaDevice() {} +XlaDevice::~XlaDevice() { + if (gpu_device_info_ != nullptr) { + gpu_device_info_->default_context->Unref(); + } +} xla::LocalClient* XlaDevice::client() const { // We lazily create the client because the platform commits to the @@ -193,9 +227,8 @@ xla::LocalClient* XlaDevice::client() const { // don't want to do it until we get a chance to hook the platform up // to a simulator. - // For now GetOrCreateLocalClient always returns success when passed - // a non-null platform. If that changes we may have to plumb in some - // way to pass Status back. + // TODO(b/78468222): This can fail, at least when the backend is GPU and + // there is no GPU on the host. return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); } @@ -220,12 +253,33 @@ xla::StatusOr XlaDevice::GetStream() { return stream_.get(); } +Status XlaDevice::CreateAndSetGpuDeviceInfo() { + if (gpu_device_info_ == nullptr) { + TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); + // Call GetAllocator for the side-effect of ensuring the allocator + // is created. + GetAllocator({}); + // XlaDevice owns both gpu_device_info_ and + // gpu_device_info_->default_context. + gpu_device_info_ = MakeUnique(); + gpu_device_info_->stream = stream; + gpu_device_info_->default_context = new XlaDeviceContext( + stream, client(), transfer_as_literal_, shape_representation_fn_); + set_tensorflow_gpu_device_info(gpu_device_info_.get()); + } + + return Status::OK(); +} + Status XlaDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - auto ctx = new XlaDeviceContext(stream); + // Call GetAllocator for the side-effect of ensuring the allocator is created. + GetAllocator({}); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, + shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -238,11 +292,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When TraceMe profiling is off (which is the default), the - // following TraceMe constructor is simply a conditional test of - // false value. Measurements show that its overhead is negligible. - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + // When Xprof profiling is off (which is the default), constructing the + // activity is simple enough that its overhead is negligible. + tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -250,8 +303,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } @@ -273,7 +326,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream); + XlaTransferManager manager(stream, client(), transfer_as_literal_, + shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; @@ -288,19 +342,23 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { + // Any op assigned to the device that isn't rewritten by the graph rewriter + // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes + // it just-in-time. + kernel_factory::OpKernelRegistrar::Factory factory = + [](OpKernelConstruction* context) -> OpKernel* { + return new XlaCompileOnDemandOp(context); + }; XlaOpRegistry::RegisterCompilationKernels(); XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; - auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { - return new XlaDeviceDummyOp(context); - }; for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( - new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp", - dummy_factory)); + new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", + factory)); } return registrations; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d2ec38293c429f04f088bf3726ba97eb4e4b0dba..02e88ee6793e984a7b782790f8011cbcbc5a5026 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,8 +17,7 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state -// is managed by XLA. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -26,6 +25,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -43,24 +45,37 @@ namespace tensorflow { class XlaDevice : public LocalDevice { public: + // Given a tensor, sets `xla::Shape*` the shape of tensor's representation + // on device, fully padded. On error, the contents of `xla::Shape*` + // are undefined. + typedef std::function PaddedShapeFn; + // Wrapper class to store metadata about the XlaDevice, where it can be // retrieved e.g., when lazily creating the XlaCompilationCache device. class Metadata { public: - Metadata(int device_ordinal, perftools::gputools::Platform* platform, - const DeviceType& device_type); + Metadata(int device_ordinal, se::Platform* platform, + const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn); // The index of the device on this host. int device_ordinal() const; - perftools::gputools::Platform* platform() const; + se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } + const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } private: const int device_ordinal_; const DeviceType device_type_; - perftools::gputools::Platform* platform_; // Not owned. + se::Platform* platform_; // Not owned. + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + PaddedShapeFn padded_shape_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -71,15 +86,28 @@ class XlaDevice : public LocalDevice { // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. - static Status Create(const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - bool register_device_for_compilation, - std::unique_ptr* device); - + // 'transfer_as_literal' is true if device<->host transfers must be done using + // XLA's TransferLiteral{To,From}Device interface. If false, we can use + // ThenMemcpy instead. + // If padded_shape_fn is empty, a default implementation that returns + // the on-host shape is used. + static Status Create( + const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); + + // Creates a new XLA Device. + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - ::perftools::gputools::Platform* platform); + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -96,7 +124,12 @@ class XlaDevice : public LocalDevice { Tensor* tensor) override; xla::LocalClient* client() const; - xla::StatusOr<::perftools::gputools::Stream*> GetStream(); + const Metadata& metadata() { return xla_metadata_; } + xla::StatusOr GetStream(); + + // If not already set, create and set GpuDeviceInfo. + // Not thread-safe + Status CreateAndSetGpuDeviceInfo(); private: // The metadata of this XlaDevice. @@ -104,18 +137,26 @@ class XlaDevice : public LocalDevice { // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - const DeviceType& jit_device_name_; + DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - ::perftools::gputools::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. xla::Backend::StreamPtr stream_; + // Must we use XLA's transfer manager for correct host<->device transfers? if + // false, we can use ThenMemcpy() instead. + bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // If set, holds default device context (that we must Unref) + // and its stream. + std::unique_ptr gpu_device_info_; }; -// Builds dummy OpKernel registrations on 'device' for the JIT operators +// Builds OpKernel registrations on 'device' for the JIT operators // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations // object that encapsulates the kernel registrations. struct XlaDeviceOpRegistrations { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index c936222f32056e92efced82d5adb3a96c8041a17..71e63b110b3b132a57fc291e53a165954c72a03c 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,44 +15,89 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/mem.h" -namespace se = ::perftools::gputools; - namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator(const xla::Backend* backend, - int device_ordinal) - : backend_(backend), device_ordinal_(device_ordinal) {} - +XlaDeviceAllocator::XlaDeviceAllocator() {} XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { - se::DeviceMemoryBase dmem = - backend_->memory_allocator() - ->Allocate(device_ordinal_, num_bytes, /*retry_on_failure=*/false) - .ValueOrDie(); - VLOG(2) << "Allocated XLA device tensor " << dmem.opaque() << "(" << num_bytes - << ")"; - return dmem.opaque(); + // We always return an empty XlaTensor object, encoded as an opaque tagged + // pointer. We can return an empty object and ignore num_bytes here because we + // have control over all of the uses of this device tensor, and can lazily + // allocate memory when used. This allows us to also know the shape of the + // allocated Tensor, which is useful if the device's tensor representation + // differs from the host. + return XlaTensor::ToOpaquePointer(new XlaTensor()); } void XlaDeviceAllocator::DeallocateRaw(void* ptr) { - se::DeviceMemoryBase dmem(ptr); - TF_CHECK_OK(backend_->memory_allocator()->Deallocate(device_ordinal_, &dmem)); - VLOG(2) << "Deallocated XLA device tensor " << ptr; + delete XlaTensor::FromOpaquePointer(ptr); } void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream) : stream_(stream) {} +XlaTransferManager::XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : stream_(stream), + client_(client), + transfer_manager_(client->backend().transfer_manager()), + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(std::move(shape_representation_fn)) { + if (!shape_representation_fn_) { + shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { + return shape; + }; + } +} + +Status XlaTransferManager::TransferLiteralToDevice( + const Tensor& host_tensor, Tensor* device_tensor) const { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + xla::BorrowingLiteral literal( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + + const xla::ShapedBuffer& shaped_buffer = + XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + << shaped_buffer.ToString(); + return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, + shaped_buffer); +} + +Status XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor) const { + const xla::ShapedBuffer& shaped_buffer = + XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + transfer_manager_->TransferLiteralFromDevice( + stream_->parent(), shaped_buffer)); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " + << shaped_buffer.ToString(); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + return errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return Status::OK(); +} void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -64,22 +109,50 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << " " << reinterpret_cast( device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements(); + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); - void* dst_ptr = DMAHelper::base(device_tensor); - se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); + + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + CHECK(xla_tensor); + + TensorShape shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); + if (!xla_tensor->has_shaped_buffer()) { + Status s = xla_tensor->AllocateShapedBuffer( + device_tensor->dtype(), shape, client_, + stream_->parent()->device_ordinal()); + if (!s.ok()) { + done(s); + return; + } + } Status status; - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } + xla_tensor->set_host_tensor(*cpu_tensor); done(status); return; @@ -100,21 +173,27 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_tensor->tensor_data().data()) << " " << reinterpret_cast(cpu_tensor->tensor_data().data()) - << device_tensor->NumElements(); + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); const int64 total_bytes = cpu_tensor->TotalBytes(); - void* src_ptr = const_cast(DMAHelper::base(device_tensor)); - se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); + se::DeviceMemoryBase dev_src_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); void* dst_ptr = DMAHelper::base(cpu_tensor); Status status; - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); + } else { + stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } done(status); @@ -125,7 +204,47 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream) : manager_(stream) {} +void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + // TODO(phawkins): replace this code with an asynchronous implementation. + auto body = [&]() { + if (src_tensor.NumElements() == 0) { + return Status::OK(); + } + XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); + XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); + CHECK(xla_src && xla_dst) + << "Missing destination tensor for device-to-device copy"; + if (!xla_dst->has_shaped_buffer()) { + TensorShape shape = + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_RETURN_IF_ERROR( + xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, + stream_->parent()->device_ordinal())); + } + TF_RETURN_IF_ERROR( + xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const se::DeviceMemoryBase& from_buffer = + xla_src->shaped_buffer().buffers().element(index); + CHECK_EQ(buffer->size(), from_buffer.size()); + if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, + buffer->size())) { + return errors::Internal("Device to device memcpy failed"); + } + return Status::OK(); + })); + return Status::OK(); + }; + done(body()); +} + +XlaDeviceContext::XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(stream, client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -142,4 +261,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, done); } +void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index c4edcd474e48f791af9340c3cd6e4d031407bb68..ee346e5653bbf9f393df202572c2150b4989506f 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -26,11 +28,12 @@ limitations under the License. namespace tensorflow { -// The allocator used for Tensors assigned to the XLA device. It uses -// XLA backend's allocator. +// The allocator used for Tensors assigned to the XLA device. The allocator +// ignores the alignment and size of the request and always returns a new, +// empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(const xla::Backend* backend, int device_ordinal); + XlaDeviceAllocator(); ~XlaDeviceAllocator() override; string Name() override; @@ -38,30 +41,42 @@ class XlaDeviceAllocator : public Allocator { void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; void GetStats(AllocatorStats* stats) override; - - private: - // Which backend in the client this allocator belongs to. - const xla::Backend* backend_; - // Which hardware device in the client's backend this allocator belongs to. - const int device_ordinal_; }; // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(perftools::gputools::Stream* stream); + explicit XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); - perftools::gputools::Stream* stream() const { return stream_; } + + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + + se::Stream* stream() const { return stream_; } private: + Status TransferLiteralToDevice(const Tensor& host_tensor, + Tensor* device_tensor) const; + Status TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor) const; + // Stream obtained from a Device, used to transfer tensors between // CPU and device. - perftools::gputools::Stream* stream_; + se::Stream* stream_; + // For the underlying memory allocator and XLA's TransferManager. + xla::LocalClient* client_; + // Transfer manager, for marshalling data to and from the device. + xla::TransferManager* transfer_manager_; + // True if we must use XLA's TransferManager for correct device transfers. + const bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -69,7 +84,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(perftools::gputools::Stream* stream); + explicit XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -77,9 +94,10 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; - perftools::gputools::Stream* stream() const override { - return manager_.stream(); - } + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + + se::Stream* stream() const override { return manager_.stream(); } private: XlaTransferManager manager_; diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index f68dba6b6a26c0c289fd8457ad143d62e5fb9a69..5ecb1afa7bcec910ca843ccd3a782745f2bb6ca8 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_ops.h" +#include + #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_tensor.h" namespace tensorflow { @@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { << type_string() << " on an XLA device. This should never happen."; } +XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); +} + +void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(), + errors::InvalidArgument( + "Variable and value dtypes don't match; respectively, ", + dtype_, " and ", context->input(1).dtype()), + done); + Var* variable = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + LookupOrCreateResource( + context, HandleFromInput(context, 0), &variable, + [this, context](Var** ptr) { + *ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + TF_RETURN_IF_ERROR(context->allocate_persistent( + dtype_, context->input(1).shape(), &unused, &tmp, attr)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + }), + done); + core::ScopedUnref s(variable); + + OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_)), + done); + + const Tensor& value = context->input(1); + AllocatorAttributes attr; + + // Copying is unnecessary if we are the last user of the value tensor, we can + // just adopt the input tensor's buffer instead. + std::unique_ptr input_alias = context->forward_input( + 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_, + value.shape(), DEVICE_MEMORY, attr); + mutex_lock ml(*variable->mu()); + variable->is_initialized = true; + if (input_alias) { + *variable->tensor() = *input_alias; + done(); + return; + } + + // Need to copy, but maybe we can re-use variable's buffer? + if (!XlaTensor::RefCountIsOne(*variable->tensor()) || + !variable->tensor()->shape().IsSameSize(value.shape())) { + // Copy to new buffer + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_persistent(dtype_, value.shape(), + &unused, &tmp, attr), + done); + *variable->tensor() = *tmp; + } + + XlaDeviceContext* device_context = + static_cast(context->op_device_context()); + + variable->Ref(); + device_context->CopyDeviceTensorToDevice( + value, variable->tensor(), [context, variable, done](Status status) { + variable->Unref(); + context->SetStatus(status); + done(); + }); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 498d25cf566a91f68e5eb1ac312e17900471aeca..11e45d2823da2b623bd3cd45f7147686b05fdb2f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,16 +23,19 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -40,8 +43,17 @@ class XlaDeviceDummyOp : public OpKernel { void Compute(OpKernelContext* ctx) override; }; +class XlaAssignVariableOp : public AsyncOpKernel { + public: + explicit XlaAssignVariableOp(OpKernelConstruction* c); + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + DataType dtype_; +}; + #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ @@ -63,13 +75,77 @@ class XlaDeviceDummyOp : public OpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ + IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ - ResourceHandleOp); + ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ + ReadVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ + XlaAssignVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); \ + REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ + SwitchOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..74257b09a808a39454eace3b1a9bf57a2e071360 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -0,0 +1,328 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +namespace tensorflow { + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "ShapeN" || + node.type_string() == "Rank" || node.type_string() == "Size"; +} + +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", + "StopGradient", /*"Snapshot",*/ + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + +Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) { + VLOG(2) << "Here at fusion optimizer"; + + // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. + // Once that happens, the expected interaction between this optimizer and when + // the global_jit_level is set is as follows: Fusion optimizer will replace + // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can + // be further compiled where possible via mark_for_compilation_pass. Note that + // this might lead to inefficient clustering, and it is best to use either the + // fusion optimizer or the global_jit flag, and not combine the two. + + // Create a Graph out of GraphDef. This is required currently because the + // helpers around clustering, encapsulation etc work on graphs. + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + Graph graph(function_library); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + shape_refiner.set_disable_constant_propagation(true); + ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, when + // colocation constraints are already validated previously and the device + // placement of nodes has also completed, so there is no need to validate + // colocation constraints again. + options.validate_colocation_constraints = false; + options.validate_shape = false; + TF_RETURN_IF_ERROR( + ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + + // Collect nodes that can be fused via XLA, while ignoring those that + // explicitly ask for XLA: (*) nodes that are marked to be compiled + // explicitly. (*) nodes assigned to XLA device. + OrderedNodeSet compilation_candidates; + for (Node* node : graph.op_nodes()) { + // If there is a _XlaCompile annotation, ignore the node if it is + // true. Nodes are marked with this attr via experimental_jit_scope, and + // will be handled by the mark_for_compilation pass. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok() && compile) { + continue; + } + // If there is already a _XlaCluster annotation, ignore the node. Nodes are + // marked with this attr to indicate they are already part of a cluster and + // hence ignored. + status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); + if (status.ok()) { + continue; + } + + // If there is an explicit XLA device placement, ignore the node. + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); + if (device_type.type_string().find("XLA") != string::npos) continue; + + // Assume all fusable ops are registered. + // TODO(hpucha): Check for registration if possible. + if (!IsXlaFusable(node->def())) { + continue; + } + + // XLA does not offer guaranteed aliasing between the input and output of + // the XLA cluster so it can't implement the forward-tensor-ref semantic. + // Leave such nodes out of XLA clusters. + if (HasForwardedRefInput(*node)) { + continue; + } + + compilation_candidates.insert(node); + } + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + *output = item.graph; + return Status::OK(); + } + + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + + // TODO(hpucha): Make clustering more robust. There are two known issues that + // we need to mitigate: (a) Non-resource variables can cause deadlocks + // when clustering changes order of execution. See b/77263461 for a specific + // example. (b) Queue operations can also cause deadlocks. See b/77261498 for + // example. + + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + }; + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters(graph.num_node_ids()); + std::deque*> worklist; + for (Node* node : compilation_candidates) { + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); + worklist.push_back(&clusters[node->id()]); + } + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. This is a simplified + // version of the clustering in mark_for_compilation_pass that also deals with + // nodes that are explicitly tagged to be compiled/clustered. + while (!worklist.empty()) { + int from = worklist.front()->Get().representative; + worklist.pop_front(); + + Node* node_from = graph.FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + for (int to : cycles.Successors(from)) { + if (to >= graph.num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph.FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + // Do not cluster across devices. + if (node_from->def().device() != node_to->def().device()) { + VLOG(2) << "Devices " << node_from->def().device() << " " + << node_to->def().device(); + VLOG(2) << "Device names " << node_from->assigned_device_name() << " " + << node_to->assigned_device_name(); + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph.num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } + } + + const int min_cluster_size = 2; + int num_clusters = 0; + for (auto size : effective_cluster_sizes) { + if (size >= min_cluster_size) { + VLOG(3) << "Cluster " << num_clusters << " " << size; + num_clusters++; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + // Sequence number generator to ensure clusters have unique names. + static std::atomic cluster_sequence_num; + + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + + // Compile if this is a cluster of >= min_cluster_size compilable operators. + if (effective_cluster_sizes[cluster] >= min_cluster_size) { + string& name = cluster_names[cluster]; + + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + graph.ToGraphDef(output); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2309e782d38725f8db025fbfda0bf0f63d18be --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { + +// Optimizes graphs by fusing ops where possible, resulting in more efficient +// execution. +class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { + public: + XlaFusionOptimizer() {} + ~XlaFusionOptimizer() override {} + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { + return Status::OK(); + } + + string name() const override { return "xla-fusion"; }; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) override; + + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, + const GraphDef& optimize_output, double result) override { + // Nothing to do for XlaFusionOptimizer. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5736760a878dc857a8558093054d0adc0f727398 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +class XlaFusionOptimizerTest : public grappler::GrapplerTest { + protected: + std::unordered_map GetClusters(const GraphDef& graph) { + std::unordered_map ids; + for (const NodeDef& node : graph.node()) { + string cluster; + if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node.name()] = cluster; + } + } + return ids; + } +}; + +TEST_F(XlaFusionOptimizerTest, Chains) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, FusableOps) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(clusters["C"], clusters["E"]); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp( + "Add", a, b, + builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + ops::UnaryOp("Cos", e, + builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, CompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 2326070358d67c0cf30ef17fab5c93862cd8932c..c0d86a28c7698c302e28bab972bb2f847cc00ca4 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -34,19 +34,37 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_GPU_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); (void)registrations; std::unique_ptr device; - Status status = XlaDevice::Create( - "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device); + Status status = + XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; return Status::OK(); } + + // TODO(b/78468222): Uncomment after fixing this bug + // status = device->CreateAndSetGpuDeviceInfo(); + // if (!status.ok()) { + // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, + // " device"); + // return status; + // } + devices->push_back(device.release()); return Status::OK(); } @@ -55,8 +73,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 2614deefd8823dcb8f38e9e22ae4e78145d0d96a..661187f4a873b03b8d013aa74cb6b6315bb4e2eb 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -25,8 +25,8 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; +constexpr std::array kExecAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: @@ -41,10 +41,19 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, - options, name_prefix, /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, + DEVICE_INTERPRETER_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0c7a9365125708b2af43f87c7617d8d84050a61 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -0,0 +1,296 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_launch_util.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { +namespace { +using xla::ScopedShapedBuffer; +using xla::ShapedBuffer; +} // anonymous namespace + +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables) { + std::map snapshot; + for (int i : variables) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + tensor.name = handle.name(); + tensor.present = true; + tensor.value = *variable->tensor(); + } + } + return snapshot; +} + +XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) + : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} + +XlaAllocator::~XlaAllocator() {} + +xla::StatusOr XlaAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + AllocationAttributes attrs; + attrs.no_retry_on_failure = !retry_on_failure; + void* data = + wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); + if (data == nullptr) { + return errors::ResourceExhausted("Out of memory while trying to allocate ", + size, " bytes."); + } + return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), + device_ordinal, this); +} + +Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { + wrapped_->DeallocateRaw(mem.opaque()); + return Status::OK(); +} + +namespace internal { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +ScopedShapedBuffer ExtractSubShapedBuffer( + ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator) { + const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer->on_host_shape(), index); + const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer->on_device_shape(), index); + + ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, + shaped_buffer->platform(), + shaped_buffer->device_ordinal()); + + auto& shape_tree = shaped_buffer->buffers(); + auto& sub_shape_tree = sub_shaped_buffer.buffers(); + sub_shape_tree.CopySubtreeFrom(shape_tree, + /*source_base_index=*/{index}, + /*target_base_index=*/{}); + shape_tree.ForEachMutableElement( + [index](const xla::ShapeIndex& shape_index, + tensorflow::se::DeviceMemoryBase* data) { + // shape_index is empty for the root node. Ignore that. + if (!shape_index.empty() && shape_index[0] == index) { + *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); + } + }); + return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); +} +} // namespace internal +using internal::ExtractSubShapedBuffer; + +XlaComputationLaunchContext::XlaComputationLaunchContext( + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), + xla_allocator_(xla_allocator), + allocate_xla_tensors_(allocate_xla_tensors) {} + +void XlaComputationLaunchContext::PopulateInputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + const std::map& variables) { + // Build ShapedBuffers that point directly to the Tensor buffers. + arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); + arg_buffers_.resize(kernel->xla_input_shapes.size()); + arg_ptrs_ = std::vector(arg_buffers_.size()); + + // Pass remaining parameters. + const Tensor* t; + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int arg_num = kernel->input_mapping[i]; + const xla::Shape& shape = kernel->xla_input_shapes[i]; + if (variables.count(arg_num)) { + t = &(variables.at(arg_num).value); + CHECK(t); + } else { + t = &(ctx->input(arg_num)); + } + + const xla::Shape on_device_shape = + client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); + if (xla::ShapeUtil::IsTuple(on_device_shape)) { + const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); + arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); + } else { + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); + arg_buffers_[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, + client_->platform(), client_->default_device_ordinal()); + arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); + arg_ptrs_[i] = arg_buffers_[i].get(); + } + } +} + +void XlaComputationLaunchContext::PopulateOutputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + ScopedShapedBuffer output) { + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + // Computation output should always be a tuple. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape (on device): " + << output.on_device_shape().DebugString(); + } + CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + + // Copy XLA results to the OpOutputList. + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); + if (kernel->outputs[i].is_constant) { + // Output is a constant. + const Tensor& const_tensor = kernel->outputs[i].constant_value; + Tensor* output_tensor; + const size_t total_bytes = const_tensor.TotalBytes(); + if (stream && total_bytes > 0) { + // Copy host -> device. (Empty tensors don't have backing buffers.) + // Manually allocate memory using an XlaTensorBuffer so we can allocate + // as much memory as the device requires (as given by + // GetByteSizeRequirement). This avoids XlaTransferManager having to + // reallocate the device buffer later. + VLOG(1) << "Constant output tensor on device"; + + OP_REQUIRES_OK( + ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + + Device* device = dynamic_cast(ctx->device()); + OP_REQUIRES(ctx, device != nullptr, + errors::Internal("DeviceBase was not a Device.")); + ctx->op_device_context()->CopyCPUTensorToDevice( + &const_tensor, device, output_tensor, + [&](Status status) { TF_CHECK_OK(status); }); + + if (device->device_type() == DEVICE_GPU) { + // The GPUDeviceContext enqueues the host->device transfer in a + // separate stream from the main compute stream. We must ensure the + // compute stream is synchronized with the host->device transfer + // stream now otherwise we will create a race condition. + auto* gpu_device_context = + static_cast(ctx->op_device_context()); + gpu_device_context->stream()->ThenWaitFor( + gpu_device_context->host_to_device_stream()); + } + } else { + // No copy required. + ctx->set_output(i, const_tensor); + output_tensor = ctx->mutable_output(i); + } + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + xla_tensor->set_host_tensor(const_tensor); + } + } else { + const TensorShape& shape = kernel->outputs[i].shape; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); + + se::DeviceMemoryBase buffer = output.buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + } else { + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); + } + ++output_num; + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << ctx->mutable_output(i)->DebugString(); + } + } + + // Apply variable updates, if any. + VLOG(2) << "Applying variable updates"; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; + OP_REQUIRES(ctx, + write.input_index >= 0 && write.input_index < ctx->num_inputs(), + errors::Internal("Invalid input index for variable write.")); + + se::DeviceMemoryBase buffer = output.buffer({output_num}); + + Var* variable = nullptr; + // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, + // not a Tensor. + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), + &variable, [this, ctx, &write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + + core::ScopedUnref s(variable); + + mutex_lock ml(*variable->mu()); + OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, + errors::Internal("Mismatched type in variable write")); + + if (allocate_xla_tensors_) { + Tensor output_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); + *variable->tensor() = output_tensor; + } else { + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + write.type, write.shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + *variable->tensor() = output_tensor; + } + ++output_num; + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h new file mode 100644 index 0000000000000000000000000000000000000000..4390701ccbd0bc3971413ddcd917c11019990087 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for launching compiled XLA kernels for a KernelContext. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class XlaAllocator; + +// Takes a snapshot of the values of resource variable arguments, whose +// indices are specified in `variables` argument. We snapshot tensors that back +// resource variables since concurrent updates may modify the shape, and it is +// important that the shapes used for compilation match the true shapes of the +// buffers. +// +// Returns a map of TensorFlow argument index to resource variable. If a +// resource variable is not initialized, the corresponding OptionalTensor +// will have its `present` field set to false. +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables); + +// Adapter class that wraps a Tensorflow allocator as an XLA allocator. +// Assumes that the Tensorflow allocator permits asynchronous deallocation: +// see comment on `AllowsAsynchronousDeallocation()`. +class XlaAllocator : public xla::DeviceMemoryAllocator { + public: + XlaAllocator(const se::Platform* platform, Allocator* wrapped); + ~XlaAllocator() override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; + + // The Tensorflow BFC allocator used on GPU allows host-side deallocation + // before GPU execution takes place. Tensorflow uses the ordering of the main + // compute stream to enforce a happens-before relationship between a memory + // allocation and code that reuses the same memory. If Tensorflow adds + // support for multiple GPU streams or allocators with different ordering + // requirements, this code may need to change. + // (This attribute has no effect on CPU.) + bool AllowsAsynchronousDeallocation() const override { return true; } + + private: + Allocator* wrapped_; +}; + +// Helper class to perform the marshalling of TensorFlow inputs and outputs to +// ShapedBuffers suitable for passing to an XLA computation. +class XlaComputationLaunchContext { + public: + // Create a new launch context. 'allocate_xla_tensors' is true if allocated + // output tensors and variables are always XlaTensors. If false they are + // assumed to be "normal" device pointers. + XlaComputationLaunchContext(xla::LocalClient* client, + xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors); + + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). + // `variables` is a map from TensorFlow argument number to resource variable. + void PopulateInputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + const std::map& variables); + + // Given the XLA output in `output`, populate all outputs of `ctx`. + void PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output); + + // Return the argument list. Only valid after PopulateInputs() has been + // called. + const std::vector& arguments() const { return arg_ptrs_; } + + private: + xla::LocalClient* client_; + xla::DeviceMemoryAllocator* xla_allocator_; + bool allocate_xla_tensors_; + std::vector> arg_buffers_; + std::vector arg_ptrs_; +}; + +// A simple TensorBuffer implementation that allows us to create Tensors that +// take ownership of pre-allocated memory. +class XlaTensorBuffer : public TensorBuffer { + public: + XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, + Allocator* allocator) + : expected_size_(expected_size), + actual_size_(actual_size), + allocator_(allocator) { + data_ = const_cast(ptr); + } + + ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); } + + void* data() const override { return data_; } + size_t size() const override { return expected_size_; } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_allocated_bytes(actual_size_); + } + + static Tensor MakeTensor(DataType dtype, const TensorShape& shape, + se::DeviceMemoryBase buffer, Allocator* allocator) { + size_t expected_size = shape.num_elements() * DataTypeSize(dtype); + auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size, + buffer.size(), allocator); + Tensor t(dtype, shape, tensor_buffer); + tensor_buffer->Unref(); + return t; + } + + private: + void* data_; + size_t expected_size_; + size_t actual_size_; + Allocator* allocator_; +}; + +// Exposed in this header file for microbenchmarking purposes, but this is an +// internal implementation detail. +namespace internal { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +xla::ScopedShapedBuffer ExtractSubShapedBuffer( + xla::ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator); +} // namespace internal + +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a45932403ec1760d6b985d5357fd6d84fbf257a2 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains microbenchmarks for performance critical functions in +// xla_launch_util.cc. + +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, + /*index=*/fan_out / 2, + /*allocator=*/nullptr) + .release(); + } +} + +BENCHMARK(BM_ExtractSubBuffer) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c44c4ae6df7f3e2d60d8933561c0c71888e8c3f --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" + +namespace tensorflow { + +/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { + if (tensor->NumElements() == 0) { + return nullptr; + } + XlaTensor* xla_tensor = + FromOpaquePointer(const_cast(tensor->tensor_data().data())); + return xla_tensor; +} + +/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) { + return tensor.RefCountIsOne(); +} + +/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( + const Tensor& tensor) { + const XlaTensor* xla_tensor = FromTensor(&tensor); + if (xla_tensor) { + CHECK(xla_tensor->has_shaped_buffer()); + return xla_tensor->shaped_buffer().root_buffer(); + } else { + return se::DeviceMemoryBase(const_cast(tensor.tensor_data().data()), + tensor.tensor_data().size()); + } +} + +Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + xla::LocalClient* client, + int device_ordinal) { + xla::Shape on_host_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); + xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape( + on_host_shape); + + xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, + client->backend().memory_allocator(), + device_ordinal); + for (auto& index_to_buffer : shaped_buffer.buffers()) { + xla::Shape subshape = + xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); + uint64 size = + client->backend().transfer_manager()->GetByteSizeRequirement(subshape); + TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, + client->backend().memory_allocator()->Allocate( + device_ordinal, size, /*retry_on_failure=*/false)); + // Move our buffer into shaped_buffer, which takes ownership of it. + index_to_buffer.second = buffer.Forget(); + } + + VLOG(4) << shaped_buffer.ToString(); + + set_shaped_buffer(std::move(shaped_buffer)); + return Status::OK(); +} + +// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from +// device-side tensors, which are either CPU or GPU memory pointers. This works +// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. +namespace { +constexpr uintptr_t kTag = 0x1ULL; +} + +/*static*/ XlaTensor* XlaTensor::FromOpaquePointer(void* ptr) { + uintptr_t value = reinterpret_cast(ptr); + if (value & kTag) { + return reinterpret_cast(value & ~kTag); + } else { + return nullptr; + } +} + +/*static*/ void* XlaTensor::ToOpaquePointer(XlaTensor* tensor) { + uintptr_t value = reinterpret_cast(tensor); + CHECK_EQ(value & kTag, 0); + value |= kTag; + return reinterpret_cast(value); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..c54001a999998f45c0cdacd752ca4036f0792857 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The implementation of a Tensor for an XlaDevice. All device tensors are +// actually one of these. +// +// To distinguish between "normal" device tensors and XlaTensors, the raw +// pointer data stored in the TensorBuffer is a tagged pointer. +class XlaTensor { + public: + // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast + // fails. + static XlaTensor* FromTensor(const Tensor* tensor); + + static bool RefCountIsOne(const Tensor& tensor); + + // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in + // which case the returned value is shaped_buffer()->root_buffer(), or a + // normal Tensor in which case the returned value is + // {tensor.tensor_data().data(), tensor.tensor_data().size}. + static se::DeviceMemoryBase DeviceMemoryFromTensor(const Tensor& tensor); + + // Assign the internal ShapedBuffer to new memory for the given dtype and + // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it + // is replaced and the managed memory deallocated. + Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + xla::LocalClient* client, int device_ordinal); + + // Some Tensors can have complex on-device shapes, including tuple shapes. To + // manage the memory for these tensors a ShapedBuffer may be required. + + // Return true if this XlaTensor contains a ShapedBuffer. + bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } + // Return the contained ShapedBuffer. + // REQUIRES: has_shaped_buffer() + const xla::ShapedBuffer& shaped_buffer() const { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + xla::ShapedBuffer& shaped_buffer() { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + // Mutates the XlaTensor to set the ShapedBuffer. + void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { + shaped_buffer_ = + xla::MakeUnique(std::move(shaped_buffer)); + } + + // Some tensors on the device may have known values on the host. We use these + // in on-demand mode to avoid re-copying values from the device if we know the + // host value already. + + // Return true if this XlaTensor contains a host tensor. + bool has_host_tensor() const { return host_tensor_ != nullptr; } + // Return the contained host tensor. + // REQUIRES: has_host_tensor() + const Tensor& host_tensor() const { return *host_tensor_; } + // Sets the contained host tensor. + void set_host_tensor(const Tensor& tensor) { + host_tensor_.reset(new Tensor(tensor)); + } + + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. + static XlaTensor* FromOpaquePointer(void* ptr); + // Convert to a raw pointer from an XlaTensor, adding the pointer tag. + static void* ToOpaquePointer(XlaTensor* tensor); + + private: + // The optional contained ShapedBuffer. + std::unique_ptr shaped_buffer_; + // An optional host tensor value. + std::unique_ptr host_tensor_; +}; + +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index da4bc44c7a75c9f8faf16c537a17a1f2d16d5d61..238fd15166c0b08ee109d6a3888e16c39f87a603 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -49,17 +49,3 @@ cc_library( "//tensorflow/compiler/jit:xla_device", ], ) - -#----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 782bf82d4149968d5e5fbfb93bbd4ff1dcd75494..e6c92f9720e1285617280f60d1c5fea443c5ebef 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -86,11 +86,14 @@ tf_xla_py_test( # ArgMax needs CustomCall on CPU, which is not available in normal # (not precompiled) TensorFlow. The flag below excludes the CPU # backend. - disabled_backends = "cpu", + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -98,7 +101,7 @@ tf_xla_py_test( tf_xla_py_test( name = "binary_ops_test", - size = "small", + size = "medium", srcs = ["binary_ops_test.py"], shard_count = 5, tags = [ @@ -108,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -117,13 +120,27 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], + tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -137,7 +154,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -152,7 +169,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -166,7 +183,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -180,7 +197,34 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:gradient_checker", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "oom_test", + size = "medium", + srcs = ["oom_test.py"], + # TODO(b/80081500): Re-enable on GPU. Disabled on 2018-05-21. + disabled_backends = [ + "cpu", + "cpu_ondemand", + "gpu", + ], + tags = [ + # Allocates very large amounts of memory and does not work under TSAN. + "notsan", + "optonly", # Times out frequently in fastbuild. + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:array_ops_gen", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -196,7 +240,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -212,7 +256,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -234,7 +278,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -242,6 +286,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "dynamic_slice_ops_test", + size = "small", + srcs = ["dynamic_slice_ops_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -250,7 +306,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -259,11 +315,38 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "eager_test", + size = "small", + srcs = ["eager_test.py"], + disabled_backends = [ + # TODO(b/78199195) Support XLA CPU devices in eager runtime + "cpu", + "cpu_ondemand", + # TODO(b/78468222) Enable GPU backend + "gpu", + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", ], ) @@ -278,7 +361,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -292,19 +375,19 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) tf_xla_py_test( name = "ftrl_test", - size = "small", + size = "medium", srcs = ["ftrl_test.py"], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -315,10 +398,12 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], + # Functions are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -333,12 +418,27 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], ) +tf_xla_py_test( + name = "listdiff_op_test", + size = "small", + srcs = ["listdiff_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform_test", + "@six_archive//:six", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -346,7 +446,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -361,7 +461,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -373,7 +473,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -387,7 +487,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -400,7 +500,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -413,7 +513,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -428,7 +528,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -445,7 +545,9 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -460,7 +562,23 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "reduce_window_test", + size = "small", + srcs = ["reduce_window_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -473,7 +591,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", ], ) @@ -485,7 +603,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -497,7 +615,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -512,7 +630,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -525,7 +643,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -540,7 +658,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -550,11 +668,13 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + # Stack ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -567,7 +687,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -576,10 +696,12 @@ tf_xla_py_test( name = "tensor_array_ops_test", size = "small", srcs = ["tensor_array_ops_test.py"], + # TensorArray ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -598,7 +720,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -611,7 +733,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -625,7 +747,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -644,7 +766,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -654,15 +776,31 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "while_test", + size = "small", + srcs = ["while_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], + tags = ["noasan"], # times out, http://b/78599043 deps = [ ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -674,7 +812,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -687,21 +825,34 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) -cuda_py_test( +tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "xla_device_gpu_test", + size = "small", + srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], ) @@ -718,15 +869,23 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], - # TODO(b/62961789): Test fails with SIGABRT - tags = [ - "manual", - "notap", +) + +cuda_py_test( + name = "dense_layer_test", + size = "small", + srcs = ["dense_layer_test.py"], + additional_deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", + "//tensorflow/python:variables", ], ) @@ -769,7 +928,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -784,7 +943,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -822,21 +981,19 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +tf_xla_py_test( + name = "placeholder_test", + size = "small", + srcs = ["placeholder_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index ec547e16cd9c91a1e25bc963b9a3cafddf7326cd..9d3a889b1f54c813e881bb03b5275f809af1b3c8 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -29,51 +29,70 @@ from tensorflow.python.platform import test class ArgMinMaxTest(xla_test.XLATestCase): - def _assertOpOutputMatchesExpected(self, op, inp, expected): - """Verifies that 'op' produces 'expected' when fed input 'inp' . + def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input, + expected): + """Verifies that 'op' produces 'expected' when fed input 'op_input' . Args: - op: operator to test - inp: numpy input array to use as input to 'op'. + op: argmin or argmax operator to test. + axis: integer axis to reduce across. + output_type: numpy datatype of the output to produce. + op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ with self.test_session() as session: with self.test_scope(): pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name="a") - output = op(pinp) - result = session.run(output, {pinp: inp}) + dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") + output = op(pinp, axis=axis, output_type=output_type) + result = session.run(output, {pinp: op_input}) self.assertAllEqual(result, expected) def testArgMinMax(self): # Complex numbers do not support argmin/argmax. minmax_types = set(self.numeric_types) - set(self.complex_types) for dtype in minmax_types: - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([1, 10, 27, 3, 3, 4], dtype=dtype), - expected=np.int32(2)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([0, 1, 0], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([0, 0], dtype=np.int32)) + # output_type is a numpy data type that is used to specify the desired + # output type of the op as well as to convert the Python number to the + # array scalar of the type. + for output_type in self.int_types: + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype), + expected=output_type(2)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([0, 1, 0], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([0, 0], dtype=output_type)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([3, 10, 27, 3, 2, 4], dtype=dtype), - expected=np.int32(4)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([1, 0, 1], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([1, 1], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype), + expected=output_type(4)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([1, 0, 1], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([1, 1], dtype=output_type)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 30a6d3a74d64f90ad33062df6d1e16e3a575bd63..1e4dd32916c3a40282735fb8f75670b0e9ef0dc9 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -71,7 +71,7 @@ class BinaryOpsTest(XLATestCase): expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([3, 3, -1.5, -8, 44], dtype=dtype), np.array([2, -2, 7, -4, 0], dtype=dtype), expected=np.array( @@ -108,57 +108,57 @@ class BinaryOpsTest(XLATestCase): [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( - gen_math_ops._reciprocal_grad, + gen_math_ops.reciprocal_grad, np.array([4, -3, -2, 1], dtype=dtype), np.array([5, -6, 7, -8], dtype=dtype), expected=np.array([-80, 54, -28, 8], dtype=dtype)) self._testBinary( - gen_math_ops._sigmoid_grad, + gen_math_ops.sigmoid_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-60, -36, -14, 0], dtype=dtype)) self._testBinary( - gen_math_ops._rsqrt_grad, + gen_math_ops.rsqrt_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-160, -81, -28, -4], dtype=dtype)) self._testBinary( - gen_math_ops._sqrt_grad, + gen_math_ops.sqrt_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([0.625, 1, 1.75, 4], dtype=dtype)) self._testBinary( - gen_nn_ops._softplus_grad, + gen_nn_ops.softplus_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array( [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype)) self._testBinary( - gen_nn_ops._softsign_grad, + gen_nn_ops.softsign_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array( [0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype)) self._testBinary( - gen_math_ops._tanh_grad, + gen_math_ops.tanh_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-75, -48, -21, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._elu_grad, + gen_nn_ops.elu_grad, np.array([1, 2, 3, 4, 5, 6], dtype=dtype), np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) self._testBinary( - gen_nn_ops._selu_grad, + gen_nn_ops.selu_grad, np.array([1, 2, 3, 4, 5, 6], dtype=dtype), np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), expected=np.array( @@ -166,20 +166,20 @@ class BinaryOpsTest(XLATestCase): 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) self._testBinary( - gen_nn_ops._relu_grad, + gen_nn_ops.relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype)) self._testBinary( - gen_nn_ops._relu6_grad, + gen_nn_ops.relu6_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype), np.array( [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._softmax_cross_entropy_with_logits, + gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype), expected=[ @@ -190,24 +190,29 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) - self._testBinary( - gen_nn_ops._sparse_softmax_cross_entropy_with_logits, - np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], - [0.9, 1.0, 1.1, 1.2]], dtype=dtype), - np.array([2, 1, 7], dtype=np.int32), - expected=[ - np.array([1.342536, 1.442536, np.nan], dtype=dtype), - np.array([[0.213838, 0.236328, -0.738817, 0.288651], - [0.213838, -0.763672, 0.261183, 0.288651], - [np.nan, np.nan, np.nan, np.nan]], - dtype=dtype), - ], - equality_test=self.ListsAreClose) + # TODO(b/68813416): Fails with bfloat16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits, + np.array( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2]], + dtype=dtype), + np.array([2, 1, 7], dtype=np.int32), + expected=[ + np.array([1.342536, 1.442536, np.nan], dtype=dtype), + np.array( + [[0.213838, 0.236328, -0.738817, 0.288651], [ + 0.213838, -0.763672, 0.261183, 0.288651 + ], [np.nan, np.nan, np.nan, np.nan]], + dtype=dtype), + ], + equality_test=self.ListsAreClose) def testIntOps(self): for dtype in self.int_types: self._testBinary( - gen_math_ops._truncate_div, + gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) @@ -232,11 +237,16 @@ class BinaryOpsTest(XLATestCase): expected=np.right_shift(lhs, rhs)) if dtype in [np.int8, np.int16, np.int32, np.int64]: - lhs = np.array([-1, -5, -3, -14], dtype=dtype) - rhs = np.array([5, 0, 1, 11], dtype=dtype) - self._testBinary( - bitwise_ops.right_shift, lhs, rhs, - expected=np.right_shift(lhs, rhs)) + lhs = np.array([-1, -5, -3, -14, -2], dtype=dtype) + rhs = np.array([5, 0, 1, 11, 36], dtype=dtype) + # HLO has saturating shift behavior. + bits = np.ceil( + np.log(np.iinfo(dtype).max - np.iinfo(dtype).min) / np.log(2)) + expected = [ + np.right_shift(l, r) if r < bits else np.sign(l) + for l, r in zip(lhs, rhs) + ] + self._testBinary(bitwise_ops.right_shift, lhs, rhs, expected=expected) def testNumericOps(self): for dtype in self.numeric_types: @@ -258,9 +268,9 @@ class BinaryOpsTest(XLATestCase): self._testBinary( math_ops.subtract, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([-9, -18], dtype=dtype)) + np.array([1, 2, 100], dtype=dtype), + np.array([10, 20, -1], dtype=dtype), + expected=np.array([-9, -18, 101], dtype=dtype)) self._testBinary( math_ops.subtract, dtype(5), @@ -350,6 +360,14 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + if np.int64 in self.numeric_types: + self._testBinary( + math_ops.add, + np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), + np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), + expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], + dtype=np.int64)) + def testComplexOps(self): for dtype in self.complex_types: ctypes = {np.complex64: np.float32} @@ -369,7 +387,7 @@ class BinaryOpsTest(XLATestCase): expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), expected=np.array( @@ -378,7 +396,7 @@ class BinaryOpsTest(XLATestCase): # Test inf/nan scenarios. self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( @@ -418,19 +436,19 @@ class BinaryOpsTest(XLATestCase): lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) self._testBinary( - gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) + gen_math_ops.reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) self._testBinary( - gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) + gen_math_ops.sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + gen_math_ops.rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( - gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) + gen_math_ops.sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) self._testBinary( - gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) + gen_math_ops.tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) def testComplexMath(self): for dtype in self.complex_types: @@ -538,7 +556,7 @@ class BinaryOpsTest(XLATestCase): if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( - gen_math_ops._floor_div, + gen_math_ops.floor_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) @@ -554,12 +572,12 @@ class BinaryOpsTest(XLATestCase): def _testRemainder(self, dtype): """Test cases for remainder operators.""" self._testBinary( - gen_math_ops._floor_mod, + gen_math_ops.floor_mod, np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, -1, 6, 0], dtype=dtype)) self._testBinary( - gen_math_ops._truncate_mod, + gen_math_ops.truncate_mod, np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, 1, -1, 0], dtype=dtype)) @@ -668,6 +686,11 @@ class BinaryOpsTest(XLATestCase): np.array([[10], [7], [2]], dtype=np.float32), np.float32(7), expected=np.array([[False], [False], [True]], dtype=np.bool)) + self._testBinary( + less_op, + np.array([[10], [7], [2], [-1]], dtype=np.int64), + np.int64(7), + expected=np.array([[False], [False], [True], [True]], dtype=np.bool)) for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: self._testBinary( @@ -686,6 +709,80 @@ class BinaryOpsTest(XLATestCase): np.float32(7), expected=np.array([[False], [True], [True]], dtype=np.bool)) + def testS64Comparisons(self): + for op in [(lambda x, y: x < y), (lambda x, y: x <= y), + (lambda x, y: x >= y), (lambda x, y: x > y)]: + lhs = np.array( + [ + np.int64(0x000000007FFFFFFF), + np.int64(0x000000007FFFFFFF), + np.int64(0x0000000080000000), + np.int64(0x0000000080000000), + np.int64(0x0000000080000001), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFF), + np.int64(0x0000000100000000), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x7ffffffefff00010), + np.int64(0x7ffffffefff00010), + np.int64(-1), + np.int64(-1) + ], + dtype=np.int64) + rhs = np.array( + [ + np.int64(0x000000007FFFFFFE), + np.int64(0x000000007FFFFFFF), + np.int64(0x000000007FFFFFFF), + np.int64(0x0000000080000000), + np.int64(0x0000000080000001), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFF0001), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFF), + np.int64(0x0000000100000001), + np.int64(0x0000000100000002), + np.int64(0x0000000100000003), + np.int64(0x0000000200000001), + np.int64(0x0000000200000002), + np.int64(0x0000000200000003), + np.int64(0x0000000300000001), + np.int64(0x0000000300000002), + np.int64(0x0000000300000003), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000001), + np.int64(-2), + np.int64(-1) + ], + dtype=np.int64) + expected = np.array([op(l, r) for l, r in zip(lhs, rhs)], dtype=np.bool) + self._testBinary(op, lhs, rhs, expected=expected) + def testBroadcasting(self): """Tests broadcasting behavior of an operator.""" @@ -1045,6 +1142,20 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) + def splitvOp(x, y): # pylint: disable=invalid-name + return array_ops.split(value=y, num_or_size_splits=[2, 3], axis=x) + for axis in [1, -1]: + self._testBinary( + splitvOp, + np.int32(axis), + np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + dtype=dtype), + expected=[ + np.array([[0, 1], [5, 6]], dtype=dtype), + np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + def testTile(self): for dtype in self.numeric_types: self._testBinary( diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fde9759a1c209844caac99d5f303cd3e406e5370 --- /dev/null +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bucketize_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BucketizationOpTest(XLATestCase): + + def testInt(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual(expected_out, + sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) + + def testFloat(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual( + expected_out, + sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) + + def test2DInput(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + self.assertAllEqual( + expected_out, sess.run(op, + {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + + def testInvalidBoundariesOrder(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Expected sorted boundaries"): + sess.run(op, {p: [-5, 0]}) + + def testBoundariesNotList(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Expected list.*"): + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + math_ops._bucketize(p, boundaries=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 0528a5415d579a844e68403ace1bb8982a10a841..7b114d4f85d3a5cadc6af25b55c5a21f90d2a768 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -51,12 +51,12 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" ] elif backend == "gpu": backend_args += [ "--test_device=XLA_GPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: @@ -89,4 +89,3 @@ def generate_backend_suites(backends=[]): backends = all_backends() for backend in backends: native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) - diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 5010fe5e21d0782e68d4e6d5bf6b4df1b44793a3..1a8989d7c2f617525c301f30fd899a01362310bf 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -34,6 +34,13 @@ from tensorflow.python.platform import test class CholeskyOpTest(XLATestCase): + # Cholesky defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/cholesky) + @property + def float_types(self): + return set(super(CholeskyOpTest, self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) self.assertAllClose(x, verification_np, atol=atol) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 81734082d9aab86f8bc763681265ef64ef32bd31..f10973e19f1945515b776cf86349445ed7334629 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -301,7 +301,7 @@ class ConcatOffsetTest(XLATestCase): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) ans = sess.run(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..865f60ccab46ec6829e49409508303052944e13b --- /dev/null +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DenseLayer JIT compilation on the CPU and GPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.compiler import jit +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +jit_scope = jit.experimental_jit_scope + + +def GetRunMetadataLabels(run_metadata): + """Returns all labels in run_metadata.""" + labels = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + labels.append(node_stats.timeline_label) + return labels + + +def InLabels(labels, substr): + """Returns true iff one of the labels contains substr.""" + return any([substr in x for x in labels]) + + +def XlaLaunchOpCount(labels): + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) + + +class DenseLayerTest(test.TestCase): + + def testDenseLayerAutoJit(self): + """Tests dense layer compilation in auto-jit mode. + + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + """ + + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = ( + config_pb2.OptimizerOptions.ON_1) + + with self.test_session(config=config) as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + def testDenseLayerJitScopeDefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer with static shape input tensor should be compiled into a single + XlaLaunch op by XLA. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + # No need to check whether ListDiff is compiled or not because ListDiff op + # is not used when input tensor shape is fully defined. + + def testDenseLayerJitScopeUndefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer uses shape op to get shape of input tensor if its shape is not + fully defined. XLA does not cluster shape op with other operators. But in + experimental_jit_scope, XLA is forced to compile shape op into its own + cluster, causing dense layer to be split into TWO XlaLaunch ops. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 0a0d335ca76dd7ec7ca3b12f9e8a83b596daa07e..03d96a2cd8ab22a472a67f092e36224820405fa8 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -153,7 +153,7 @@ class DepthwiseConv2DTest(XLATestCase): dtype=data_type).reshape(filter_in_sizes) with self.test_session() as sess: if data_type == np.float32: - tolerance = 1e-5 + tolerance = 1e-4 else: self.assertEqual(data_type, np.float64) tolerance = 1e-8 @@ -339,7 +339,7 @@ class DepthwiseConv2DTest(XLATestCase): gpu_value = _GetVal(use_xla=True) cpu_value = _GetVal(use_xla=False) - self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4) + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-3) def testDepthwiseConv2DInputGradCompare(self): for index, (input_size, filter_size, output_size, stride, diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6a46d2ec3e7aee3a4ecfbf1ab9f622d8eb659e3c --- /dev/null +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -0,0 +1,93 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA dynamic slicing ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class DynamicUpdateSliceOpsTest(XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + self.assertAllClose(result, expected, rtol=1e-3) + + def testUpdateSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array([], dtype=dtype), + np.array([], dtype=dtype), + np.array([0], dtype=np.int32) + ], + expected=np.array([], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), + np.array([-1, -2, -3], dtype=dtype), + np.array([6], dtype=np.int32) + ], + expected=np.array([1, 2, 3, 4, 5, 6, -1, -2, -3, 10], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype), + np.array([[42, 43], [44, 45]], dtype=dtype), + np.array([1, 2], dtype=np.int32) + ], + expected=np.array( + [[1, 2, 3, 4], [5, 6, 42, 43], [9, 10, 44, 45]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype), + np.array([[], []], dtype=dtype), + np.array([1, 2], dtype=np.int32) + ], + expected=np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype), + np.ones([3, 4], dtype=dtype), + np.array([0, 0], dtype=np.int32) + ], + expected=np.ones([3, 4], dtype=dtype)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4154ad1e846f8241a2ab6598da36ccb6b3b653e --- /dev/null +++ b/tensorflow/compiler/tests/eager_test.py @@ -0,0 +1,443 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for eager execution using XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import pooling +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import googletest +from tensorflow.python.training import adam + + +class EagerTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = three * five + self.assertAllEqual(15, product) + + def testExecuteListOutputLen0(self): + with self.test_scope(): + empty = constant_op.constant([], dtype=dtypes.float32) + result = array_ops.unstack(empty, 0) + self.assertTrue(isinstance(result, list)) + self.assertEqual(0, len(result)) + + def testExecuteListOutputLen1(self): + with self.test_scope(): + split_dim = constant_op.constant(1) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) + result = array_ops.split(value, 1, axis=split_dim) + self.assertTrue(isinstance(result, list)) + self.assertEqual(1, len(result)) + self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0]) + + def testExecuteListOutputLen3(self): + with self.test_scope(): + split_dim = constant_op.constant(1) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) + result = array_ops.split(value, 3, axis=split_dim) + self.assertTrue(isinstance(result, list)) + self.assertEqual(3, len(result)) + self.assertAllEqual([[0], [3]], result[0]) + self.assertAllEqual([[1], [4]], result[1]) + self.assertAllEqual([[2], [5]], result[2]) + + def testBasicGraph(self): + # Run some ops eagerly + with self.test_scope(): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = three * five + self.assertAllEqual(15, product) + + # Run some ops graphly + with context.graph_mode(), self.test_session() as sess: + with self.test_scope(): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = three * five + self.assertAllEqual(15, sess.run(product)) + + def testDegenerateSlices(self): + with self.test_scope(): + npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3) + t = constant_op.constant(npt) + # degenerate by offering a forward interval with a negative stride + self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :]) + # degenerate with a reverse interval with a positive stride + self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :]) + # empty interval in every dimension + self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1]) + + def testIdentity(self): + with self.test_scope(): + self.assertAllEqual(2, array_ops.identity(2)) + + def testIdentityOnVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(True) + i = array_ops.identity(v) + self.assertAllEqual(True, i.numpy()) + + def testAssignAddVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + v.assign_add(2.0) + self.assertEqual(3.0, v.numpy()) + + def testReadAssignRead(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + val1 = v.read_value() + v.assign_add(2.0) + val2 = v.read_value() + self.assertEqual(1.0, val1.numpy()) + self.assertEqual(3.0, val2.numpy()) + + def testGradient(self): + def f(x): + return x + + with self.test_scope(): + grad_fn = backprop.gradients_function(f) + self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) + + def testVariableGradient(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(1.0) + + def f(): + x = v0 * v0 + return x + + grads = backprop.implicit_grad(f)() + self.assertEqual(2., grads[0][0].numpy()) + + def testMultipleVariableReads(self): + # This test makes sure consecutive variable reads don't copy + # the underlying memory. + with self.test_scope(): + # Create 128MiB variables + var = resource_variable_ops.ResourceVariable( + array_ops.ones([32, 1024, 1024])) + + # Read the same variable 100 times. If the underlying tensor + # is not copied, this is a trivial operation. If it is copied, + # this will eat over 13GB and OOM. + values = [] + for _ in range(100): + values.append(var.value()) + + # The shape, shape_n, size, and rank are tested here because their + # execution kernels (as opposed to compilation only tf2xla kernels) + # are distincts from tf2xla kernels. + + def testShape(self): + def const(value): + return array_ops.shape( + constant_op.constant(value)).numpy() + + def ones(value): + return array_ops.shape( + array_ops.ones(value)).numpy() + + with self.test_scope(): + # Shapes of directly constructed tensors + self.assertAllEqual([], const(3)) + self.assertAllEqual([3], const([1.0, 2.0, 3.0])) + self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) + self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) + + # Shapes of tensors created by op running on device + # We make this distinction because directly constructed tensors + # are treated differently in a few places that can influence shape: + # - they always have on_host_tensor + # - they and their shapes can be cached + # - they end up on device via a copy, instead of as program output + self.assertAllEqual([], ones([])) + self.assertAllEqual([3], ones([3])) + self.assertAllEqual([2, 2], ones([2, 2])) + self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) + + def testShapeN(self): + with self.test_scope(): + # Shapes of directly constructed tensors + shapes = array_ops.shape_n([ + constant_op.constant(1.0), + constant_op.constant([1.0, 2.0, 3.0]), + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + # Shapes of tensors created by op running on device + shapes = array_ops.shape_n([ + array_ops.ones([]), + array_ops.ones([3]), + array_ops.ones([2, 2])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + def testSize(self): + with self.test_scope(): + self.assertEqual( + 1, array_ops.size(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 4, array_ops.size( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testRank(self): + with self.test_scope(): + self.assertEqual( + 0, array_ops.rank(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 2, array_ops.rank( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testAdam(self): + with self.test_scope(): + optimizer = adam.AdamOptimizer(0.1) + x = resource_variable_ops.ResourceVariable(10.0) + with backprop.GradientTape() as tape: + y = x * x + dy_dx = tape.gradient(y, x) + optimizer.apply_gradients([(dy_dx, x)]) + self.assertAlmostEqual(9.9, x.numpy(), places=3) + + def testAdamSparse(self): + with ops.device('/cpu:0'): + # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates + # are not implemented on TPU. + embedding_matrix = resource_variable_ops.ResourceVariable( + array_ops.ones([3, 2])) + + with self.test_scope(): + with backprop.GradientTape() as tape: + embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) + y = math_ops.reduce_sum(embedding) + dy_dx = tape.gradient(y, embedding_matrix) + self.assertIsInstance(dy_dx, ops.IndexedSlices) + optimizer = adam.AdamOptimizer(0.1) + # The gradient application operations will run on CPU because optimizer + # updates are always collocated with the variable. + optimizer.apply_gradients([(dy_dx, embedding_matrix)]) + + # This assign_add will run on CPU because when an input to an + # operation is a resource, this operation is placed on the resource's + # device by the eager runtime. + embedding_matrix.assign_add(array_ops.ones([3, 2])) + + self.assertAllClose([[2.0, 2.0], + [1.9, 1.9], + [2.0, 2.0]], embedding_matrix.numpy()) + + +class EagerFunctionTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + matmul = function.defun(math_ops.matmul, compiled=True) + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t, transpose_a=True) + self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) + + def testConv(self): + if 'GPU' in self.device: + # TODO(b/32333178) + self.skipTest('Current implementation of RandomStandardNormal kernel ' + 'is very slow on GPU, and has been blacklisted.') + with self.test_scope(): + data_format = 'channels_last' + conv = convolutional.Conv2D( + filters=1, kernel_size=2, padding='VALID', + data_format=data_format, activation=nn_ops.relu, + kernel_initializer=init_ops.ones_initializer(), + bias_initializer=init_ops.zeros_initializer()) + pool = pooling.MaxPooling2D(2, 2, data_format=data_format) + + def model(x): + x = conv(x) + return pool(x) + model = function.defun(model, compiled=True) + + x = array_ops.ones([1, 4, 4, 1]) + y = model(x) + self.assertAllEqual(y.numpy(), [[[[4.]]]]) + + def testReadVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun(compiled=True) + def f(): + return v.read_value() + + var = f() + self.assertEqual(1.0, var.numpy()) + + def testUpdateVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + def f(v): + v.assign_add(1.0) + return v + + f = function.defun(f, compiled=True) + + var = f(v) + self.assertEqual(2.0, var.numpy()) + + def testAllArgumentKinds(self): + """Test a complex function that takes different argument kinds. + + tf2xla machinery that translates, compiles, and runs defuns + classifies arguments into: compile-time constants, regular tensors, + and resources. This test creates a function with a mix of all these + kinds. Moreover, the order of function arguments is intentionally mixed up. + + This also tests the case when the same argument is a compile-time constant + as well as used in an operation that normally expects its inputs to be + in device memory - addition in this case. + """ + with self.test_scope(): + def foo(c1, r1, v1, c2, v2, r2): + # c1 and c2 are compile-time constants + # r1 and r2 are regular tensors + # v1 and v2 are resource variables + a = c1 + r1 + b = math_ops.cast(c2, dtypes.float32) + v2 + c = array_ops.slice(v1, c1, c2) + d = r2 * v2 + return a, b, c, d + + foo = function.defun(foo, compiled=True) + + c1 = [0, 0] + c2 = array_ops.ones([2], dtype=dtypes.int32) + + r1 = array_ops.ones([2]) + r2 = [[2., 2.], [3., 3.]] + + v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) + v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) + + a, b, c, d = foo(c1, r1, v1, c2, v2, r2) + + self.assertAllEqual([1, 1], a.numpy()) + self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) + self.assertAllEqual([[1.]], c.numpy()) + self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + + def testDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun(compiled=True) + def f(x): + x = v0 * v0 * x + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + +class ExcessivePaddingTest(XLATestCase): + """Test that eager execution works with TPU flattened tensors. + + Tensors that would normally be excessively padded when written + to TPU memory are reshaped to 1-D flat tensors. + + This test case verifies that such tensors work with eager execution. + + The flattening currently only happens on TPU, but tests should work + fine with all backends as flattening is transparent. + """ + + def testFromConstant(self): + with self.test_scope(): + # Create constant of shape [100, 2, 1]. This tensor would be + # excessively padded on TPU. + tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) + # Use reduce_sum since it requires correctly working with + # a particular dimension. + reduced = math_ops.reduce_sum(tensor, axis=1) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testFromOperation(self): + with self.test_scope(): + tensor = array_ops.ones([3, 100, 2, 2]) + reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) + self.assertAllEqual(100 * [12.0], reduced) + + def testAsFunctionInput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return math_ops.reduce_sum(x, axis=2) + + tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) + reduced = f(tensor) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testAsFunctionOutput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return x * constant_op.constant(100 * [[[10.0, 2.0]]]) + + y = f(3) + reduced = math_ops.reduce_sum(y, axis=2) + self.assertAllEqual(100 * [[36.0]], reduced) + + +if __name__ == '__main__': + ops.enable_eager_execution( + config=config_pb2.ConfigProto(log_device_placement=True)) + googletest.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index f9db4cf2017c0b4b6dc0cfeeda6dca7bb9d14f19..8e6407dffdac3adbcda8cbca2109ef9196defa8c 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -134,9 +134,15 @@ class FtrlOptimizerTest(XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.60260963, -4.29698515]), var0.eval(), float_rtol=1e-5) + np.array([-2.60260963, -4.29698515]), + var0.eval(), + float_rtol=1e-5, + half_rtol=1e-2) self.assertAllCloseAccordingToType( - np.array([-0.28432083, -0.56694895]), var1.eval(), float_rtol=1e-5) + np.array([-0.28432083, -0.56694895]), + var1.eval(), + float_rtol=1e-5, + half_rtol=1e-2) def testFtrlwithoutRegularization2(self): for dtype in self.float_types: @@ -272,8 +278,8 @@ class FtrlOptimizerTest(XLATestCase): with self.test_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) - self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4) - self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4, half_rtol=1e-2) def testEquivGradientDescentwithoutRegularization(self): steps = 5 diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 11d8a99ffe1a136a54b16e20f1792062203f7969..8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,12 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): @@ -105,6 +103,28 @@ class FunctionTest(XLATestCase): result = sess.run(call_f) self.assertAllClose(result, expected, rtol=1e-3) + def testCompileTimeConstantsInDefun(self): + """Tests that XLA handles compile-time constants in defuns.""" + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) + def Foo(a, c, d): + # c and d must be known at compile time + x = array_ops.slice(a, c, d) + return x + + a = array_ops.placeholder(dtypes.float32) + c = array_ops.placeholder(dtypes.int32, shape=[4]) + d = array_ops.placeholder(dtypes.int32, shape=[4]) + with self.test_scope(): + call_f = Foo(a, c, d) + result = sess.run(call_f, feed_dict={ + a: np.ones([1, 4, 4, 1]), + c: [0, 0, 0, 0], + d: [1, 2, 2, 1]}) + + self.assertAllEqual(np.ones([1, 2, 2, 1]), result) + # TODO(b/36139787): Re-enable this test when noinline works again. def DISABLED_testFunctionsNoInline(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 538fa8e8e570b83ed681ecc0501285520cabdecb..7cf953ef25ef5daf8a6d4fc9985ed8dbfb2081e5 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -34,6 +34,13 @@ from tensorflow.python.ops import image_ops from tensorflow.python.platform import test +def GenerateNumpyRandomRGB(shape): + # Only generate floating points that are fractions like n / 256, since they + # are RGB pixels. Some low-precision floating point types in this test can't + # handle arbitrary precision floating points well. + return np.random.randint(0, 256, shape) / 256. + + class RGBToHSVTest(XLATestCase): def testBatch(self): @@ -43,7 +50,7 @@ class RGBToHSVTest(XLATestCase): shape = (batch_size, 2, 7, 3) for nptype in self.float_types: - inp = np.random.rand(*shape).astype(nptype) + inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually with self.test_session() as sess: @@ -58,14 +65,13 @@ class RGBToHSVTest(XLATestCase): join1 = array_ops.stack(split1) join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], - { - batch0: inp - }) + {batch0: inp}) # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) self.assertAllClose(batch2, join2) - self.assertAllCloseAccordingToType(batch2, inp, bfloat16_atol=0.03) + self.assertAllCloseAccordingToType( + batch2, inp, bfloat16_atol=0.03, half_rtol=0.02) def testRGBToHSVRoundTrip(self): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -82,7 +88,7 @@ class RGBToHSVTest(XLATestCase): def testRGBToHSVNumpy(self): """Tests the RGB to HSV conversion matches a reference implementation.""" for nptype in self.float_types: - rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype) + rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype) rgb_np = rgb_flat.reshape(4, 4, 4, 3) hsv_np = np.array([ colorsys.rgb_to_hsv( @@ -393,9 +399,7 @@ class AdjustSaturationTest(XLATestCase): x = array_ops.placeholder(dtypes.float32, shape=x_shape) with self.test_scope(): y_fused = self._adjust_saturation(x, - scale).eval(feed_dict={ - x: x_np - }) + scale).eval(feed_dict={x: x_np}) self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) @@ -404,7 +408,8 @@ class ResizeBilinearTest(XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, target_shape, - expected=None): + expected=None, + large_tolerance=False): if expected is None: self.fail("expected must be specified") with self.test_session() as sess, self.test_scope(): @@ -412,7 +417,11 @@ class ResizeBilinearTest(XLATestCase): resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=True) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) def _assertBackwardOpMatchesExpected(self, grads_np, @@ -426,7 +435,7 @@ class ResizeBilinearTest(XLATestCase): with self.test_session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) - resized = gen_image_ops._resize_bilinear_grad( + resized = gen_image_ops.resize_bilinear_grad( grads, np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), align_corners=True) @@ -547,6 +556,28 @@ class ResizeBilinearTest(XLATestCase): [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], dtype=np.float32)) + def testAlignCorners4x4To8x8(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array( + [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8], + expected=3 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)), + large_tolerance=True) + + def testAlignCorners8x8To16x16(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, + [16, 16], + expected=7 * (np.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), + large_tolerance=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 2d8236e2cbdfafb35626cd582ee39b1f917aec7f..6e0db54b7a74b284dc7d18bcbb07c178c664c1e5 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,6 +39,18 @@ from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope +# Disable rewrites to make sure we don't end up having to update this test +# whenever we implement new ones. +def NoRewriteSessionConfig(): + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + function_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + return config_pb2.ConfigProto(graph_options=graph_options) + + def CompiledKernel(fn, *inputs, **kwargs): """Execute 'fn' as a compiled XLA kernel, with 'inputs'.""" name = kwargs.pop("name", None) @@ -64,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -76,11 +90,11 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): - with session_lib.Session() as sess: + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] feeds = {} for arg in args: @@ -111,7 +125,7 @@ class JitLaunchTest(test.TestCase): for (x, y) in zip(compiled, direct): self.assertAllClose(x, y, rtol=1e-1) else: - self.assertAllClose(compiled, direct) + self.assertAllClose(compiled, direct, rtol=1e-2) def testNoOutputs(self): with session_lib.Session() as sess: @@ -257,7 +271,7 @@ class XlaCompilationTest(test.TestCase): def testReshape(self): """Tests an operator with compile-time constant and non-constant inputs.""" - with self.test_session() as sess: + with self.test_session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -281,7 +295,7 @@ class XlaCompilationTest(test.TestCase): def testIgnoredArguments(self): """Tests that JIT computations can ignore formal parameters.""" - with self.test_session() as sess: + with self.test_session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.int32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -305,7 +319,7 @@ class XlaCompilationTest(test.TestCase): def testLoops(self): """Tests that compilation accepts computations containing loops.""" - with self.test_session() as session: + with self.test_session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): c = lambda i, _: math_ops.less(i, 5) @@ -323,7 +337,7 @@ class XlaCompilationTest(test.TestCase): def testCond(self): """Tests that compilation handles switch operators.""" - with self.test_session() as session: + with self.test_session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.float32) c = array_ops.placeholder(dtypes.bool) @@ -364,7 +378,8 @@ class XlaCompilationTest(test.TestCase): inp = array_ops.placeholder(dtypes.float32) out = Entry(inp) - with self.test_session(graph=g, use_gpu=True) as sess: + with self.test_session( + config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: run_metadata = config_pb2.RunMetadata() val = sess.run(out, feed_dict={inp: [2., 10.]}, @@ -376,7 +391,7 @@ class XlaCompilationTest(test.TestCase): def testLoopDeadlock(self): """Regression test for bug that caused deadlocks in graphs with loops.""" - with self.test_session() as session: + with self.test_session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): y = x + 1.0 @@ -403,10 +418,10 @@ class XlaCompilationTest(test.TestCase): y = Forward(x) dx, = gradients_impl.gradients(y, [x], 1.0) - cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( - optimizer_options=config_pb2.OptimizerOptions( - opt_level=config_pb2.OptimizerOptions.L1, - do_function_inlining=True))) + cfg = NoRewriteSessionConfig() + cfg.graph_options.optimizer_options.opt_level = ( + config_pb2.OptimizerOptions.L1) + cfg.graph_options.optimizer_options.do_function_inlining = True with session_lib.Session(graph=g, config=cfg) as sess: run_metadata = config_pb2.RunMetadata() dx_val = sess.run(dx, @@ -426,14 +441,65 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaLaunch")) + + +class ElementWiseFusionTest(test.TestCase): + + # Runs a simple test with the input jit_level and fusion_only flag. + def simpleTest(self, arg0, arg1, global_jit_level): + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = global_jit_level + + with session_lib.Session(config=config) as sess: + a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1") + a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2") + # Two element-wise ops. We need at least two ops since single + # element clusters are not passed to XLA in fusion_only mode. + a3 = a1 * a2 + a4 = a3 + a1 + # A matmul to break XLA clustering. + a5 = math_ops.matmul(a4, a1) + # Two more element-wise ops. + a6 = a5 - a4 + a7 = a6 + a2 + + run_metadata = config_pb2.RunMetadata() + output = sess.run( + a7, { + a1: arg0, + a2: arg1 + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = RunMetadataLabels(run_metadata) + count = sum("XlaLaunch(" in x for x in labels) + + return output, count + + def testElementWiseClustering(self): + arg0 = np.random.rand(2, 2).astype(np.float32) + arg1 = np.random.rand(2, 2).astype(np.float32) + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit") + tf_op, tf_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.OFF) + self.assertEqual(0, tf_count) + + tfef_op, tfef_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.ON_1) + self.assertEqual(2, tfef_count) + + self.assertAllClose(tf_op, tfef_op, rtol=1e-1) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..45a04f0cf56e88946b946bedacb25ce6da3121b4 --- /dev/null +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA listdiff operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ListDiffTest(xla_test.XLATestCase): + + def _testListDiff(self, x, y, out, idx): + for dtype in [dtypes.int32, dtypes.int64]: + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + with self.test_scope(): + out_tensor, idx_tensor = array_ops.listdiff( + x_tensor, y_tensor, out_idx=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(out, tf_out) + self.assertAllEqual(idx, tf_idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) + + def testBasic1(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3]) + + def testBasic2(self): + self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3]) + + def testBasic3(self): + self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2]) + + def testDuplicates(self): + self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1], + y=[4, 2], + out=[1, 3, 3, 3, 1], + idx=[0, 3, 5, 6, 7]) + + def testRandom(self): + num_random_tests = 10 + int_low = -7 + int_high = 8 + max_size = 50 + for _ in xrange(num_random_tests): + x_size = np.random.randint(max_size + 1) + x = np.random.randint(int_low, int_high, size=x_size) + y_size = np.random.randint(max_size + 1) + y = np.random.randint(int_low, int_high, size=y_size) + out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y] + if out_idx: + out, idx = map(list, zip(*out_idx)) + else: + out = [] + idx = [] + self._testListDiff(list(x), list(y), out, idx) + + def testFullyOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[]) + + def testNonOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], + y=[5, 6], + out=[1, 2, 3, 4], + idx=[0, 1, 2, 3]) + + def testEmptyX(self): + self._testListDiff(x=[], y=[1, 2], out=[], idx=[]) + + def testEmptyY(self): + self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3]) + + def testEmptyXY(self): + self._testListDiff(x=[], y=[], out=[], idx=[]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 5d8d89224d4a778d84803811710bb095872e86b2..69bd8f7230d4394c45764d02a88fb0ec097c5756 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -115,11 +115,11 @@ class LRNTest(XLATestCase): out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) with ops.device(CPU_DEVICE): - expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, - depth_radius, bias, alpha, beta) + expected = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) with self.test_scope(): - actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, - depth_radius, bias, alpha, beta) + actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) expected_val = expected.eval() actual_val = actual.eval() self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index cccb7f5789dce39ef8c3d4b3a7573aaa983b3fbd..5819b2bf2b55b9213a039c0ba82dd0bf1c738b00 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -37,6 +37,14 @@ def MakePlaceholder(x): class MatrixTriangularSolveOpTest(XLATestCase): + # MatrixTriangularSolve defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) + @property + def float_types(self): + return set(super(MatrixTriangularSolveOpTest, + self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol): diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d68d32057a367776d5b70d5ac21d5618297c605d --- /dev/null +++ b/tensorflow/compiler/tests/oom_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for out-of-memory conditions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class OutOfMemoryTest(xla_test.XLATestCase): + + def testOutputOutOfMemory(self): + """Allocates tensors until out of memory. + + Generates a large rank-1 tensor. The tensor is an output of an XLA + computation, not constant. + + Check that a ResourceExhaustedError is raised and can be caught. + + We spin in a loop generating larger and larger tensors until an OOM event + happens. We may be running sandboxed, so have a small host memory limit, so + any hardcoded value is unlikely to land in the sweet spot between device + memory size and host memory size with stability. + """ + + def test_loop(): + size = int(2e8) + while True: + with self.test_session(): + # Force the compiled code to not be constant by feeding in a + # parameter. + p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) + with self.test_scope(): + # Create a computation that produces a large R1 tensor as an + # intermediate result. Reduce it down so that if this file was + # compiled without --config=cuda, we don't force a D2H copy of a + # large tensor and potentially OOM the host. + # + # This is a bit tricky because XLA:GPU doesn't currently support RNG + # ops. Here we rely on the fact that XLA doesn't do algebraic + # simplifications on conv(, ). + c = math_ops.reduce_sum( + nn_ops.convolution( + array_ops.ones([1, size, 1]), + p, + padding='SAME', + data_format='NWC')) + + c.eval(feed_dict={p: [[[1.0]], [[2.0]]]}) + size *= 2 + + self.assertRaises(errors.ResourceExhaustedError, test_loop) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6d1313bd0336eba71fcf3658d949bd3342ae11 --- /dev/null +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for xla handling of placeholder_with_default.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class PlaceholderTest(XLATestCase): + + def test_placeholder_with_default_default(self): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(4.0) + ph = array_ops.placeholder_with_default(v, shape=[]) + out = ph * 2 + sess.run(variables.variables_initializer([v])) + self.assertEqual(8.0, sess.run(out)) + + def test_placeholder_with_default_fed(self): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(4.0) + ph = array_ops.placeholder_with_default(v, shape=[]) + out = ph * 2 + sess.run(variables.variables_initializer([v])) + self.assertEqual(2.0, sess.run(out, {ph: 1.0})) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index eb48fe555a0b182ea7983cbd8c3b217d56350408..4eed903963a34a253ea5c409782d9a89a97a4fdf 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test # MaxPoolGrad. def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): del outputs # Unused by average-pooling gradients. - return gen_nn_ops._avg_pool3d_grad( + return gen_nn_ops.avg_pool3d_grad( inputs.get_shape().as_list(), output_gradients, ksize=ksize, @@ -263,7 +263,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[1, 3, 3, 3, 1], ksize=[1, 1, 1], strides=[1, 1, 1], @@ -272,7 +272,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 3, 6, 3], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -281,7 +281,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_1_7_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 5, 7, 3], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -290,7 +290,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 2, 2, 2, 3], ksize=[2, 2, 2], strides=[2, 2, 2], @@ -299,7 +299,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 2, 4, 1], ksize=[1, 1, 1], strides=[1, 1, 1], @@ -308,7 +308,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 2, 4, 1], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -317,7 +317,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 5, 2, 4, 3], ksize=[2, 2, 2], strides=[2, 2, 2], @@ -326,7 +326,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[1, 3, 3, 7, 1], ksize=[3, 3, 3], strides=[1, 1, 1], diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 7c19a99c4eb4be3ca34b3ce949216e557b0a681d..fe270af3d636c0824621f36360ce9e7d14d8fc91 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -292,8 +292,15 @@ class PoolGradTest(XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" - def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format): + def _VerifyOneTest(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -304,9 +311,19 @@ class PoolGradTest(XLATestCase): strides: The stride dimensions padding: Padding type. data_format: The data format we use to run the pooling operation. + pool_grad_grad_func: Second-order gradient function, if available. """ total_size = np.prod(input_sizes) - x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) + # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally + # maximal at 16 bits. Switch to np.random.randn when resolved. + x = np.arange(1, total_size + 1, dtype=np.float32) + x *= (np.random.randint(2, size=total_size) * 2 - 1) # Flip signs randomly + # Verify some specifically interesting values... + x[np.random.choice(total_size)] = np.inf + x[np.random.choice(total_size)] = -np.inf + # TODO(b/74222344): Fix nan handling for max pool grad. + # x[np.random.choice(total_size)] = np.nan + x = x.reshape(input_sizes) with self.test_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). @@ -323,6 +340,8 @@ class PoolGradTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -342,18 +361,36 @@ class PoolGradTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) xla_inputs = inputs xla_outputs = outputs xla_output_gradients = output_gradients + xla_output_grad_gradients = output_grad_gradients xla_ksize = ksize xla_strides = strides if data_format == "NCHW": xla_inputs = NHWCToNCHW(inputs) xla_outputs = NHWCToNCHW(outputs) xla_output_gradients = NHWCToNCHW(output_gradients) + xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients) xla_ksize = NHWCToNCHW(ksize) xla_strides = NHWCToNCHW(strides) actual_input_gradients = pool_grad_func( @@ -366,22 +403,54 @@ class PoolGradTest(XLATestCase): data_format=data_format) if data_format == "NCHW": actual_input_gradients = NCHWToNHWC(actual_input_gradients) - actual = sess.run(actual_input_gradients, { + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + xla_inputs, + xla_outputs, + xla_output_grad_gradients, + ksize=xla_ksize, + strides=xla_strides, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + actual_grad_gradients = NCHWToNHWC(actual_grad_gradients) + actual_input_gradients_vals = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, output_gradients: output_gradient_vals }) - # Compare the Tensorflow and XLA results. self.assertAllClose( - expected_input_gradient_vals.flatten(), - actual.flatten(), + expected_input_gradient_vals, + actual_input_gradients_vals, rtol=1e-4, atol=1e-6) - self.assertShapeEqual(actual, inputs) - - def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + self.assertShapeEqual(actual_input_gradients_vals, inputs) + + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + + def _VerifyValues(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling function. Args: @@ -391,12 +460,20 @@ class PoolGradTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ for data_format in GetTestConfigs(): - self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format) - - def _TestPooling(self, forward_op, backward_op): + self._VerifyOneTest( + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=pool_grad_grad_func) + + def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None): # VALID padding self._VerifyValues( forward_op, @@ -404,7 +481,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 3, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding self._VerifyValues( @@ -413,7 +491,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, non square window self._VerifyValues( @@ -422,7 +501,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 2, 1], ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # VALID padding, uneven stride self._VerifyValues( @@ -431,14 +511,16 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) self._VerifyValues( forward_op, backward_op, input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 4 input self._VerifyValues( @@ -447,7 +529,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 4], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 8 input self._VerifyValues( @@ -456,10 +539,14 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 8, 8, 8], ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) def testMaxPool(self): - self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad) + self._TestPooling( + nn_ops.max_pool, + gen_nn_ops.max_pool_grad, + pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad) def testAvgPool(self): # Wrapper around AvgPoolGrad that ignores extra arguments needed by @@ -467,7 +554,7 @@ class PoolGradTest(XLATestCase): def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding, data_format): del outputs # Unused by average-pooling gradients. - return gen_nn_ops._avg_pool_grad( + return gen_nn_ops.avg_pool_grad( inputs.get_shape().as_list(), output_gradients, ksize=ksize, @@ -483,7 +570,7 @@ class PoolGradTest(XLATestCase): def testMaxPoolKernelSmallerThanStrideValid(self): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 7, 7, 1], ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1], @@ -492,7 +579,7 @@ class PoolGradTest(XLATestCase): def testMaxPoolKernelSmallerThanStrideSame(self): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 3, 3, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], @@ -500,7 +587,7 @@ class PoolGradTest(XLATestCase): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 4, 4, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index d6c93088d4efff7d8306e262a79ae49d3d8ac722..f13dff96203b5480480c2a2fc9ac38ca78b7f78a 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest @@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,12 +72,20 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsNotConstant(self): + + def rng(dtype): + return random_ops.truncated_normal(shape=[2], dtype=dtype) + + # TODO(b/34339814): implement inverse erf support for non-F32 types. + self._testRngIsNotConstant(rng, dtypes.float32) + def testTruncatedNormalIsInRange(self): count = 10000 # TODO(b/34339814): implement inverse erf support for non-F32 types. @@ -87,6 +97,29 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == count) self.assertTrue((y <= 2).sum() == count) + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e72dd4eea9f127e1df96ab166103c4c16372adb6..16f293891d56d78885dd515bb7b9899faf0690f7 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) { return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } -constexpr std::array kAllXlaTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}}; +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64, DT_INT64}}; // An OpTestBuilder is a graph builder class that takes as input an operator to // test, its inputs and attributes, and builds a graph that executes the @@ -619,8 +619,8 @@ std::vector OpTest::ImageDims(TensorFormat format, int batch, dims.push_back(dim); } break; - case FORMAT_NCHW_VECT_C: - LOG(FATAL) << "FORMAT_NCHW_VECT_C not supported."; + default: + LOG(FATAL) << "Tensor format " << ToString(format) << " not supported."; } return dims; } diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 965fdf684b973498d0b3c3cde17711cca7279705..7420724bdbeab63b39542ada59328621febad895 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import itertools import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -30,8 +32,13 @@ from tensorflow.python.platform import googletest class ReduceOpsTest(XLATestCase): - def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, - rtol=1e-4, atol=1e-4): + def _testReduction(self, + tf_reduce_fn, + np_reduce_fn, + dtype, + test_inputs, + rtol=1e-4, + atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: @@ -41,16 +48,16 @@ class ReduceOpsTest(XLATestCase): index = array_ops.placeholder(dtypes.int32) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=0), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [1]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=1), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [-1]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=1), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Invalid reduction dim'): @@ -60,7 +67,7 @@ class ReduceOpsTest(XLATestCase): errors_impl.InvalidArgumentError, 'Invalid reduction dim'): sess.run(out, {a: test_input, index: [2]}) - FLOAT_DATA = [ + REAL_DATA = [ np.zeros(shape=(2, 0)), np.zeros(shape=(0, 30)), np.arange(1, 7).reshape(2, 3), @@ -74,7 +81,7 @@ class ReduceOpsTest(XLATestCase): np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_REAL_DATA = [x for x in REAL_DATA if np.size(x) > 0] NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), @@ -83,8 +90,7 @@ class ReduceOpsTest(XLATestCase): ] def testReduceSumF32(self): - self._testReduction(math_ops.reduce_sum, np.sum, np.float32, - self.FLOAT_DATA) + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA) def testReduceSumC64(self): self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, @@ -92,7 +98,7 @@ class ReduceOpsTest(XLATestCase): def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, - self.FLOAT_DATA) + self.REAL_DATA) def testReduceProdC64(self): self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, @@ -100,31 +106,44 @@ class ReduceOpsTest(XLATestCase): def testReduceMin(self): - def reference_min(inp, axis): + def reference_min(dtype, inp, axis): """Wrapper around np.amin that returns +infinity for an empty input.""" if inp.shape[axis] == 0: - return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + if np.issubdtype(dtype, np.floating): + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + np.iinfo(dtype).max) return np.amin(inp, axis) - self._testReduction(math_ops.reduce_min, reference_min, np.float32, - self.FLOAT_DATA) + for dtype in set(self.all_types).intersection( + [np.float32, np.int32, np.int64]): + self._testReduction(math_ops.reduce_min, + functools.partial(reference_min, dtype), dtype, + self.REAL_DATA) def testReduceMax(self): - def reference_max(inp, axis): + def reference_max(dtype, inp, axis): """Wrapper around np.amax that returns -infinity for an empty input.""" if inp.shape[axis] == 0: - return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf')) + if np.issubdtype(dtype, np.floating): + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + float('-inf')) + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + np.iinfo(dtype).min) return np.amax(inp, axis) - self._testReduction(math_ops.reduce_max, reference_max, np.float32, - self.FLOAT_DATA) + for dtype in set(self.all_types).intersection( + [np.float32, np.int32, np.int64]): + self._testReduction(math_ops.reduce_max, + functools.partial(reference_max, dtype), dtype, + self.REAL_DATA) def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, - self.NONEMPTY_FLOAT_DATA) + self.NONEMPTY_REAL_DATA) def testReduceMeanC64(self): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, @@ -137,5 +156,68 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) +class ReduceOpPrecisionTest(XLATestCase): + + def _testReduceSum(self, + expected_result, + dtype, + test_inputs, + rtol=1e-3, + atol=1e-4): + """Tests reduce sum on a list of input arrays. + + For each array in test_inputs, check that performing reduce sum on the array + produces a value that is close to the expected result. + + Args: + expected_result: the expected result. + dtype: the data type of the reduce sum operation. + test_inputs: a list of input arrays for the reduce sum operation. + rtol: the relative error. + atol: the absolute error. + """ + + for test_input in test_inputs: + with self.test_session() as sess: + with self.test_scope(): + a = array_ops.placeholder(dtype) + index = array_ops.placeholder(dtypes.int32) + out = math_ops.reduce_sum(a, index) + result = sess.run(out, { + a: np.array(test_input, dtype=dtype), + index: [0] + }) + # Compare the results using float32 type. + self.assertAllClose( + np.float32(result), + np.float32(expected_result), + rtol=rtol, + atol=atol) + + def testReduceSumF16(self): + """Tests the reduce sum of float16 doesn't lose too much precision.""" + + if np.float16 not in self.all_types: + return + + f16_max = np.finfo(np.float16).max + self._testReduceSum( + f16_max, np.float16, + itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3)) + + def testReduceSumBF16(self): + """Tests the reduce sum of bfloat16 doesn't lose too much precision.""" + + if dtypes.bfloat16.as_numpy_dtype not in self.all_types: + return + + bf16_max = np.float32(dtypes.bfloat16.max) + f32_max = dtypes.float32.max + value = min(bf16_max, f32_max - bf16_max) + self._testReduceSum( + dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, + itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e78a63465b80644d8810d9fa7433653bc4639fed --- /dev/null +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for xla.reduce_window.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class ReduceWindowTest(XLATestCase): + """Test cases for xla.reduce_window.""" + + def _reduce_window(self, operand, init, reducer, **kwargs): + with self.test_session(): + placeholder = array_ops.placeholder(operand.dtype) + with self.test_scope(): + output = xla.reduce_window(placeholder, init, reducer, **kwargs) + return output.eval(feed_dict={placeholder: operand}) + + def testReduceWindow(self): + + # TODO(b/77644762): float16 and float64 ReduceWindow are unimplemented. + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + self.assertAllClose( + np.array([3, 5, 7, 9, 11, 13], dtype=dtype), + self._reduce_window( + np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[2])) + + self.assertAllClose( + np.array([3, 7, 11], dtype=dtype), + self._reduce_window( + np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[2], + window_strides=[2])) + + self.assertAllClose( + np.array([1, 4, 7], dtype=dtype), + self._reduce_window( + np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[1], + window_strides=[3])) + + self.assertAllClose( + np.array([[24, 36, 24], [96, 0, 0]], dtype=dtype), + self._reduce_window( + np.array([[1, 2, 3, 4], [4, 3, 2, 1], [2, 4, 0, 1]], dtype=dtype), + 1.0, + mul_reducer, + window_dimensions=[2, 2], + window_strides=[1, 1])) + + self.assertAllClose( + np.array([[0, 0, 0], [5, 10, 5], [2, 4, 1], [0, 0, 0]], dtype=dtype), + self._reduce_window( + np.array([[1, 2, 3, 4], [4, 3, 2, 1], [2, 4, 0, 1]], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[2, 2], + window_strides=[2, 2], + padding=[[2, 3], [1, 2]])) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index a7cbfb04003c397212a35e16c6b23d7c2a18f7df..305ca0c6b78d3ef985deb38816f9388e7983906b 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -137,6 +138,34 @@ class StridedSliceTest(XLATestCase): self.assertAllEqual([6, 4], result) + def test2DDegenerate(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + o = array_ops.strided_slice(i, [-1, 0], [0, 3]) + params = { + i: [[0, 1, 2], + [3, 4, 5]] + } + result = o.eval(feed_dict=params) + + self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + + def test2DDegenerateNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) + params = { + i: [[0, 1, 2], + [3, 4, 5]] + } + result = o.eval(feed_dict=params) + + self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + def test3D(self): for dtype in self.numeric_types: with self.test_session(): diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index c013f4b50a4cf95be8028248c52b10b1c3be2bd3..f37c34156f96761632247be4bc1b62fca54f666e 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import test @@ -75,11 +76,11 @@ class SpaceToBatchTest(XLATestCase): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) - x_tf = gen_array_ops._space_to_batch( + x_tf = gen_array_ops.space_to_batch( placeholder, paddings, block_size=block_size) self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) # inputs = batch_to_space(outputs) - x_tf = gen_array_ops._batch_to_space( + x_tf = gen_array_ops.batch_to_space( placeholder, paddings, block_size=block_size) self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) @@ -156,14 +157,32 @@ class SpaceToBatchNDTest(XLATestCase): paddings = np.array(paddings).reshape((len(block_shape), 2)) with self.test_session() as sess, self.test_scope(): for dtype in self.float_types: + # TODO(b/68813416): Skip bfloat16's as the input type for direct is + # float32 and results in a mismatch, while making testDirect provide the + # correctly typed input results in 'no fill-function for data-type' + # error. + if dtype == dtypes.bfloat16.as_numpy_dtype: + continue + if dtype == np.float16: + actual_inputs = np.array(inputs).astype(dtype) + actual_paddings = np.array(paddings).astype(dtype) + expected_outputs = np.array(outputs).astype(dtype) + else: + actual_inputs = inputs + actual_paddings = paddings + expected_outputs = outputs placeholder = array_ops.placeholder(dtype) # outputs = space_to_batch(inputs) - x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) - self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, + actual_paddings) + self.assertAllEqual( + sess.run(x_tf, {placeholder: actual_inputs}), expected_outputs) # inputs = batch_to_space(outputs) placeholder = array_ops.placeholder(dtype) - x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings) - self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, + actual_paddings) + self.assertAllEqual( + sess.run(x_tf, {placeholder: expected_outputs}), actual_inputs) def _testDirect(self, input_shape, block_shape, paddings): inputs = np.arange(np.prod(input_shape), dtype=np.float32) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 2b9c2279737ccee531d488d27ccdb0cafa1dc8fc..94342f9567ca71274609e63b0482d55637c98d51 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -34,33 +34,33 @@ class StackOpTest(XLATestCase): with self.test_session(), self.test_scope(): size = array_ops.placeholder(dtypes.int32) v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, v) + h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): with self.test_session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True) + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, x, swap_memory=True) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose(a, c1.eval({x: a})) def testMultiStack(self): with self.test_session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, v) + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_push_v2(h1, v) with ops.control_dependencies([c1]): - c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar") - c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + c1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="bar") + c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) with ops.control_dependencies([c2]): - c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + c2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) r = c1 + c2 self.assertAllClose(9.0, r.eval({v: 4.0})) @@ -69,15 +69,15 @@ class StackOpTest(XLATestCase): with self.test_session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, v1) + c1 = gen_data_flow_ops.stack_push_v2(h1, v1) with ops.control_dependencies([c1]): - c2 = gen_data_flow_ops._stack_push_v2(h2, v2) + c2 = gen_data_flow_ops.stack_push_v2(h2, v2) with ops.control_dependencies([c2]): - pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0}) self.assertAllClose(out1, 4.0) @@ -86,17 +86,17 @@ class StackOpTest(XLATestCase): def testCloseStack(self): with self.test_session() as sess, self.test_scope(): size = array_ops.placeholder(dtypes.int32) - h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_close_v2(h) + h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {size: 5}) def testPushCloseStack(self): with self.test_session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, v) + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_close_v2(h) + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {v: [[4.0, 5.0]]}) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 4336ebdbd184a081619f0a6951dd4514735c6eb6..b6f8390a45d43bf7666b90e14cc6ff2f3f61947e 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase): # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) + def testRandomNormalIsFinite(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[10000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(np.isfinite(y))) + def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index a62925a1818da00cb0a9e82e1281db20fb38b208..f332aa2e9b97e13654cf9b10588c18fed32f7ad4 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -338,7 +338,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0 = ta.write(0, [[4.0, 5.0]]) # Test reading wrong datatype. - r0_bad = gen_data_flow_ops._tensor_array_read_v3( + r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): r0_bad.eval() @@ -472,7 +472,9 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) def testTensorArrayGradientWriteRead(self): - for dtype in self.numeric_types: + for dtype in self.float_types: + self._testTensorArrayGradientWriteReadType(dtype) + for dtype in self.complex_types: self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ba5f829936fd82ca0cc53eda34aefbca6d80482b..ef047005b60bd156a677050368ef67ae030d6c3a 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -68,40 +69,41 @@ class TernaryOpsTest(XLATestCase): expected=np.array([1, 3, 5], dtype=np.int32)) def testSelect(self): - self._testTernary( - array_ops.where, - np.array(0, dtype=np.bool), - np.array(2, dtype=np.float32), - np.array(7, dtype=np.float32), - expected=np.array(7, dtype=np.float32)) + for dtype in self.numeric_types: + self._testTernary( + array_ops.where, + np.array(0, dtype=np.bool), + np.array(2, dtype=dtype), + np.array(7, dtype=dtype), + expected=np.array(7, dtype=dtype)) - self._testTernary( - array_ops.where, - np.array(1, dtype=np.bool), - np.array([1, 2, 3, 4], dtype=np.float32), - np.array([5, 6, 7, 8], dtype=np.float32), - expected=np.array([1, 2, 3, 4], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array(1, dtype=np.bool), + np.array([1, 2, 3, 4], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([1, 2, 3, 4], dtype=dtype)) - self._testTernary( - array_ops.where, - np.array(0, dtype=np.bool), - np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), - np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), - expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array(0, dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) - self._testTernary( - array_ops.where, - np.array([0, 1, 1, 0], dtype=np.bool), - np.array([1, 2, 3, 4], dtype=np.float32), - np.array([5, 6, 7, 8], dtype=np.float32), - expected=np.array([5, 2, 3, 8], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array([0, 1, 1, 0], dtype=np.bool), + np.array([1, 2, 3, 4], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([5, 2, 3, 8], dtype=dtype)) - self._testTernary( - array_ops.where, - np.array([0, 1, 0], dtype=np.bool), - np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), - np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), - expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array([0, 1, 0], dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype)) def testSlice(self): for dtype in self.numeric_types: @@ -119,6 +121,23 @@ class TernaryOpsTest(XLATestCase): np.array([2, 1], dtype=np.int32), expected=np.array([[2], [5]], dtype=dtype)) + def testClipByValue(self): + # TODO(b/78258593): enable integer types here too. + for dtype in self.float_types: + test_cases = [ + (np.array([2, 4, 5], dtype=dtype), dtype(7)), # + (dtype(1), np.array([2, 4, 5], dtype=dtype)), # + (np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype)) + ] + x = np.array([-2, 10, 6], dtype=dtype) + for lower, upper in test_cases: + self._testTernary( + gen_math_ops._clip_by_value, + x, + lower, + upper, + expected=np.minimum(np.maximum(x, lower), upper)) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 3d3e112f4821ea8e57ea9589a5b4433647ad294b..689a4a1f4e02f5dd48f64dc94afd0fcb50df8b5b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -154,6 +154,9 @@ class UnaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018. + if dtype == np.float16 and self.device == "XLA_CPU": + continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( math_ops.acos, @@ -206,7 +209,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype)) + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -248,12 +252,12 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2]], dtype=dtype), expected=np.array([[0.540297, -0.41614]], dtype=dtype)) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.rint, @@ -330,13 +334,19 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, @@ -416,7 +426,9 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), - expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)), + rtol=1e-6, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.reciprocal, @@ -438,13 +450,13 @@ class UnaryOpsTest(XLATestCase): np.array([[5j, 3 - 2j]], dtype=dtype), expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) self._assertOpOutputMatchesExpected( @@ -600,6 +612,20 @@ class UnaryOpsTest(XLATestCase): src, expected=dst) + def testBitcast(self): + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int32), + np.array([1, 0x3f800000], np.int32), + expected=np.array([1, 0x3f800000], np.int32)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.float32), + np.array([1, 0x3f800000], np.int32), + expected=np.array([1e-45, 1.0], np.float32)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int32), + np.array([1e-45, 1.0], np.float32), + expected=np.array([1, 0x3f800000], np.int32)) + def testInvertPermutation(self): self._assertOpOutputMatchesExpected( array_ops.invert_permutation, @@ -772,14 +798,19 @@ class UnaryOpsTest(XLATestCase): zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected) + nn_ops.softplus, features, expected=expected, + rtol=1e-6, + atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) self._assertSoftplusMatchesExpected( [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) - log_eps = np.log(np.finfo(dtype).eps) + if dtype == dtypes.bfloat16.as_numpy_dtype: + log_eps = np.log(np.finfo(np.float32).eps) + else: + log_eps = np.log(np.finfo(dtype).eps) one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index b08d6ab21e0746558cb3d4818d4c822c45d2e9ee..2c09b03d5a35cde2c42d8a145781270c0c908587 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -187,6 +187,25 @@ class VariableOpsTest(XLATestCase): rtol=1e-4) self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) + def testWriteOfAliasedTensor(self): + for dtype in self.numeric_types: + init = np.array([[1, 2j], [3, 4]]).astype(dtype) + update = np.array([[7, 1j], [2, 11]]).astype(dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + q = array_ops.identity(p) + x = v.read_value() + # Writes the value of 'p' to 'v', but keeps a reference to the original + # value of 'v' so the variable update cannot reuse its buffer. + with ops.control_dependencies([x]): + y = v.assign(q) + result = sess.run([x, y, q], {p: update}) + self.assertAllClose(init, result[0]) + self.assertAllClose(update, result[1]) + self.assertAllClose(update, result[2]) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" @@ -230,7 +249,10 @@ class SliceAssignTest(XLATestCase): # shrink shape changes checker[1:2, 1] = [66] checker[1, 1:2] = [66] - checker[1, 1] = 66 + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker[1, 1] = 66 # newaxis shape changes checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] # shrink and newaxis @@ -243,8 +265,11 @@ class SliceAssignTest(XLATestCase): # Assign vector to scalar (rank-0) using newaxis checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) - checker2[()] = 6 # no indices - checker2[...] = 6 # ellipsis + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis checker2[None] = [6] # new axis def testUninitialized(self): diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f79eb27435cc954cebde4357c1d946a320f4ed75 --- /dev/null +++ b/tensorflow/compiler/tests/while_test.py @@ -0,0 +1,130 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for while loops in XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class WhileTest(XLATestCase): + + def testSingletonLoopHandrolled(self): + # Define a function for the loop body + @function.Defun(dtypes.int32) + def loop_body(step): + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + return step_out + + # Define a function for the loop condition + @function.Defun(dtypes.int32) + def loop_cond(step): + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) + + result = sess.run(loop_outputs, {init_index: 0}) + self.assertAllClose(result, [10], rtol=1e-3) + + def testCountingLoopHandrolled(self): + # Define a function for the loop body + @function.Defun(dtypes.int32, dtypes.float32) + def loop_body(step, rsum): + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + sum_out = rsum + constant_op.constant(1.5, dtype=dtypes.float32) + return step_out, sum_out + + # Define a function for the loop condition + @function.Defun(dtypes.int32, dtypes.float32) + def loop_cond(step, rsum): + del rsum + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + init_sum = array_ops.placeholder(dtypes.float32, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index, init_sum], loop_cond, + loop_body) + + result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0}) + self.assertAllClose(result, [10, 15.0], rtol=1e-3) + no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0}) + self.assertAllClose(no_iters_result, [10, 0.0], rtol=1e-3) + + def testCountingLoopHandrolledC64(self): + # Define a function for the loop body + @function.Defun(dtypes.int32, dtypes.complex64) + def loop_body(step, rsum): + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + sum_out = rsum + constant_op.constant(1.5 + 2j, dtype=dtypes.complex64) + return step_out, sum_out + + # Define a function for the loop condition + @function.Defun(dtypes.int32, dtypes.complex64) + def loop_cond(step, rsum): + del rsum + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + init_sum = array_ops.placeholder(dtypes.complex64, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index, init_sum], loop_cond, + loop_body) + + result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0}) + self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3) + no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0}) + self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3) + + def testLoopWithConstantOutput(self): + # Define a function for the loop body + @function.Defun(dtypes.int32, dtypes.int32) + def loop_body(step, x): + del x + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + return (step_out, 7) + + # Define a function for the loop condition + @function.Defun(dtypes.int32, dtypes.int32) + def loop_cond(step, x): + del x + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) + + result = sess.run(loop_outputs, {init_index: 0}) + self.assertAllClose(result, [10, 7], rtol=1e-3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/for_loops_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py similarity index 52% rename from tensorflow/contrib/py2tf/converters/for_loops_test.py rename to tensorflow/compiler/tests/xla_device_gpu_test.py index 70a367d3b517e528b67f260d607431d324d2ab7d..1e30ebd55d09fe00449fb67b92a8325f5809d89a 100644 --- a/tensorflow/contrib/py2tf/converters/for_loops_test.py +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -12,36 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for for_loops module.""" +"""Test cases for XLA devices.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import for_loops +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ControlFlowTest(converter_test_base.TestCase): +class XlaDeviceGpuTest(test.TestCase): - def test_basic_for(self): + def testCopiesToAndFromGpuWork(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return - def test_fn(l): - s = 0 - for e in l: - s += e - return s + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) - node = self.parse_and_analyze(test_fn, {}) - node = for_loops.transform(node, self.ctx) - with self.compiled(node) as result: - l = [1, 2, 3] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) - - -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f5c228f8305d740b994dadc34c93b4e0ae32d785..f0b010fa67f2ffb3f81fd14d4d89585f716b4890 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,30 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(test.TestCase): +class XlaDeviceTest(XLATestCase): def testCopies(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return - - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) + """Tests that copies onto and off XLA devices work.""" + shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], + [16384, 1], [1, 16384], [1, 20000, 1, 1]] + for dtype in self.numeric_types: + for shape in shapes: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(dtype, shape) + with self.test_scope(): + y = x + x + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape).astype(dtype) + result = sess.run(z, {x: inputs}) + self.assertAllCloseAccordingToType(result, inputs + inputs) + + def testControlTrigger(self): + with self.test_session() as sess: + with self.test_scope(): + x = gen_control_flow_ops.control_trigger() + sess.run(x) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 7e1f5c76ed65946363cc3c113ab1a9862f87b289..e924fe1e61454aefda622a5a46a0e483d26db5c1 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import contextlib +import os import random import re @@ -44,6 +45,8 @@ flags.DEFINE_string('test_device', None, flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') flags.DEFINE_string('disabled_manifest', None, 'Path to a file with a list of tests that should not run.') +flags.DEFINE_string('tf_xla_flags', None, + 'Value to set the TF_XLA_FLAGS environment variable to') class XLATestCase(test.TestCase): @@ -71,14 +74,14 @@ class XLATestCase(test.TestCase): self._all_types = set( [dtype.as_numpy_dtype for dtype in self._all_tf_types]) - self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ dtype.as_numpy_dtype for dtype in self.complex_tf_types ]) - self._numeric_types = set( - self.int_types | self._float_types | self.complex_types) + self._numeric_types = set(self._int_types | self._float_types + | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable @@ -97,6 +100,8 @@ class XLATestCase(test.TestCase): disabled_tests = [] disabled_method_types = [] for l in manifest_file.read().splitlines(): + if not l: + continue entry = comments_re.sub('', l).strip().split(' ') if len(entry) == 1: disabled_tests.append(entry[0]) @@ -113,6 +118,9 @@ class XLATestCase(test.TestCase): for name in types]) manifest_file.close() + if FLAGS.tf_xla_flags is not None: + os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags + @property def all_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) @@ -130,6 +138,11 @@ class XLATestCase(test.TestCase): name = '{}.{}'.format(type(self).__name__, self._testMethodName) return self._float_tf_types - self._method_types_filter.get(name, set()) + @property + def int_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._int_types - self._method_types_filter.get(name, set()) + @property def numeric_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index fb82c2601c432cee425a46a3b6dc2c55febeda87..cd57452302fcbde37d79ce760a80615a76d7ad8c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -58,6 +58,15 @@ xla_proto_library( ], ) +xla_proto_library( + name = "host_compute_metadata_proto", + srcs = ["host_compute_metadata.proto"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "tf2xla", srcs = ["tf2xla.cc"], @@ -72,7 +81,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -149,6 +158,7 @@ cc_library( ":common", ":dump_graph", ":functionalize_control_flow", + ":host_compute_metadata_proto", ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", @@ -158,9 +168,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -205,7 +215,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -316,12 +325,14 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -401,10 +412,9 @@ cc_library( hdrs = ["functionalize_control_flow.h"], deps = [ ":tf2xla_util", - "//tensorflow/compiler/jit:graph_to_functiondef", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", - "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -426,7 +436,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", - "//tensorflow/compiler/tf2xla/cc:functional_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -452,17 +462,3 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 311dddca94c458a60fd00afe5532840e0dbf0437..ea8d1b3d14939d4f4fba598318200f71c2eb0270 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -7,21 +7,20 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") tf_gen_op_wrapper_cc( - name = "functional_ops_gen", - include_internal_ops = 1, - out_ops_file = "ops/functional_ops", - deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"], + name = "xla_ops_gen", + out_ops_file = "ops/xla_ops", + deps = ["//tensorflow/compiler/tf2xla/ops:xla_ops"], ) cc_library( - name = "functional_ops", - srcs = ["ops/functional_ops.cc"], - hdrs = ["ops/functional_ops.h"], + name = "xla_ops", + srcs = ["ops/xla_ops.cc"], + hdrs = ["ops/xla_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", "//tensorflow/cc:scope", - "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -30,38 +29,23 @@ cc_library( ) tf_gen_op_wrapper_cc( - name = "sendrecv_ops_gen", - include_internal_ops = 1, - out_ops_file = "ops/sendrecv_ops", - deps = ["//tensorflow/compiler/tf2xla/ops:sendrecv_ops"], + name = "xla_jit_op_gen", + out_ops_file = "ops/xla_jit_op", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( - name = "sendrecv_ops", - srcs = ["ops/sendrecv_ops.cc"], - hdrs = ["ops/sendrecv_ops.h"], + name = "xla_jit_ops", + srcs = ["ops/xla_jit_op.cc"], + hdrs = ["ops/xla_jit_op.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", "//tensorflow/cc:scope", - "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 82923722c54d235716b9138d95a75a441df924ca..de1008803d69fefa415c7bdbe6c27a62e625b417 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -37,7 +37,7 @@ Status BackwardsConstAnalysis(const Graph& g, }; Status status; - std::unordered_set must_be_const; + std::unordered_set must_be_const; auto visit = [&status, &metadata_ops, &must_be_const, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -55,8 +55,10 @@ Status BackwardsConstAnalysis(const Graph& g, compile_time_const_args->at(index) = true; return; } - for (Node* pred : node->in_nodes()) { - must_be_const.insert(pred); + for (const Edge* pred : node->in_edges()) { + if (!pred->IsControlEdge()) { + must_be_const.insert(pred->src()); + } } return; } diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 9d125f8d499863cfaa0e26b5b633ca02914d1b7d..992b12c06db5efc0ae54284d0ea77017c1c79aca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -79,5 +79,24 @@ TEST(ConstAnalysisTest, TopologicalOrder) { } } +TEST(ConstAnalysisTest, DontFollowControlDependencies) { + Scope root = Scope::NewRootScope(); + + Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + Output c1 = + ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1}); + Output add = ops::Add(root, arg1, c1); + Output reshape = ops::Reshape(root, arg1, add); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(2, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + + EXPECT_EQ(const_args, std::vector({false, true})); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f8169795ddfb7fd4e93d3f136c51623385868951..1438f6b48c4913e60b0c0a9f5c3d67fe595cbfe8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" @@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, return Status::OK(); } -Status FunctionalizeLoop(Graph* graph, Frame* frame, +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function refered by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << dump_graph::DumpGraphToFile("functionalize_before", *graph, @@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } // Builds a While operator. NodeDef while_def; @@ -583,13 +642,15 @@ class FunctionalizeCond { // CondArgNode represents a input to the conditional and its corresponding // switch nodes. struct CondArgNode { - explicit CondArgNode(Node* input) : input(input) {} + explicit CondArgNode(Node* src, int src_output) + : src(src), src_output(src_output) {} string ToString() const { - return strings::StrCat("input=", input->name(), + return strings::StrCat("src=", src->name(), ":", src_output, " switches=", NodesToString(switches)); } - Node* input; + Node* src; + int src_output; std::vector switches; }; using CondArgNodes = std::vector; @@ -606,14 +667,15 @@ class FunctionalizeCond { // Group of switch nodes that will be part of the same XlaIf. struct SwitchCluster { - explicit SwitchCluster(Node* predicate) : predicate(predicate) {} + explicit SwitchCluster(const Edge* predicate_edge) + : predicate_edge(predicate_edge) {} string ToString() const { - return strings::StrCat(name, " predicate=", predicate->name(), + return strings::StrCat(name, " predicate=", predicate_edge->src()->name(), " switches=", NodesToString(switches)); } string name; - Node* predicate; + const Edge* predicate_edge; std::vector switches; }; @@ -653,8 +715,8 @@ class FunctionalizeCond { Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate, - Node* if_node); + Status AddInputEdges(const CondArgNodes& cond_arg_nodes, + const Edge* predicate_edge, Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); @@ -756,8 +818,8 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, if (IsMerge(dst)) { dst_state->branch = Branch::kBoth; } else { - return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", - dst_state->ToString(), " for ", + return errors::Internal("Illegal merge:\n", src_state.ToString(), + " with ", dst_state->ToString(), " for\n", dst->DebugString()); } } @@ -861,12 +923,15 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { if (IsSwitch(n)) { Node* input; TF_CHECK_OK(n->input_node(0, &input)); - entry_cluster[n->id()] = &clusters[input->id()]; - UnionFind* cluster = find_output_cluster(input); + entry_cluster[n->id()] = find_output_cluster(input); + UnionFind* cluster = entry_cluster[n->id()]; int cluster_depth = switch_depth[cluster->Get().representative]; // Merge the inputs of the switch node with one another. This results in // predicates and control input residing in the same cluster. for (const Edge* e : n->in_edges()) { + // Only consider the data inputs to the Switch node. + if (e->IsControlEdge()) continue; + Node* src = e->src(); UnionFind* src_cluster = find_output_cluster(src); int src_cluster_depth = switch_depth[src_cluster->Get().representative]; @@ -898,6 +963,14 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { int src_depth = switch_depth[src_id]; if (!e->IsControlEdge() || new_switch_depth == src_depth) { if (src_depth != new_switch_depth) { + // TODO(b/77601805) remove this when outside_compilation supports + // control flow. + if (str_util::StrContains(src->name(), "outside_compilation") || + str_util::StrContains(n->name(), "outside_compilation")) { + return errors::InvalidArgument( + "outside_compilation is not yet supported within TensorFlow " + "control flow constructs b/77601805"); + } return errors::InvalidArgument( "Unable to functionalize control flow in graph: Operand ('", src->name(), "') and operator ('", n->name(), @@ -956,16 +1029,21 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { // node whose cluster is later in the topological order of clustered // switches). for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { - Node* pred; - TF_CHECK_OK((*it)->input_node(1, &pred)); - auto repr = std::make_pair(pred, clusters[(*it)->id()].Get()); + const Edge* pred_edge; + TF_CHECK_OK((*it)->input_edge(1, &pred_edge)); + // The predicate can be preceded by a identity node. Look through identity + // nodes to predicate. + while (pred_edge->src()->IsIdentity()) { + TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge)); + } + auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get()); if (predicate_index.find(repr) == predicate_index.end()) { predicate_index[repr] = switch_clusters.size(); - switch_clusters.emplace_back(pred); + switch_clusters.emplace_back(pred_edge); // Generate a name by concatenating with the cluster representative as // there could be multiple switch clusters with the same predicate. - switch_clusters[predicate_index[repr]].name = - strings::StrCat(pred->name(), "_", repr.second.representative, "_If"); + switch_clusters[predicate_index[repr]].name = strings::StrCat( + pred_edge->src()->name(), "_", repr.second.representative, "_If"); } switch_clusters[predicate_index[repr]].switches.push_back(*it); } @@ -1044,9 +1122,12 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( ForwardFlowNode& ffn = branch_map[out]; if (IsSwitch(n)) { int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); - TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ", + e->DebugString()); } else { - TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn), + " when joining ", e->DebugString()); } if (IsMerge(out)) { if (out->in_edges().size() == ffn.count) { @@ -1083,8 +1164,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { for (auto it = predicate_switch_order.rbegin(); it != predicate_switch_order.rend(); ++it) { auto& ps = *it; - VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " (" - << ps.predicate->name() << ")"; + VLOG(3) << "Flow down from: " << ps.ToString(); std::unordered_map branch_map; std::unordered_set frontier; @@ -1097,21 +1177,29 @@ Status FunctionalizeCond::FunctionalizeInternal() { library_); TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second)); + } + }; + // Sort the merge and switch nodes using NodeCmp. The switch-nodes are // further grouped (post sorting) by input to the switch node as in the // functionalized form each input will be passed in only once. This grouping // should retain the sorted order. CondArgNodes cond_arg_nodes; - std::unordered_map input_index; std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); + std::unordered_map, int, Hash> input_index; for (Node* switch_node : ps.switches) { - Node* in; - TF_RETURN_IF_ERROR(switch_node->input_node(0, &in)); - if (input_index.find(in) == input_index.end()) { - input_index[in] = cond_arg_nodes.size(); - cond_arg_nodes.emplace_back(in); + const Edge* e; + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); + std::pair key = std::make_pair(e->src(), e->src_output()); + if (input_index.find(key) == input_index.end()) { + input_index[key] = cond_arg_nodes.size(); + cond_arg_nodes.emplace_back(key.first, key.second); } - cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node); + cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node); } std::vector merge_nodes(frontier.begin(), frontier.end()); std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); @@ -1200,11 +1288,12 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(switch_cluster.predicate->assigned_device_name()); + builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name()); // Conditional should be the first input ... - builder.Input( - NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0, - switch_cluster.predicate->output_type(0))); + builder.Input(NodeDefBuilder::NodeOut( + switch_cluster.predicate_edge->src()->name(), + switch_cluster.predicate_edge->src_output(), + switch_cluster.predicate_edge->src()->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1264,24 +1353,17 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, } Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, - Node* predicate, Node* if_node) { + const Edge* predicate_edge, + Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); int index = 0; - graph_->AddEdge(predicate, 0, if_node, index++); - for (auto& kv : cond_arg_nodes) { - bool inserted = false; - for (const Node* arg : kv.switches) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph_->AddControlEdge(in_edge->src(), if_node); - } else { - if (!inserted) { - graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, - index++); - inserted = true; - } - } + graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node, + index++); + for (auto& arg : cond_arg_nodes) { + if (arg.src_output == Graph::kControlSlot) { + graph_->AddControlEdge(arg.src, if_node); + } else { + graph_->AddEdge(arg.src, arg.src_output, if_node, index++); } } return Status::OK(); @@ -1302,10 +1384,10 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return errors::Unimplemented("Output of index (", edge->src_output(), ") of merge node ", node->name()); } - graph_->RemoveEdge(edge); int src_output = dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph_->RemoveEdge(edge); graph_->AddEdge(if_node, src_output, dst, dst_input); } } @@ -1323,7 +1405,7 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( Node * if_node, BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); TF_RETURN_IF_ERROR( - AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node)); + AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return if_node; @@ -1342,14 +1424,27 @@ Status FunctionalizeCond::Functionalize(Graph* graph, // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); + // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + std::vector unreachable_nodes; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + tensorflow::str_util::Join(unreachable_nodes, ", ")); + } // Builds Frames, indexed by name. std::unordered_map frames; @@ -1410,7 +1505,8 @@ Status FunctionalizeControlFlow(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 4d4ee3054c2914bb614bf75f7a51be8f6292683e..d941041d15532446d1413f16fe64602bfb1a7daa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -22,9 +22,13 @@ limitations under the License. namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index bc7276c3afd5060d6faeceb4d479416299ecc5da..14977a908ae2b0ff7e13b634c41b6d331b4b8a36 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" @@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { + FunctionDef fdef = FunctionDefHelper::Create( + "increment_fn", {"x:int32"}, {"add:int32"}, {}, + { + {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, + {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, + }, + {{"add", "add_0:z:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); + NodeDef increment_fn; + increment_fn.set_name(node_name); + increment_fn.set_op("increment_fn"); + *increment_fn.add_input() = "while/Identity"; + *increment_fn.add_input() = "^while/Identity"; + Status status; + graph->AddNode(increment_fn, &status); + return status; +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) +TEST(FunctionalizeControlFlow, NoinlineLoopBody) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, + "while/while_context"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + + NodeDef next_iter; + next_iter.set_name("while/NextIteration"); + next_iter.set_op("NextIteration"); + *next_iter.add_input() = noinline_node_name; + (*next_iter.mutable_attr())["T"].set_type(DT_INT32); + + Status status; + Node* n = scope.graph()->AddNode(next_iter, &status); + TF_ASSERT_OK(status); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(n, 0, merge.output.node(), 1); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // Function increment_fn will be copied from lookup_lib to library. + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Tests functionalizing OneLoopVar where the loop value is not used post the // loop. // Graph: diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md index 91351421bcacd26c41b5c9f98ea833730e4aef30..20179b67991d3d23d678cf1df2642e029ea037fd 100644 --- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -186,6 +198,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -198,6 +211,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md index b9bdb829d773825005a8921f48d28b6892d8f0cd..55f0538dba7c1941dfea88e0631cd299e51f76d0 100644 --- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -183,6 +195,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -195,6 +208,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 058a1f2621c64a735bd9d9c9d0ae007f93aa4dea..212f6f3966149ca0b2d2e012b19300e1f488f996 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { auto builder = ctx->builder(); + auto client = ctx->compiler()->client(); std::vector compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( @@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.kind = XlaCompiler::Argument::kConstant; TF_RET_CHECK(expressions[i]->resource() == nullptr) << "Input with resource is not yet implemented."; + TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( + expressions[i]->handle())); TF_ASSIGN_OR_RETURN(auto literal, - builder->ComputeConstant(expressions[i]->handle())); + client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); } else { @@ -130,7 +133,7 @@ Status GraphCompiler::Compile() { // Set up inputs from outputs of previous nodes. for (auto* e : n->in_edges()) { if (e->IsControlEdge()) continue; - Node* src = e->src(); + const Node* src = e->src(); TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; @@ -205,14 +208,15 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; XlaCompiler::CompilationResult result; - - TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), - func, arguments, &result)); + TF_RETURN_IF_ERROR( + compiler->CompileFunction(compile_options, func, arguments, &result)); TF_RET_CHECK(arguments.size() == expressions.size()); - std::vector handles; + std::vector handles; for (int64 i = 0; i < expressions.size(); ++i) { if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; @@ -226,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto output_handle = b->Call(*result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. + int computation_output = 0; for (int64 i = 0; i < n->num_outputs(); ++i) { if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + xla_op_context.SetOutput( + i, b->GetTupleElement(output_handle, computation_output)); + ++computation_output; } } return b->first_error(); diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto new file mode 100644 index 0000000000000000000000000000000000000000..43ab371a217e6c4521a160715104c96e3c8782c6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/host_compute_metadata.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +package tensorflow.tf2xla; +option cc_enable_arenas = true; +option java_outer_classname = "Tf2XlaProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.tf2xla"; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// TensorMetadata indicates the type and shape of a Tensor that is +// part of a host compute transfer. +message TensorMetadata { + DataType type = 1; + TensorShapeProto shape = 2; +} + +// HostTransferMetadata describes a transfer either from host to device +// or device to host. It has a key that is unique to the computation, +// and metadata about the list of tensors being transferred. +message HostTransferMetadata { + // The key used to identify this transfer. + string key = 1; + + // For each Tensor being transferred, its type and shape. + repeated TensorMetadata metadata = 2; +} + +// HostComputeMetadata describes all the sends and recvs +// from all host compute transfer ops in a computation. +message HostComputeMetadata { + // Metadata about each device_to_host transfer + repeated HostTransferMetadata device_to_host = 1; + + // Metadata about each host_to_device transfer + repeated HostTransferMetadata host_to_device = 2; +} diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d2fa933cf9c085f92b2f442827a94d72938e4bb2..edd2ab6301ee891c433639ce300cde0c72929cea 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,9 +18,11 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", + "clip_by_value_op.cc", "concat_op.cc", "const_op.cc", "conv_ops.cc", @@ -29,6 +31,7 @@ tf_kernel_library( "cwise_ops.h", "depthtospace_op.cc", "diag_op.cc", + "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", "elu_op.cc", "extract_image_patches_op.cc", @@ -43,6 +46,7 @@ tf_kernel_library( "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", + "listdiff_op.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", @@ -56,6 +60,7 @@ tf_kernel_library( "pooling_ops.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", + "reduce_window_op.cc", "reduction_ops.cc", "reduction_ops.h", "reduction_ops_common.cc", @@ -93,6 +98,7 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -102,7 +108,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:while_loop", - "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -110,8 +116,8 @@ tf_kernel_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -145,15 +151,48 @@ tf_kernel_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) +tf_kernel_library( + name = "if_op", + srcs = ["if_op.cc"], + hdrs = ["if_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# Kernels that have a dummy (no-op) implementation. +tf_kernel_library( + name = "xla_dummy_ops", + srcs = [ + "assert_op.cc", + "check_numerics_op.cc", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:logging_ops_op_lib", + ], + alwayslink = 1, +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. tf_kernel_library( @@ -166,8 +205,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", @@ -200,17 +239,3 @@ cc_library( ], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 5c9f66df101bfb731d6114c23933e241af5dcbeb..1e59868621475cf72f4cc8b14dafec2dd8cd5c95 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel { OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("AddN requires at least one argument")); - xla::ComputationDataHandle sum = ctx->Input(0); + xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { sum = ctx->builder()->Add(sum, ctx->Input(i)); } diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..af4ab5e8ef6e268226edc90515706405ac36858c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/assert_op.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +namespace { + +// This TensorFlow op supports the Assert primitve. +class AssertOp : public XlaOpKernel { + public: + explicit AssertOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + ~AssertOp() override {} + + void Compile(XlaOpKernelContext* ctx) override { + static mutex mu(tensorflow::LINKER_INITIALIZED); + static int log_counter = 0; + + mutex_lock l(mu); + if (log_counter < 20) { + ++log_counter; + LOG(WARNING) << "Ignoring Assert operator " << name(); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(AssertOp); +}; + +REGISTER_XLA_OP(Name("Assert"), AssertOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a249b1869f547f8e5aa725f9f5cf391b10429928..15e1815a4cf07ff50dd1431b6790d14781da590f 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); int feature_index = @@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel { input = builder->ConvertElementType(input, scale_type); if (is_training_) { - xla::ComputationDataHandle output = builder->BatchNormTraining( + xla::XlaOp output = builder->BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the @@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel { ctx->SetOutput(3, builder->GetTupleElement(output, 1)); ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = builder->BatchNormInference( + xla::XlaOp output = builder->BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); @@ -118,36 +118,30 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); - - auto grad_backprop = ctx->Input(0); - auto activations = ctx->Input(1); - auto scale = ctx->Input(2); - auto mean = ctx->Input(3); - auto var = ctx->Input(4); - - TensorShape input_shape = ctx->InputShape(0); - int feature_index = - GetTensorFeatureDimIndex(input_shape.dims(), data_format_); - + xla::XlaBuilder* const b = ctx->builder(); DataType input_dtype = ctx->input_type(0); DataType scale_dtype = ctx->input_type(2); - xla::PrimitiveType input_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); - xla::PrimitiveType scale_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). - grad_backprop = b->ConvertElementType(grad_backprop, scale_type); - activations = b->ConvertElementType(activations, scale_type); + auto grad_backprop = + XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + auto activations = + XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + auto scale = ctx->Input(2); + auto mean = ctx->Input(3); + auto var = ctx->Input(4); + + const int input_dims = ctx->InputShape(0).dims(); + const int feature_index = + GetTensorFeatureDimIndex(input_dims, data_format_); - xla::ComputationDataHandle x_backprop; - xla::ComputationDataHandle scale_backprop; - xla::ComputationDataHandle offset_backprop; + xla::XlaOp x_backprop; + xla::XlaOp scale_backprop; + xla::XlaOp offset_backprop; if (is_training_) { - xla::ComputationDataHandle output = + xla::XlaOp output = b->BatchNormGrad(activations, scale, mean, var, grad_backprop, epsilon_, feature_index); @@ -156,7 +150,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = b->GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. - std::vector reduction_dims(input_shape.dims() - 1); + std::vector reduction_dims(input_dims - 1); std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, 0); std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), @@ -165,9 +159,14 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + // epsilon)) // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - offset_backprop = - b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), - *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(scale_dtype); + auto converted = + XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -175,17 +174,21 @@ class FusedBatchNormGradOp : public XlaOpKernel { b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) - auto scratch2 = b->Reduce( - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), - XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), - reduction_dims); + auto mul = + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(0, + XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 344a2ab2b6835c518c41de6f7a30fb2a34d130d2..642278ab994bf3cc84396f093ed56b009a1435c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -20,9 +20,8 @@ limitations under the License. namespace tensorflow { namespace { -void BatchToSpace(XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, DataType input_dtype, - const TensorShape& input_tensor_shape, +void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, + DataType input_dtype, const TensorShape& input_tensor_shape, gtl::ArraySlice block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); @@ -46,7 +45,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -73,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -91,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = b->Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -111,8 +110,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::ComputationDataHandle reshaped_permuted = - b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -139,7 +137,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, "Cropped size must be non-negative: start: ", crop_start, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } - xla::ComputationDataHandle output = + xla::XlaOp output = b->Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } @@ -159,7 +157,9 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), +REGISTER_XLA_OP(Name("BatchToSpaceND") + .CompileTimeConstInput("block_shape") + .CompileTimeConstInput("crops"), BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { @@ -182,9 +182,7 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace") - .CompileTimeConstInput("crops") - .CompileTimeConstInput("block_shape"), +REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstInput("crops"), BatchToSpaceOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index c667b4e3e326b776faba49387760abbd582fcc68..9d677f426650ea17a49e5ab1401078f04623fe97 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -60,7 +60,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::ComputationDataHandle result = + xla::XlaOp result = ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -103,10 +103,15 @@ class BiasAddGradOp : public XlaOpKernel { std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), feature_dim + 1); - xla::ComputationDataHandle result = ctx->builder()->Reduce( - ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), reduce_dims); - ctx->SetOutput(0, result); + xla::XlaBuilder* const b = ctx->builder(); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2436a6074a11ad66387b232dd1c5aa135875bfc3..f04cde878e98002d9442e0f3ec251c5197ef7969 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -34,14 +34,13 @@ namespace { class NAME##Op : public XlaBinaryOp { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::ComputationDataHandle Computation( \ - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \ - const gtl::ArraySlice& lhs_shape, \ - const xla::ComputationDataHandle& rhs, \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ const gtl::ArraySlice& rhs_shape, \ const BCast& broadcast_helper, \ const std::vector& extend_dimensions) override { \ - xla::ComputationBuilder* b = ctx->builder(); \ + xla::XlaBuilder* b = ctx->builder(); \ return HLO; \ } \ }; \ @@ -63,11 +62,8 @@ XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); // } else { // return x / y; // } -static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, - DataType dtype, - xla::ComputationDataHandle x, - xla::ComputationDataHandle y, - const BCast& broadcast_helper) { +static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); @@ -87,11 +83,8 @@ XLA_MAKE_BINARY(FloorDiv, // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); -static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, - DataType dtype, - xla::ComputationDataHandle x, - xla::ComputationDataHandle y, - const BCast& broadcast_helper) { +static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); @@ -127,8 +120,7 @@ XLA_MAKE_BINARY(SqrtGrad, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { return builder->Mul(x, x); } @@ -175,11 +167,11 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); - auto abs_type = abs_shape.ValueOrDie()->element_type(); + auto abs_type = abs_shape.ValueOrDie().element_type(); auto result = b->Lt( abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca9a6b40688d1e8496d1b823e20d273d519f65e8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BucketizeOp : public XlaOpKernel { + public: + explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); + OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), + errors::InvalidArgument("Expected sorted boundaries")); + } + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + const DataType dtype = context->input_type(0); + xla::XlaOp input = context->Input(0); + + xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + // TODO(phawkins): the following behavior matches the behavior of the core + // Bucketize kernel. However, comparing an int32 or int64 against float may + // lead to inaccurate bucketing due to rounding. + if (dtype == DT_DOUBLE) { + input = builder->ConvertElementType(input, xla::F64); + boundaries = builder->ConvertElementType(boundaries, xla::F64); + } else { + input = builder->ConvertElementType(input, xla::F32); + } + xla::XlaOp comparison = builder->ConvertElementType( + builder->Ge(builder->Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = builder->Reduce( + comparison, /*init_value=*/builder->ConstantR0(0), + /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + context->SetOutput(0, buckets); + } + + private: + std::vector boundaries_; +}; + +REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 43a6a747c6bcc441f33f276fde4a66f367d99731..e9d98c768572c52825fa5192ecec834889f040fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -33,9 +33,9 @@ class CastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output; if (src_dtype_ == dst_dtype_) { output = input; @@ -62,5 +62,50 @@ class CastOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Cast"), CastOp); +class BitcastOp : public XlaOpKernel { + public: + explicit BitcastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &src_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("type", &dst_dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output; + + if (src_dtype_ == dst_dtype_) { + output = input; + } else { + // The only complex type in XLA is C64, so error out if the bitcast has a + // complex source or destination type and the bitcast is not trivial. + OP_REQUIRES(ctx, + !xla::primitive_util::IsComplexType(src_type_) && + !xla::primitive_util::IsComplexType(dst_type_), + errors::Unimplemented("Complex types not supported.")); + // XLA bitcast requires that the bit-width of the source and destination + // matches, and currently only the simple lowering is performed. + OP_REQUIRES(ctx, + xla::primitive_util::BitWidth(src_type_) == + xla::primitive_util::BitWidth(dst_type_), + errors::Unimplemented( + "Only bitcasts between equally sized types supported.")); + output = builder->BitcastConvertType(input, dst_type_); + } + + ctx->SetOutput(0, output); + } + + protected: + DataType src_dtype_, dst_dtype_; + xla::PrimitiveType src_type_, dst_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(BitcastOp); +}; + +REGISTER_XLA_OP(Name("Bitcast"), BitcastOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 545aa364f937b2dc972dbe7b8c18b5897aa8e5c3..835a7f568945f0bee86fe2b39491c3326726e1aa 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -34,7 +34,7 @@ class CategoricalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // Get the logits - const xla::ComputationDataHandle& logits = ctx->Input(0); + const xla::XlaOp& logits = ctx->Input(0); TensorShape logits_shape = ctx->InputShape(0); int64 num_samples; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples)); @@ -56,7 +56,7 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); std::array uniform_shape_array = { {batch_size, num_samples, num_classes}}; @@ -78,7 +78,7 @@ class CategoricalOp : public XlaOpKernel { /*broadcast_dimensions=*/{0, 2}); TensorShape softmax_shape(uniform_shape_array); - xla::ComputationDataHandle argmax; + xla::XlaOp argmax; OP_REQUIRES_OK( ctx, XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, diff --git a/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6061e822d8d9c6c807a63aad4e9e9526a49e456c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace { + +class CheckNumericsOp : public XlaOpKernel { + public: + explicit CheckNumericsOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + // TODO(b/32223192): add a real implementation of CheckNumerics + { + static mutex mu(tensorflow::LINKER_INITIALIZED); + static int log_counter = 0; + mutex_lock l(mu); + if (log_counter < 20) { + ++log_counter; + LOG(WARNING) << "Ignoring CheckNumerics operator " << name(); + } + } + ctx->SetOutput(0, ctx->Input(0)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CheckNumericsOp); +}; + +REGISTER_XLA_OP(Name("CheckNumerics"), CheckNumericsOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a00bc912f9f40052565446c6bf9390629af9a4cd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class ClipByValueOp : public XlaOpKernel { + public: + explicit ClipByValueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape shape = ctx->InputShape(0); + const TensorShape min_shape = ctx->InputShape(1); + const TensorShape max_shape = ctx->InputShape(2); + + xla::XlaBuilder* builder = ctx->builder(); + auto input = ctx->Input(0); + auto min = ctx->Input(1); + auto max = ctx->Input(2); + + auto shape_error = [&]() -> tensorflow::Status { + return errors::InvalidArgument( + "clip_value_min and clip_value_max must be either of " + "the same shape as input, or a scalar. ", + "Input shape: ", shape.DebugString(), + " clip_value_min shape: ", min_shape.DebugString(), + " clip_value_max shape: ", max_shape.DebugString()); + }; + + if (shape != min_shape) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); + min = builder->Broadcast(min, shape.dim_sizes()); + } + if (shape != max_shape) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); + max = builder->Broadcast(max, shape.dim_sizes()); + } + ctx->SetOutput(0, builder->Clamp(min, input, max)); + } +}; + +REGISTER_XLA_OP(Name("ClipByValue"), ClipByValueOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 1a246e8df9b2cd83147b50d960744332f8582a51..78285affa1c399ae107a9172fb85cf257457c368 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -54,7 +54,7 @@ class ConcatBaseOp : public XlaOpKernel { // TODO(annarev): add a helper to support int64 input. const int32 concat_dim = literal.Get({}); - std::vector values; + std::vector values; std::vector shapes; OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); const int N = values.size(); @@ -70,13 +70,13 @@ class ConcatBaseOp : public XlaOpKernel { "[", -input_dims, ", ", input_dims, "), but got ", concat_dim)); - // Make a vector holding the ComputationDataHandles for each of - // the inputs that has non-zero elements. - std::vector input_data; + // Make a vector holding the XlaOp for each of the inputs that has non-zero + // elements. + std::vector input_data; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { - xla::ComputationDataHandle handle = values[i]; + xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 8f78b4c8f90cf00d5fa9ba71a78bb1c0fe280dc6..59d06c654de18c9003fe0bdc706d0c2443de6d7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -45,7 +45,7 @@ class ConstOp : public XlaOpKernel { ctx->SetInvalidOutput(0); return; } - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, // recognize that case and emit a scalar broadcast instead. diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 81cea6d376d02c956a5257c5475fe5c10b83deb9..627bad12f33c82e91bc3c6f3323f562bc8174056 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -47,9 +47,8 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( } // Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::ComputationDataHandle CreateExpandedZero( - const TensorShape& filter_shape, DataType dtype, - xla::ComputationBuilder* builder) { +xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); return builder->Broadcast(XlaHelpers::Zero(builder, dtype), @@ -58,7 +57,7 @@ xla::ComputationDataHandle CreateExpandedZero( // Create a mask for depthwise convolution that will make a normal convolution // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// depthwise filter this returns a [2, 2, 3, 6] tensor // 1 1 0 0 0 0 1 1 0 0 0 0 // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 @@ -87,8 +86,8 @@ xla::ComputationDataHandle CreateExpandedZero( // // Finally compare A and broadcasted B in dimension 2 amd return the result at // the beginning of the comment. -xla::ComputationDataHandle CreateExpandedFilterMask( - const TensorShape& filter_shape, xla::ComputationBuilder* builder) { +xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); @@ -96,11 +95,11 @@ xla::ComputationDataHandle CreateExpandedFilterMask( // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::ComputationDataHandle input_feature_iota; + xla::XlaOp input_feature_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, &input_feature_iota)); - xla::ComputationDataHandle expanded_feature_iota; + xla::XlaOp expanded_feature_iota; TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature * depthwise_multiplier, &expanded_feature_iota)); @@ -126,10 +125,10 @@ xla::ComputationDataHandle CreateExpandedFilterMask( // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. -xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( - const TensorShape& filter_shape, DataType dtype, - const xla::ComputationDataHandle& filter, - xla::ComputationBuilder* builder) { +xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, + DataType dtype, + const xla::XlaOp& filter, + xla::XlaBuilder* builder) { int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = @@ -156,16 +155,21 @@ xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( } // Inverse of ExpandFilterForDepthwiseConvolution. -xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, - const xla::ComputationDataHandle& filter_backprop, - xla::ComputationBuilder* builder) { +xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, + const TensorShape& filter_shape, + DataType dtype, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); auto masked_expanded_filter = builder->Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); return builder->Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with + // ExpandedZero guarantees that only one element is non zero, so there + // cannot be accumulated precision error. builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), *ctx->GetOrCreateAdd(dtype), {expanded_filter_shape.dims() - 2}), @@ -244,9 +248,9 @@ class ConvOp : public XlaOpKernel { "input and filter must have the same depth: ", in_depth, " vs ", input_shape.dim_size(feature_dim))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); - xla::ComputationDataHandle filter = ctx->Input(1); + xla::XlaOp filter = ctx->Input(1); TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { filter = ExpandFilterForDepthwiseConvolution( @@ -284,7 +288,7 @@ class ConvOp : public XlaOpKernel { &unused_output_size, &padding[i].first, &padding[i].second)); } - xla::ComputationDataHandle conv = + xla::XlaOp conv = b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); @@ -387,7 +391,7 @@ class ConvBackpropInputOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto filter = ctx->Input(1); auto out_backprop = ctx->Input(2); @@ -431,12 +435,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::ComputationDataHandle mirrored_weights = - b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights - xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = b->ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -542,9 +545,9 @@ class ConvBackpropFilterOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle activations = ctx->Input(0); - xla::ComputationDataHandle gradients = ctx->Input(2); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp activations = ctx->Input(0); + xla::XlaOp gradients = ctx->Input(2); // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 3df8c00f1b83556d7d954aedc8eeac0728251c3e..7fcd4170fb79a574663c1abffe873d4b53f471d3 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -53,7 +53,7 @@ class CrossOp : public XlaOpKernel { } std::vector strides(in0_shape.dims(), 1); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto in0 = ctx->Input(0); auto in1 = ctx->Input(1); starts.back() = 0; diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 0cf03ceb948a5165a71e902eef5264eaddbd71e9..01aa1a83e7967921f1583b3ef18ec57e452dcfea 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" @@ -75,7 +75,7 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } // Call virtual method to emit the computation. - xla::ComputationDataHandle output = + xla::XlaOp output = Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle, rhs_shape.dim_sizes(), bcast, extend_dimension); @@ -85,11 +85,9 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(0, output); } -/* static */ std::pair -XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& lhs, - const xla::ComputationDataHandle& rhs, - const BCast& broadcast_helper) { +/* static */ std::pair XlaBinaryOp::Broadcast( + xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, + const BCast& broadcast_helper) { // Manually construct the broadcasting since MapN does not do // automatic broadcasting. The bcast helper ensures that // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 5bc1d5fb1f08fb576df654e1f4068b6be9114096..4f92dbc8740b697322424058530b8477c35d809a 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -30,7 +30,7 @@ namespace tensorflow { // inputs that can be broadcast to the same shape. The base class // contains pure virtual methods to override: description is a textual // description of the operation; and Computation adds the -// implementation of the operation to a xla::ComputationBuilder. For most +// implementation of the operation to a xla::XlaBuilder. For most // arithmetic Ops XLA handles the broadcasting automatically given the input // tensors. class XlaBinaryOp : public XlaOpKernel { @@ -55,10 +55,9 @@ class XlaBinaryOp : public XlaOpKernel { // higher-rank input should be matched when broadcasting the // lower-rank input. See comment below and the documentation on broadcasting // in the XLA documentation. - virtual xla::ComputationDataHandle Computation( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, - const gtl::ArraySlice& lhs_shape, - const xla::ComputationDataHandle& rhs, + virtual xla::XlaOp Computation( + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, + const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; @@ -67,11 +66,9 @@ class XlaBinaryOp : public XlaOpKernel { // Helper function that performs the broadcasting described by // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. - static std::pair - Broadcast(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& lhs, - const xla::ComputationDataHandle& rhs, - const BCast& broadcast_helper); + static std::pair Broadcast( + xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, + const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 96d7809f7995634b6bc31ab801b93526d9da7e6f..23243f62462c6315e359d9621823b19fc98c6218 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -50,8 +50,8 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,8 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -152,8 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 765ea922a532a085a552192348ab360c4c30ff0a..931705ba837153e1175cd9a209876ef5ec93f0fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -25,10 +25,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr CreateDiagonal( - const xla::ComputationDataHandle& input, int64 last_dim_size, +xla::StatusOr CreateDiagonal( + const xla::XlaOp& input, int64 last_dim_size, tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, - xla::ComputationBuilder* builder) { + xla::XlaBuilder* builder) { // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -38,12 +38,11 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::ComputationDataHandle iota; + xla::XlaOp iota; TF_RETURN_IF_ERROR( XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::ComputationDataHandle iota_broadcast = - builder->Broadcast(iota, {last_dim_size}); - xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. @@ -65,8 +64,7 @@ xla::StatusOr CreateDiagonal( std::vector broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::ComputationDataHandle input_broadcast = - builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; xla::PrimitiveType element_type; @@ -74,7 +72,7 @@ xla::StatusOr CreateDiagonal( DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = Zeros(builder, broadcast_shape); input_broadcast = builder->Add(input_broadcast, zeros); return builder->Select(mask, input_broadcast, zeros); @@ -85,7 +83,7 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); @@ -96,7 +94,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -112,7 +110,7 @@ class DiagOp : public XlaOpKernel { auto diag_or_status = CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -131,7 +129,7 @@ class DiagPartOp : public XlaOpKernel { explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -158,7 +156,7 @@ class DiagPartOp : public XlaOpKernel { new_dims.push_back(dims[i]); } - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); // TODO(b/30878775): use Slice with strides when supported, in place of // the Pad -> Reshape -> Slice. @@ -199,7 +197,7 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); @@ -210,7 +208,7 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); @@ -232,7 +230,7 @@ class MatrixDiagPartOp : public XlaOpKernel { explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -241,7 +239,7 @@ class MatrixDiagPartOp : public XlaOpKernel { errors::InvalidArgument("Expected 2 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = dims[last_dim]; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..0419de78b2ee83fd395e8bf23444fde84f30bba2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class DynamicUpdateSliceOp : public XlaOpKernel { + public: + explicit DynamicUpdateSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + VLOG(3) << "DynamicUpdateSliceOp::Compile"; + + DataType index_type = input_type(2); + OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("index must be int32 or int64")); + + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape update_shape = ctx->InputShape(1); + const TensorShape index_shape = ctx->InputShape(2); + + OP_REQUIRES( + ctx, + TensorShapeUtils::IsVector(index_shape) && + index_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument("index must be a vector with length equal to " + "the number of input dimensions")); + OP_REQUIRES( + ctx, input_shape.dims() == update_shape.dims(), + errors::InvalidArgument("input and update must have the same rank," + " input shape is ", + input_shape.DebugString(), "; update shape is ", + update_shape.DebugString())); + + xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( + ctx->Input(0), ctx->Input(1), ctx->Input(2)); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index f2cd21ffb9ce88747c04f3c71e66dadeb1faf0f9..dd4a16908779508380b36f43ce2306ff2f5fb8c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -56,7 +56,7 @@ class DynamicStitchOp : public XlaOpKernel { std::vector indices_input; OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input)); - std::vector data; + std::vector data; std::vector data_shapes; OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes)); @@ -136,7 +136,7 @@ class DynamicStitchOp : public XlaOpKernel { // Look up all the children expressions that represent the data // inputs. - std::vector input(indices.size()); + std::vector input(indices.size()); for (int input_num = 0; input_num < indices.size(); input_num++) { TensorShape new_shape; // first reshaped dimension is the number of indices for this input. @@ -166,7 +166,7 @@ class DynamicStitchOp : public XlaOpKernel { for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } - std::vector to_concat(number_of_indices); + std::vector to_concat(number_of_indices); for (int index_num = 0; index_num < number_of_indices; index_num++) { const auto& expression = input[src_input_vector[index_num]]; // Take the appropriate slice of data. diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 2fd27c5ca7e87c8b387d9d0854b787d30e7f7b6f..493781a1e68b8906f1a7e018e5710130e2eb08b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -32,11 +32,10 @@ class EluOp : public XlaOpKernel { explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); } }; @@ -47,7 +46,7 @@ class EluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return lhs * (1 + rhs). void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); @@ -66,15 +65,14 @@ class SeluOp : public XlaOpKernel { explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), b->Mul(scale_alpha, expm1))); } @@ -86,9 +84,8 @@ class SeluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return lhs * (1 + rhs). void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index b2970eae20a3fb71f06619f476a49d41b22bca56..6df01cabbf1d98c0299bfd808bcc6db6223c4777 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -93,7 +93,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { input_shape.DebugString())); const int64 depth = input_shape.dim_size(feature_dim); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); // The following code is equivalent to: // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) @@ -110,7 +110,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::ComputationDataHandle iota; + xla::XlaOp iota; TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, kernel_size * depth, &iota)); @@ -147,7 +147,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { &padding[i].first, &padding[i].second)); } - xla::ComputationDataHandle conv = + xla::XlaOp conv = builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 453a32c494b42e9922bc35fc526f3306530054fd..8f0de0a524c908b598c1a2165a462275346ad137 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -44,23 +44,20 @@ void CpuNudge(const float min, const float max, const float quant_min, } // An XLA version of CpuNudge(). -void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, - const xla::ComputationDataHandle& min, - const xla::ComputationDataHandle& max, +void XlaNudge(xla::XlaBuilder* b, const DataType data_type, + const xla::XlaOp& min, const xla::XlaOp& max, const float quant_min_value, const float quant_max_value, - xla::ComputationDataHandle* nudged_min, - xla::ComputationDataHandle* nudged_max, - xla::ComputationDataHandle* scale) { + xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, + xla::XlaOp* scale) { *scale = b->Div(b->Sub(max, min), XlaHelpers::FloatLiteral(b, data_type, quant_max_value - quant_min_value)); - xla::ComputationDataHandle quant_min = + xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::ComputationDataHandle zero_point_from_min = - b->Sub(quant_min, b->Div(min, *scale)); - xla::ComputationDataHandle quant_max = + xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); - xla::ComputationDataHandle nudged_zero_point = + xla::XlaOp nudged_zero_point = b->Select(b->Le(zero_point_from_min, quant_min), quant_min, b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, b->Round(zero_point_from_min))); @@ -68,22 +65,18 @@ void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); } -xla::ComputationDataHandle Quantize( - xla::ComputationBuilder* b, const xla::ComputationDataHandle& input, - const DataType data_type, - const xla::ComputationDataHandle& nudged_input_min, - const xla::ComputationDataHandle& nudged_input_max, - const xla::ComputationDataHandle& input_scale) { - xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::ComputationDataHandle inv_scale = b->Div(one, input_scale); - xla::ComputationDataHandle half = - XlaHelpers::FloatLiteral(b, data_type, 0.5f); - - xla::ComputationDataHandle clamped = - b->Clamp(nudged_input_min, input, nudged_input_max); - xla::ComputationDataHandle clamped_shifted = - b->Sub(clamped, nudged_input_min); - xla::ComputationDataHandle rounded = +xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, + const DataType data_type, + const xla::XlaOp& nudged_input_min, + const xla::XlaOp& nudged_input_max, + const xla::XlaOp& input_scale) { + xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); + xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); + + xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp rounded = b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); return b->Add(b->Mul(rounded, input_scale), nudged_input_min); } @@ -111,18 +104,18 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min = + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); - xla::ComputationDataHandle nudged_input_max = + xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::ComputationDataHandle input_scale = + xla::XlaOp input_scale = XlaHelpers::FloatLiteral(b, data_type, input_scale_); - xla::ComputationDataHandle output = Quantize( - b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min, + nudged_input_max, input_scale); ctx->SetOutput(0, output); } @@ -159,23 +152,22 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle gradient = ctx->Input(0); + xla::XlaOp gradient = ctx->Input(0); const TensorShape gradient_shape = ctx->InputShape(0); - xla::ComputationDataHandle input = ctx->Input(1); + xla::XlaOp input = ctx->Input(1); const DataType data_type = ctx->input_type(1); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min = + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); - xla::ComputationDataHandle nudged_input_max = + xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::ComputationDataHandle between_nudged_min_max = + xla::XlaOp between_nudged_min_max = b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::ComputationDataHandle zeroes = b->Broadcast( - XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes()); - xla::ComputationDataHandle output = - b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -204,18 +196,18 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::ComputationDataHandle input_min = ctx->Input(1); - xla::ComputationDataHandle input_max = ctx->Input(2); + xla::XlaOp input_min = ctx->Input(1); + xla::XlaOp input_max = ctx->Input(2); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::ComputationDataHandle output = Quantize( - b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min, + nudged_input_max, input_scale); ctx->SetOutput(0, output); } @@ -243,37 +235,43 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle gradient = ctx->Input(0); + xla::XlaOp gradient = ctx->Input(0); const TensorShape gradient_shape = ctx->InputShape(0); - xla::ComputationDataHandle input = ctx->Input(1); + xla::XlaOp input = ctx->Input(1); const DataType data_type = ctx->input_type(1); - xla::ComputationDataHandle input_min = ctx->Input(2); - xla::ComputationDataHandle input_max = ctx->Input(3); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(data_type); + xla::XlaOp input_min = ctx->Input(2); + xla::XlaOp input_max = ctx->Input(3); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::ComputationDataHandle between_nudged_min_max = + xla::XlaOp between_nudged_min_max = b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type); - xla::ComputationDataHandle zeroes = - b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::ComputationDataHandle output0 = - b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zero = XlaHelpers::Zero(b, data_type); + xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); - xla::ComputationDataHandle output1 = - b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + xla::XlaOp below_min = b->Lt(input, nudged_input_min); + xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); + xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); - xla::ComputationDataHandle output2 = - b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + xla::XlaOp above_max = b->Gt(input, nudged_input_max); + xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); + xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index a4f3c1c3ad9a928e0552c388a25ed9fcb08edabb..933924cad1c7cac2879bd4720cb21ffc33c23f50 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -62,9 +62,8 @@ class GenericFftOp : public XlaOpKernel { } } - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle fft = - b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } @@ -82,9 +81,11 @@ class FFTOp : public GenericFftOp { explicit FFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); -REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); -REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); +REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<3>); template class IFFTOp : public GenericFftOp { @@ -92,9 +93,12 @@ class IFFTOp : public GenericFftOp { explicit IFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); -REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); -REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); +REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<3>); template class RFFTOp : public GenericFftOp { diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index eaa13b8dfacce9aaca42ce5fcdfa467ce7fa7b7f..e4467a0fb138ed7919af62ed032c0f5abee3e4f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -48,7 +48,7 @@ class FillOp : public XlaOpKernel { 0, {dims_shape.num_elements()}, &dims_literal)); // Convert the dims literal into a vector that we can pass to - // ComputationBuilder. + // XlaBuilder. std::vector broadcast; broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { @@ -56,7 +56,7 @@ class FillOp : public XlaOpKernel { } // Look up the value input, reshaping to a scalar if it was a // 'legacy' scalar (secretly a vector). - xla::ComputationDataHandle data = ctx->Input(1); + xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); data = ctx->builder()->Reshape(data, {}); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 7945c05af40df21a798a2cff51fe7f8e935793f6..d13e25bcddae16d0cd630403219657121b80868d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -26,55 +26,55 @@ limitations under the License. namespace tensorflow { -Status XlaGather(const xla::ComputationDataHandle& input, - const TensorShape& input_shape, - const xla::ComputationDataHandle& indices, - TensorShape indices_shape, int64 axis, bool indices_are_nd, - DataType dtype, DataType index_type, - xla::ComputationBuilder* builder, - xla::ComputationDataHandle* gather_output) { +Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, const TensorShape& indices_shape, + int64 axis, bool indices_are_nd, DataType dtype, + DataType index_type, xla::XlaBuilder* builder, + xla::XlaOp* gather_output) { + // There is no deep reason why we need this precondition, but this is the only + // combination that is used and tested today. + CHECK(!indices_are_nd || axis == 0); + + // num_index_dims is the number of components in each index in the indices + // tensor. + // + // num_indices is the total number of (n dimensional or scalar) indices in the + // indices tensor. + // // If the indices are N-dimensional, then the minor dimension of indices // should be of size N and correspond to the N indices. - int64 num_index_dims = 1; + int64 num_index_dims; + int64 num_indices = 1; if (indices_are_nd) { CHECK_GE(indices_shape.dims(), 1); num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); - indices_shape.RemoveLastDims(1); + for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) { + num_indices *= indices_shape.dim_size(i); + } + } else { + num_index_dims = 1; + for (int64 i = 0, e = indices_shape.dims(); i < e; i++) { + num_indices *= indices_shape.dim_size(i); + } } - // Although the indices Tensor is flattened into rank 1 during the lookup, - // and each scalar entry is used as an index into the first dimension of the - // input, the output is returned with shape: - // input.shape[:axis] + indices.shape + input.shape[axis+1:] - - const int64 num_indices = indices_shape.num_elements(); - TensorShape input_shape_pre_axis(input_shape); - input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); - TensorShape input_shape_post_axis(input_shape); - input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); - // Each slice of the input tensor has shape: - // [, 1, ..., 1, ] - TensorShape slice_shape(input_shape); - for (int64 i = 0; i < num_index_dims; ++i) { - slice_shape.set_dim(axis + i, 1); - } + // Degenerate case: empty indices. + if (num_indices == 0) { + TensorShape input_shape_pre_axis{input_shape}; + input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); + TensorShape input_shape_post_axis{input_shape}; + input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); - TensorShape loop_out_shape; - loop_out_shape.AppendShape(input_shape_pre_axis); - loop_out_shape.AddDim(num_indices); - loop_out_shape.AppendShape(input_shape_post_axis); - TensorShape loop_out_slice_shape; - loop_out_slice_shape.AppendShape(input_shape_pre_axis); - loop_out_slice_shape.AddDim(1); - loop_out_slice_shape.AppendShape(input_shape_post_axis); + TensorShape indices_shape_no_index_vectors{indices_shape}; + if (indices_are_nd) { + indices_shape_no_index_vectors.RemoveLastDims(1); + } - TensorShape out_shape; - out_shape.AppendShape(input_shape_pre_axis); - out_shape.AppendShape(indices_shape); - out_shape.AppendShape(input_shape_post_axis); + TensorShape out_shape; + out_shape.AppendShape(input_shape_pre_axis); + out_shape.AppendShape(indices_shape_no_index_vectors); + out_shape.AppendShape(input_shape_post_axis); - // Degenerate case: empty indices. - if (num_indices == 0) { *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); return Status::OK(); @@ -88,76 +88,61 @@ Status XlaGather(const xla::ComputationDataHandle& input, } } - // Flatten the major dimensions of indices into a single dimension for ease of - // iteration. If there is an axis dimension, we must leave it alone. - std::vector flat_indices_shape = {num_indices}; - if (indices_are_nd) { - flat_indices_shape.push_back(num_index_dims); - } - - // Specify the shape of the loop-carried Tensor tuple. - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - loop_out_shape.dim_sizes()); - auto init = {input, flat_indices, init_out}; - - // Construct the while loop body's function. The implementation of gather is: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // xi = dynamic-slice(input, index) - // output = dynamic-update-slice(output, xi, i) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* bodyb) { - auto input = loop_vars[0]; - auto indices = loop_vars[1]; - auto output = loop_vars[2]; - - auto zero_index = XlaHelpers::Zero(bodyb, index_type); - - // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; - auto indices_offset = bodyb->Reshape(i, {1}); - if (indices_are_nd) { - // Slice out the entire nd index, if applicable. - indices_offset = bodyb->Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = bodyb->Collapse(index, {0, 1}); + // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a + // tensor of shape [3,3]. + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // gather = s32[3,2] gather(operand, indices), + // output_window_dims={0}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3, 1} + // + // + // Example of an N-D gather pulling out slices of shape [1,1,2] out of a + // tensor of shape [3,3,2]. + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // gather = s32[2,2] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0,1}, + // gather_dims_to_operand_dims={0,1}, + // index_vector_dim=0, + // window_bounds={1,1,2} + + xla::GatherDimensionNumbers dim_numbers; + std::vector window_bounds; + window_bounds.reserve(input_shape.dims()); + for (int64 i = 0; i < input_shape.dims(); i++) { + int64 window_bound; + if (axis <= i && i < (axis + num_index_dims)) { + dim_numbers.add_elided_window_dims(i); + window_bound = 1; } else { - index = bodyb->DynamicSlice(indices, indices_offset, {1}); + window_bound = input_shape.dim_size(i); + } + + window_bounds.push_back(window_bound); + + if (i < axis) { + dim_numbers.add_output_window_dims(i); + } else if (i >= (axis + num_index_dims)) { + int64 indices_rank = + indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims(); + dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims); } + } + + dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) + : indices_shape.dims()); + for (int64 i = axis; i < axis + num_index_dims; i++) { + dim_numbers.add_gather_dims_to_operand_dims(i); + } - // Slice the corresponding data from the input array. - auto start_indices = bodyb->Pad( - index, zero_index, - xla::MakeEdgePaddingConfig( - {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice_i = bodyb->Reshape( - bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()), - loop_out_slice_shape.dim_sizes()); - - // Construct the index into the output Tensor 0, ..., , 0, ... - std::vector out_index_vals( - loop_out_shape.dims(), bodyb->Reshape(zero_index, {1})); - out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1}); - auto out_index = bodyb->ConcatInDim(out_index_vals, 0); - - // Update the output Tensor - auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); - - return std::vector{input, indices, - updated_output}; - }; - - // Construct the While loop, extract and reshape the output. - xla::PrimitiveType ptype; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype)); - TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn, - init, "gather", builder)); - *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes()); + *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds); return Status::OK(); } @@ -166,7 +151,7 @@ class GatherOp : public XlaOpKernel { explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto input = context->Input(0); auto input_shape = context->InputShape(0); auto indices = context->Input(1); @@ -195,7 +180,7 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, errors::InvalidArgument("indices must be int32 or int64")); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( context, XlaGather(input, input_shape, indices, indices_shape, axis, /*indices_are_nd=*/false, input_type(0), index_type, @@ -233,10 +218,10 @@ class GatherNdOp : public XlaOpKernel { indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", params_shape.dims())); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto params = context->Input(0); auto indices = context->Input(1); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/true, params_type, diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index bd8b92c22d71fe89ab8951ec79f411feef6505e3..d898e43b858bac706d524c7c271f48b1b5fa258f 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -33,13 +33,11 @@ namespace tensorflow { // If `indices_are_nd` is true, the last dimension of `indices` are treated as // a multidimensional index values. Otherwise, `indices` is treated as a tensor // of scalar indices. -Status XlaGather(const xla::ComputationDataHandle& input, - const TensorShape& input_shape, - const xla::ComputationDataHandle& indices, - TensorShape indices_shape, int64 axis, bool indices_are_nd, - DataType dtype, DataType index_type, - xla::ComputationBuilder* builder, - xla::ComputationDataHandle* gather_output); +Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, const TensorShape& indices_shape, + int64 axis, bool indices_are_nd, DataType dtype, + DataType index_type, xla::XlaBuilder* builder, + xla::XlaOp* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 39af662b638cb9d723118e58fcfc983633fed497..e72200bfbcff20c55ac03030f1afc4bacaabf7ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -38,6 +38,7 @@ class IdentityOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); REGISTER_XLA_OP(Name("Snapshot"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d48c6eea754f75a8879d3938f233a6a591d26d0d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -0,0 +1,250 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/if_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr)); + then_branch_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr)); + else_branch_ = *name_attr; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); +} + +// TODO(b/35949885): There is duplication here with the handling of the +// while_op. Refactor the common code out/rework. +void XlaIfOp::Compile(XlaOpKernelContext* ctx) { + xla::XlaBuilder* b = ctx->builder(); + + OP_REQUIRES(ctx, cond_type_ == DT_BOOL, + errors::InvalidArgument( + "Condition argument must be a boolean for XLA compilation")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), + errors::InvalidArgument( + "Condition argument must be a scalar for XLA compilation")); + + VLOG(1) << "Building If: " << input_types_.size() << " inputs"; + + std::vector arguments(input_types_.size()); + for (int i = 0; i < input_types_.size(); ++i) { + XlaCompiler::Argument& arg = arguments[i]; + DataType type = ctx->input_type(i + 1); + + if (type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + + arg.type = resource->type(); + arg.shape = resource->shape(); + OP_REQUIRES(ctx, arg.initialized, + errors::Unimplemented("Uninitialized arguments: ", arg.name)); + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + arg.name = resource->name(); + VLOG(2) << "Resource " << resource->name() + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input_types_[i]; + arg.shape = ctx->InputShape(i + 1); + VLOG(2) << "Arg type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString(); + } + } + + // Compile both branches of the conditional. + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.resolve_compile_time_constants = false; + options.return_updated_values_for_all_resources = true; + options.is_entry_computation = false; + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult then_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + XlaCompiler::CompilationResult else_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + + bool has_tensor_array_gradients = false; + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; + } + } + + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + then_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + else_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + } + + // Check that both branches have identical input shapes. + OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape then_input_shape = then_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape else_input_shape = else_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), + errors::InvalidArgument( + "Input shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_input_shape), " vs. ", + xla::ShapeUtil::HumanString(else_input_shape))); + + // Check that both branches have identical output shapes. + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(then_result.xla_output_shape, + else_result.xla_output_shape), + errors::InvalidArgument( + "Output shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", + xla::ShapeUtil::HumanString(else_result.xla_output_shape))); + + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString(then_result.xla_output_shape); + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + OP_REQUIRES(ctx, + then_result.resource_updates.size() == + else_result.resource_updates.size(), + errors::FailedPrecondition( + "Different number of resources in then and else branch")); + for (int i = 0; i < then_result.resource_updates.size(); ++i) { + const auto& lhs = then_result.resource_updates[i]; + const auto& rhs = else_result.resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + OP_REQUIRES( + ctx, equal, + errors::FailedPrecondition( + "Mismatch in resource of then and else branch for resource ", i)); + } + + int num_inputs = then_result.input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = then_result.input_mapping[i] + 1; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + + xla::XlaOp outputs = + b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, + b->Tuple(inputs), *else_result.computation); + // Sets non-variable outputs. + for (int i = 0; i < output_types_.size(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + xla::XlaOp output_handle = b->GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; + } + } + ctx->SetOutput(i, output_handle); + } + } + + // Updates the values of any resource variables modified by the conditional + // bodies. + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (int i = 0; i < result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = result->resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + if (update.modified) { + int pos = result->outputs.size() + i; + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + b->GetTupleElement(outputs, pos), b)); + } + VLOG(2) << "If variable: pos: " << update.input_index + << " name: " << resource->name() + << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + } + } + VLOG(1) << "Done building If"; +} + +REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f9bc98a198a72dcc0594e61971713bf890ce30b6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional conditional primitive. +// +// The outputs of the then/else branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in then/else bodies may read from and write to resource +// variables. +// Resource variables may be passed as arguments to the then/else function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the then/else bodies output. This ensures the then/else bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaIfOp : public XlaOpKernel { + public: + explicit XlaIfOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp); + + NameAttrList then_branch_; + NameAttrList else_branch_; + DataType cond_type_; + DataTypeVector input_types_; + DataTypeVector output_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index f22f384256a8ddd8c05de4a1322aba741dc4d7fd..1568b33679963c1a6630525f60560180d40b8d53 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -23,10 +23,9 @@ namespace { // Converts 'input' from RGB format to HSV format. // 'shape' is the shape of the red/green/blue tensors. -std::array RGBToHSV( - XlaOpKernelContext* ctx, xla::ComputationBuilder* b, - const std::array& rgb, DataType dtype, - const TensorShape& shape) { +std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, + const std::array& rgb, + DataType dtype, const TensorShape& shape) { auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); @@ -54,12 +53,12 @@ std::array RGBToHSV( } // Converts 'input' from HSV format to RGB format. -std::array HSVToRGB( - xla::ComputationBuilder* b, - const std::array& hsv, DataType dtype) { - xla::ComputationDataHandle hue = hsv[0]; - xla::ComputationDataHandle saturation = hsv[1]; - xla::ComputationDataHandle value = hsv[2]; +std::array HSVToRGB(xla::XlaBuilder* b, + const std::array& hsv, + DataType dtype) { + xla::XlaOp hue = hsv[0]; + xla::XlaOp saturation = hsv[1]; + xla::XlaOp value = hsv[2]; auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); @@ -95,16 +94,16 @@ class RGBToHSVOp : public XlaOpKernel { errors::FailedPrecondition("input must have 3 channels but input has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; @@ -133,15 +132,15 @@ class HSVToRGBOp : public XlaOpKernel { errors::FailedPrecondition("input must have 3 channels but input has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle hue = + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp hue = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle saturation = + xla::XlaOp saturation = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle value = + xla::XlaOp value = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); @@ -174,15 +173,19 @@ class AdjustContrastOpV2 : public XlaOpKernel { errors::InvalidArgument("contrast_factor must be scalar: ", factor_shape.DebugString())); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle factor = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp factor = context->Input(1); DataType type = context->input_type(0); - auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), - /*computation=*/*context->GetOrCreateAdd(type), + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); + auto output = XlaHelpers::ConvertElementType(b, reduce, type); output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); @@ -217,19 +220,19 @@ class AdjustSaturationOp : public XlaOpKernel { errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle scale = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp scale = context->Input(1); DataType type = context->input_type(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; @@ -267,19 +270,19 @@ class AdjustHueOp : public XlaOpKernel { errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle delta = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp delta = context->Input(1); DataType type = context->input_type(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index f36b3f594826c27b7866d956c855aa3638db9cb4..79d3a6979cec4c6bda92a71dcff4ddd2151367d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -99,28 +99,35 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } -xla::ComputationDataHandle MakeBilinearResizeKernel( - xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, - int64 channels) { - // Form a 2D convolution kernel like: - // 1 2 3 2 1 - // 2 4 6 4 2 - // 1/9 * 3 6 9 6 3 - // 2 4 6 4 2 - // 1 2 3 2 1 - // by multiplying two 1D kernels of the form: - // 1/3 * [1 2 3 2 1] - auto make_1d_kernel = [](int64 n) { - std::vector kernel(n * 2 - 1); - for (int64 i = 0; i < n; ++i) { - float v = (i + 1.0f) / n; - kernel[i] = v; - kernel[n * 2 - 2 - i] = v; - } - return kernel; - }; +// Form a 2D convolution kernel like: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 +// by multiplying two 1D kernels of the form: +// 1/3 * [1 2 3 2 1] +// If the 2D kernel would be very large, the 1D kernel can be applied once in +// each dimension due to the symmetry of the kernel along all axis to reduce the +// computational intensity. +std::vector Make1DKernel(int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; +} + +// Kernels with more than 16 spatial elements are considered intense and the +// kernel should applied to each dimension independently. +const int64 kMax2DKernelSize = 16; - xla::ComputationDataHandle channels_iota; +xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels) { + xla::XlaOp channels_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK( XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); @@ -133,16 +140,43 @@ xla::ComputationDataHandle MakeBilinearResizeKernel( xla::PrimitiveType::F32); return builder->Mul( builder->Mul(diag, - builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR1(Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}), - builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->ConstantR1(Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } -xla::ComputationDataHandle ResizeUsingDilationAndConvolution( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input, - const int num_spatial_dims, std::vector in_size, - std::vector out_size, const int64 channels) { +xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels, int64 dim) { + xla::XlaOp channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq(builder->Broadcast( + channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + if (dim == 1) { + return builder->Mul( + diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}); + } + return builder->Mul(diag, + builder->ConstantR1(Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + +xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, + const xla::XlaOp& input, + const int num_spatial_dims, + std::vector in_size, + std::vector out_size, + const int64 channels) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -163,20 +197,42 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution( dimension_numbers.add_output_spatial_dimensions(1 + i); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::ComputationDataHandle kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::ComputationDataHandle output = builder->ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp output; + // Split convolutions into independent dimensions if they wmuld be a very + // large kernel. + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + output = builder->ConvGeneralDilated( + input, kernel0, {dims.stride[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.kernel_size[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + output = builder->ConvGeneralDilated( + output, kernel1, {1, dims.stride[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.kernel_size[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. @@ -189,10 +245,12 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution( return output; } -xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad, - const int num_spatial_dims, std::vector in_size, - std::vector grad_size, const int64 channels) { +xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, + const xla::XlaOp& grad, + const int num_spatial_dims, + std::vector in_size, + std::vector grad_size, + const int64 channels) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size); @@ -210,26 +268,63 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::ComputationDataHandle kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp output; + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = + builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + if (in_size[0] == 1 && grad_size[0] > 1) { + kernel0 = + builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), + /*broadcast_dimensions=*/{0}); + } + if (in_size[1] == 1 && grad_size[1] > 1) { + kernel1 = + builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - } - xla::ComputationDataHandle output = builder->ConvGeneralDilated( - grad, kernel, /*window_strides=*/dims.kernel_size, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = builder->ConvGeneralDilated( + grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.stride[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + output = builder->ConvGeneralDilated( + output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.stride[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. // Opposite of the slice performed by the forward op. @@ -258,7 +353,7 @@ class ResizeBilinearOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == 4, @@ -283,7 +378,7 @@ class ResizeBilinearOp : public XlaOpKernel { const int num_spatial_dims = 2; - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. @@ -318,7 +413,7 @@ class ResizeBilinearOp : public XlaOpKernel { // from image of size axb -> cxd is same as resizing axb -> exf -> cxd. // // This makes the convolutions kernels smaller and the operation faster. - xla::ComputationDataHandle output = input; + xla::XlaOp output = input; while (in_size != out_size) { if (in_size[0] != 1 && in_size[1] != 1) { std::vector k = { @@ -369,7 +464,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape input_shape = ctx->InputShape(1); OP_REQUIRES(ctx, input_shape.dims() == 4, @@ -406,9 +501,9 @@ class ResizeBilinearGradOp : public XlaOpKernel { const int num_spatial_dims = 2; - xla::ComputationDataHandle grad = ctx->Input(0); + xla::XlaOp grad = ctx->Input(0); - xla::ComputationDataHandle output = grad; + xla::XlaOp output = grad; while (in_size != grad_size) { if (in_size[0] != 1 && in_size[1] != 1) { std::vector k = { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 7bf4b435f526afa93d8a218b191928acb932cd6b..36eb4c75454ed82804c40b82e5dbaec2eef0a719 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -61,10 +61,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { DataType index_type = output_type(0); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaOp output; if (is_min_) { OP_REQUIRES_OK(ctx, XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index b1f3c3c298ce0cadf38b9bda715761fe7e2896d7..2c2d88486fda99d2380382a3e2f633f5bdc7478c 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -71,10 +71,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), errors::InvalidArgument( "ArgMax implementation requires a CustomCall on CPU")); - xla::ComputationBuilder& b = *ctx->builder(); + xla::XlaBuilder& b = *ctx->builder(); // XLA passes to the function, so it is not included here. - std::vector args; + std::vector args; args.push_back(ctx->Input(0)); args.push_back(b.ConstantLiteral( *xla::Literal::CreateR1(input_shape.dim_sizes()))); @@ -91,7 +91,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. - xla::ComputationDataHandle output; + xla::XlaOp output; switch (input_shape.dims()) { case 1: output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index d096415087e47a73503a06526ab133ac34803c5d..1decf7d72d72bb697477e7f841ced2a1a0d5fbe9 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -29,21 +29,22 @@ class L2LossOp : public XlaOpKernel { explicit L2LossOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); + std::vector dims(ctx->InputShape(0).dims()); + std::iota(dims.begin(), dims.end(), 0); DataType dtype = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - - auto zero = XlaHelpers::Zero(b, dtype); - auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - const xla::Computation& add = *ctx->GetOrCreateAdd(dtype); - - std::vector dims(input_shape.dims()); - std::iota(dims.begin(), dims.end(), 0); + xla::XlaBuilder* const b = ctx->builder(); // output = sum(t ** 2) / 2 - auto x = ctx->Input(0); - ctx->SetOutput(0, b->Div(b->Reduce(b->Mul(x, x), zero, add, dims), two)); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto t = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto square = b->Mul(t, t); + auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); + ctx->SetOutput(0, b->Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0388b4c830702ea00ec69fc42c6468326c88cf38 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 +// input. + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +constexpr std::array kListDiffTypes = {DT_INT32, DT_INT64}; + +// ListDiffOp is an XLA kernel that supports constant-only x and y input. +class ListDiffOp : public XlaOpKernel { + public: + explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)), + errors::InvalidArgument("ListDiff expects x as a vector, not ", + context->InputShape(0).DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)), + errors::InvalidArgument("ListDiff expects y as a vector, not ", + context->InputShape(1).DebugString())); + + DataType val_type = context->expected_output_dtype(0); + DataType idx_type = context->expected_output_dtype(1); + + Status status; + switch (val_type) { + case DT_INT32: + status = ListDiffWithIndexType(context, idx_type); + break; + case DT_INT64: + status = ListDiffWithIndexType(context, idx_type); + break; + default: + // This should never happen since we restrict this kernel to only match + // inputs with supported Tensor datatype. + status = errors::InvalidArgument("ListDiff expects x and y as either ", + "int32 or int64, not ", + DataTypeString(val_type)); + } + OP_REQUIRES_OK(context, status); + } + + private: + template + Status ListDiff(XlaOpKernelContext* context) { + std::vector x_input, y_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); + + std::unordered_set y_input_set; + y_input_set.reserve(y_input.size()); + for (auto y : y_input) { + y_input_set.insert(y); + } + + std::vector val_output; + std::vector idx_output; + auto x_size = x_input.size(); + for (Tidx i = 0; i < x_size; ++i) { + if (y_input_set.count(x_input[i]) > 0) { + continue; + } + val_output.push_back(x_input[i]); + idx_output.push_back(i); + } + + context->SetOutput(0, context->builder()->ConstantR1(val_output)); + context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + return Status::OK(); + } + + template + Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) { + switch (idx_type) { + case DT_INT32: + return ListDiff(context); + case DT_INT64: + return ListDiff(context); + default: + return errors::InvalidArgument( + "ListDiff expects idx_out as either int32 or int64, not ", + DataTypeString(idx_type)); + } + } +}; + +REGISTER_XLA_OP(Name("ListDiff") + .TypeConstraint("T", kListDiffTypes) + .CompileTimeConstInput("x") + .CompileTimeConstInput("y"), + ListDiffOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 759d1a1a2d996d4f5deb1774be7014bb6de30f40..39fbf98a6274918840e9e351470f04c2d80c5d01 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -38,8 +38,8 @@ class LRNOp : public XlaOpKernel { OP_REQUIRES(ctx, in_shape.dims() == 4, errors::InvalidArgument("in must be 4-dimensional")); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); // sqr_sum[a, b, c, d] = // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) @@ -47,12 +47,17 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. - auto squared = builder->Mul(input, input); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto scale = builder->Pow( builder->Add(builder->ConstantR0(bias_), @@ -106,10 +111,10 @@ class LRNGradOp : public XlaOpKernel { "input_grads, input_image, and out_image should have the same " "shape")); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle in_grads = ctx->Input(0); - xla::ComputationDataHandle in_image = ctx->Input(1); - xla::ComputationDataHandle out_image = ctx->Input(2); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp in_grads = ctx->Input(0); + xla::XlaOp in_image = ctx->Input(1); + xla::XlaOp out_image = ctx->Input(2); // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python // pseudo-code, the Eigen code does this for each spatial position: @@ -130,12 +135,17 @@ class LRNGradOp : public XlaOpKernel { // dyi *= out_grads[j] // grads[k] += dyi - auto squared = builder->Mul(in_image, in_image); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = builder->Add(builder->ConstantR0(bias_), @@ -146,13 +156,17 @@ class LRNGradOp : public XlaOpKernel { builder->Div(out_image, norm)), in_grads); - auto dy_reduced = builder->ReduceWindow( - dy, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto converted_dy = + XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto dy_reduce = builder->ReduceWindow( + converted_dy, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto dy_reduced = + XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::ComputationDataHandle gradients = builder->Add( + xla::XlaOp gradients = builder->Add( builder->Mul(in_image, dy_reduced), builder->Mul(in_grads, builder->Pow(norm, builder->ConstantR0(-beta_)))); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 886baf8115243a22b7255a3961c914d4cf6c2ed5..6949b296f4b9afe4a0c9152c763a9ad233b9f595 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -66,8 +66,8 @@ class MatMulOp : public XlaOpKernel { a_shape.DebugString(), ", In[1]: ", b_shape.DebugString())); - xla::ComputationDataHandle a = ctx->Input(0); - xla::ComputationDataHandle b = ctx->Input(1); + xla::XlaOp a = ctx->Input(0); + xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { a = ctx->builder()->ConvertElementType(a, xla::F32); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index faa415a97b053b4b11d015fefcd430210b98118a..fbd5dc0fdad4483aadbe9bc263cc1f7a034cee09 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -44,10 +44,10 @@ class MatrixBandPartOp : public XlaOpKernel { errors::InvalidArgument("num_upper must be scalar, got shape ", num_upper_in_shape.DebugString())); - xla::ComputationBuilder* builder = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle num_lower = context->Input(1); - xla::ComputationDataHandle num_upper = context->Input(2); + xla::XlaBuilder* builder = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp num_lower = context->Input(1); + xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); @@ -58,10 +58,10 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::ComputationDataHandle iota_m; + xla::XlaOp iota_m; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); - xla::ComputationDataHandle iota_n; + xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index b2940bdcff75a087c914fdad0cb2426276e41aff..db53f6fef8d6bf901c8281f50791ca6766c46efd 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -54,16 +54,16 @@ class MatrixSetDiagOp : public XlaOpKernel { input_shape.DebugString(), " and diagonal shape: ", diag_shape.DebugString())); - xla::ComputationBuilder* builder = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle diag = context->Input(1); + xla::XlaBuilder* builder = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp diag = context->Input(1); auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::ComputationDataHandle iota_m; + xla::XlaOp iota_m; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::ComputationDataHandle iota_n; + xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); auto indicator = builder->Eq(iota_m, builder->Broadcast(iota_n, {m}), diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 05a36a031ad73be289604da1b7e56203ff12fbf5..7e9de3ef9b245c113cc143128fe58e7e017a361c 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -25,10 +25,11 @@ class MirrorPadOp : public XlaOpKernel { public: explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} - xla::StatusOr DoMirrorPad( - const xla::ComputationDataHandle& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, xla::ComputationBuilder* b) { - xla::ComputationDataHandle accum = t; + xla::StatusOr DoMirrorPad(const xla::XlaOp& t, + const xla::Shape& original_shape, + const xla::Literal& pad_literal, + xla::XlaBuilder* b) { + xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = b->Rev(accum, {dimno}); @@ -76,12 +77,12 @@ class MirrorPadOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto in0 = ctx->Input(0); - xla::StatusOr> in0_shape = b->GetShape(in0); + xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); - xla::StatusOr accum_status = - DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b); + xla::StatusOr accum_status = + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 8c8a9bbe787f3224e7444b62dcf8ad99130cf37f..65ab9da8d7ca0509a4a69c43727a0e6c0435908a 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -24,8 +24,7 @@ namespace tensorflow { REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); // We register ControlTrigger as a no-op. This is correct since nodes seen -// by the XLA compiler are never dead. This may need rethinking when we add -// support for conditionals to XLA. -REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); +// by the XLA compiler are never dead. +REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 9f7c9913802d311895479b914b66553e135aa426..cac2eea96eeed723b2a63bc9193070cad04b005d 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -62,7 +62,7 @@ class OneHotOp : public XlaOpKernel { ctx, depth >= 0, errors::InvalidArgument("depth must be non-negative, got: ", depth)); - xla::ComputationDataHandle one_hot; + xla::XlaOp one_hot; OP_REQUIRES_OK( ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0), indices_shape, ctx->Input(0), ctx->Input(2), diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a4318e29d2532faf1f0cc6bb9418d29c2df20cd4..aecaabb6dcf46bdd6ae3da929448d6370acb989b 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -43,7 +43,7 @@ class PackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - std::vector values; + std::vector values; std::vector shapes; OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); const int num = values.size(); @@ -69,7 +69,7 @@ class PackOp : public XlaOpKernel { -expanded_num_dims, ", ", expanded_num_dims, ")")); - std::vector reshaped_inputs(num); + std::vector reshaped_inputs(num); TensorShape child_shape(shapes[0]); child_shape.InsertDim(axis, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 791351637aee61c5fdd911dd8a48959990514395..7c95475e7b1f02183e44f73f116a4aeb25f05c09 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -70,7 +70,7 @@ class PadOp : public XlaOpKernel { } // PadV2 added a "constant_values" input that indicates the pad value. - xla::ComputationDataHandle constant_values; + xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4fb5dd4e06c7c70591262c0d63a91c383a2a6e0..f8e7b48a0fd94835964aea033ad33523150067b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -35,8 +35,11 @@ namespace { // Superclass of pooling ops. class PoolingOp : public XlaOpKernel { public: - PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) - : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims, + const DataType reduction_type) + : XlaOpKernel(ctx), + num_spatial_dims_(num_spatial_dims), + reduction_type_(reduction_type) { if (ctx->num_inputs() == 1) { std::vector ksize_int; std::vector stride_int; @@ -63,22 +66,17 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } // Method that builds an initial value to use in reductions. - virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) = 0; + virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; // The reduction operation to apply to each window. - virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) = 0; + virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) = 0; + virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) = 0; void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - std::vector ksize = ksize_; std::vector stride = stride_; if (ctx->num_inputs() != 1) { @@ -106,16 +104,20 @@ class PoolingOp : public XlaOpKernel { stride.clear(); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); } + const TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), " dimensions")); - const DataType type = input_type(0); - xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( - input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, - stride, padding_); - ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); + xla::XlaBuilder* const b = ctx->builder(); + auto input = + XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); + auto reduce = ctx->builder()->ReduceWindow( + input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + ctx->SetOutput(0, + PostProcessOutput(ctx, pooled, input_type(0), input_shape)); } protected: @@ -124,26 +126,26 @@ class PoolingOp : public XlaOpKernel { std::vector stride_; xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; + DataType reduction_type_; }; class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ctx->input_type(0)) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::MinValue(b, data_type); + xla::XlaOp InitValue(xla::XlaBuilder* b) override { + return XlaHelpers::MinValue(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateMax(dtype); + const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateMax(reduction_type_); } - xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) override { + xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) override { return output; } }; @@ -174,9 +176,9 @@ REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); // Common computation shared between AvgPool and AvgPoolGrad. Divide each // element of an image by the count of elements that contributed to that // element during pooling. -static xla::ComputationDataHandle AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape, xla::Padding padding, +static xla::XlaOp AvgPoolDivideByCount( + XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape, xla::Padding padding, const std::vector& ksize, const std::vector& stride, int num_spatial_dims, TensorFormat data_format) { if (padding == xla::Padding::kValid) { @@ -209,15 +211,17 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( } // Build a matrix of all 1s, with the same width/height as the input. + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto ones = ctx->builder()->Broadcast( - XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); + XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto counts = ctx->builder()->ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), dtype), - *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, + auto reduce = ctx->builder()->ReduceWindow( + ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); + auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); return ctx->builder()->Div(output, counts, window_dims); } @@ -226,21 +230,21 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::Zero(b, data_type); + xla::XlaOp InitValue(xla::XlaBuilder* b) override { + return XlaHelpers::Zero(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateAdd(dtype); + const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateAdd(reduction_type_); } - xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) override { + xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) override { return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, ksize_, stride_, num_spatial_dims_, data_format_); @@ -340,11 +344,10 @@ class MaxPoolGradOp : public XlaOpKernel { xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - xla::ComputationDataHandle init_value = - XlaHelpers::Zero(ctx->builder(), input_type(2)); + xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter( + xla::XlaOp gradients = ctx->builder()->SelectAndScatter( input, select, ksize_, stride_, xla_padding, out_backprop, init_value, scatter); @@ -455,14 +458,12 @@ class AvgPoolGradOp : public XlaOpKernel { gradients_shape, filter_shape, out_backprop_shape, stride_, padding_, data_format_, &dims)); + // The input gradients are computed by a convolution of the output gradients + // and the filter, with some appropriate padding. See the comment at the top + // of conv_grad_ops.h for details. + xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - - // The input gradients are computed by a convolution of the output - // gradients - // and the filter, with some appropriate padding. See the comment at - // the top of conv_grad_ops.h for details. - DataType dtype = input_type(1); - + auto dtype = input_type(1); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; @@ -483,17 +484,18 @@ class AvgPoolGradOp : public XlaOpKernel { padding->set_interior_padding(dims.spatial_dims[i].stride - 1); } - auto zero = XlaHelpers::Zero(ctx->builder(), dtype); - auto padded_gradients = - ctx->builder()->Pad(out_backprop_div, zero, padding_config); + auto zero = XlaHelpers::Zero(b, dtype); + auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); - xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( - padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, + auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto in_backprop = b->ReduceWindow( + XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), ksize_, /* window_strides=*/ones, xla::Padding::kValid); - - ctx->SetOutput(0, in_backprop); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); } protected: @@ -525,5 +527,172 @@ class AvgPool3DGradOp : public AvgPoolGradOp { REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), AvgPool3DGradOp); +class MaxPoolGradGradOp : public XlaOpKernel { + public: + MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + + OP_REQUIRES(ctx, ksize_.size() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES(ctx, stride_.size() == num_dims(), + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims(), " dimensions")); + + const TensorShape tensor_in_shape = ctx->InputShape(0); + const TensorShape tensor_out_shape = ctx->InputShape(1); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // For maxpooling, tensor_in should have num_dims() dimensions. + OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_in must be ", num_dims(), + "-dimensional")); + OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_out must be ", num_dims(), + "-dimensional")); + // For maxpooling, out_backprop should have num_dims() dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), + errors::InvalidArgument("out_backprop must be ", num_dims(), + "-dimensional")); + + // What we want to compute: + // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad) + // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad. + // + // In the regular TF op, this amounts to selecting for each window the + // incoming backprop value from xs_grad_grad that corresponds to the maximal + // value in the corresponding window of x. + // + // TODO(b/73062247): What we really want is a ReduceWindow with different + // arrays for index selection vs return value selection--a select-to-gather. + // + // Here, we implement a bitwise hack: we use the hi 16 bits of input for + // separate max pooling alongside each of the hi and lo 16 bits of + // out_backprop packed into 16 lo bits, which we then glue back together at + // the end to get a full 32 bits of gradient. + // + // This could select the wrong backprop value for two x values that are + // equally maximal up to the first 16 bits, in which case we are taking the + // latter. + // + // Note that in principle we could use 32 separate maxpools to recover each + // of 32 bits of the gradient while preserving 31 bits of input for the max + // pooling criteria; here, we just truncate to the first 16 bits of input. + + auto input = ctx->Input(0); + auto out_backprop = ctx->Input(2); + + auto b = ctx->builder(); + + auto sixteen = b->ConstantR0(16); + // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 + auto in_hi = b->BitcastConvertType( + b->ConvertElementType(b->ConvertElementType(input, xla::BF16), + xla::F32), + xla::U32); + auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); + auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. + + auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + // We will reduce by taking the maximal value up to 16 bits (ignoring the lo + // 16 bits of packed-in hi/lo backprop value). + auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); + { + // F32 parameters to satisfy lowering type restriction for reduce opcode. + const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); + auto lhs = rb->Parameter(0, scalar, "lhs"); + auto rhs = rb->Parameter(1, scalar, "rhs"); + auto sixteen = rb->ConstantR0(16); + auto lhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); + // Must use a F32 comparison, because S32 would not work for negatives. + rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), + rb->BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); + } + auto reduce = rb->BuildAndNoteError(); + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + auto pooled_hi = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto pooled_lo = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto grads_hi = + b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = b->ShiftRightLogical( + b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + sixteen); + auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + + xla::PrimitiveType element_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); + ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + } + + protected: + const int num_spatial_dims_; + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_ = FORMAT_NHWC; +}; + +class MaxPool2DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool2DGradGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradGradV2") + .TypeConstraint("T", DT_FLOAT) + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradGradOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 4171e076ff6d9dd4f809454377620324d1fe5ae4..661cd5923e1023eaf89a6bc4f56fcc362c8bcfb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -35,7 +35,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); // Comments taken from semantics description at @@ -46,8 +46,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // m = max(abs(input_min), abs(input_max)) if range_given is true, // m = max(abs(min_elem(input)), // abs(max_elem(input))) otherwise. - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input_min, input_max; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input_min, input_max; if (range_given_) { double input_min_value, input_max_value; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); @@ -55,14 +55,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); } else { - const xla::Computation* fmax = ctx->GetOrCreateMax(data_type); - const xla::Computation* fmin = ctx->GetOrCreateMin(data_type); + const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); + const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); input_min = b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); input_max = b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); } - xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max)); + xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); // Next, we choose our fixed-point quantization buckets, [min_fixed, // max_fixed]. If signed_input is true, this is @@ -85,7 +85,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // From this we compute our scaling factor, s: // // s = (max_fixed - min_fixed) / (2 * m). - xla::ComputationDataHandle s = + xla::XlaOp s = b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); @@ -93,7 +93,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // e is transformed into e': // // e' = (e * s).round_to_nearest() / s. - xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s); + xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index c0994c434bca5174eaee7b9e63e10432d9c2ed8d..105be38fe26b6667e8b4ce6da92a3969cdc0c187 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,9 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -41,9 +44,9 @@ class RandomUniformOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle result = b->RngUniform( - XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -55,6 +58,78 @@ class RandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), RandomUniformOp); +class RandomShuffleOp : public XlaOpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + const int64 n = input_shape.dim_size(0); + int64 num_elements = 1; + for (tensorflow::TensorShapeDim dimension : input_shape) { + num_elements *= dimension.size; + } + if (num_elements <= 1 || n <= 1) { + // No shuffling is required, so copy input directly to output + ctx->SetOutput(0, input); + } else { + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + builder->RngUniform(builder->ConstantR0(0), + builder->ConstantR0(n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = builder->Reshape(i, {1}); + // temp = indices[i] + auto temp = builder->DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = builder->DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = builder->DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); +}; + +REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); + class RandomUniformIntOp : public XlaOpKernel { public: explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -100,11 +175,11 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::ComputationDataHandle result = b->RngNormal( - XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -127,22 +202,16 @@ class TruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::Shape xla_element_shape = - xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); - xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); - xla::ComputationDataHandle candidate = - b->RngNormal(mean, stddev, xla_shape); + xla::XlaBuilder* b = ctx->builder(); - auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) { + auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); }; - auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate, - xla::ComputationBuilder* b) { - xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); - xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); + auto out_of_range_mask = [two_sd](xla::XlaOp candidate, + xla::XlaBuilder* b) { + xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); + xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); return b->Or(too_large, too_small); }; @@ -152,37 +221,38 @@ class TruncatedNormalOp : public XlaOpKernel { // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd // candidate = select(out_of_range_mask, rng_normal(), candidate) // } - std::unique_ptr test_builder = - b->CreateSubBuilder("truncated_normal_test"); - { - auto* b = test_builder.get(); - xla::ComputationDataHandle candidate = - b->Parameter(0, xla_shape, "candidate"); - xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b); - OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); - } - - std::unique_ptr body_builder = - b->CreateSubBuilder("truncated_normal_body"); - { - auto* b = body_builder.get(); - xla::ComputationDataHandle candidate = - b->Parameter(0, xla_shape, "candidate"); - xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b); - xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); - xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); - b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); - } - - xla::StatusOr test_computation = test_builder->Build(); - OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); - OP_REQUIRES_OK(ctx, body_computation.status()); - xla::ComputationDataHandle result = - b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(), - candidate); - - ctx->SetOutput(0, result); + std::vector initial_values = { + // The current candidate. + b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), + // The to_resample mask, where 'true' identifies a location in the + // current candidate that is out of range and must be regenerated. + b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), + // Is any element in the mask true? + b->ConstantR0(true)}; + auto condition = [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr { + // Continue while any element in the mask is true. + return values[2]; + }; + auto body = + [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr> { + xla::XlaOp candidate = values[0]; + xla::XlaOp to_resample = values[1]; + xla::XlaOp mean = XlaHelpers::Zero(b, dtype); + xla::XlaOp stddev = XlaHelpers::One(b, dtype); + candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), + candidate); + // Compute a new to_resample mask, and determine whether any value is + // still out of range. + to_resample = out_of_range_mask(candidate, b); + TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); + return std::vector{candidate, to_resample, done}; + }; + auto result = + XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()[0]); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..08894489ac77bbbe4ddb067c06a6d031a537697d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class ReduceWindowOp : public XlaOpKernel { + public: + explicit ReduceWindowOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); + OP_REQUIRES_OK(context, + context->GetAttr("window_dimensions", &window_dimensions_)); + OP_REQUIRES_OK(context, + context->GetAttr("window_strides", &window_strides_)); + OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); + OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions_.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions_.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides_.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides_.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_low_.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low_.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high_.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high_.size(), " vs. ", rank, ")")); + + xla::XlaBuilder* builder = context->builder(); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *computation_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES(context, + xla::ShapeUtil::Compatible( + reducer.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({scalar_shape})), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + // Wraps the reducer in a computation that unpacks the output tuple. + xla::XlaComputation wrapper; + { + std::unique_ptr cb = + builder->CreateSubBuilder("wrapper"); + auto x = cb->Parameter(0, scalar_shape, "x"); + auto y = cb->Parameter(1, scalar_shape, "y"); + auto outputs = cb->Call(*reducer.computation, {x, y}); + cb->GetTupleElement(outputs, 0); + xla::StatusOr result = cb->Build(); + OP_REQUIRES_OK(context, result.status()); + wrapper = std::move(result.ValueOrDie()); + } + + std::vector> padding(rank); + for (int i = 0; i < rank; ++i) { + padding[i] = {padding_low_[i], padding_high_[i]}; + } + + xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( + context->Input(0), context->Input(1), wrapper, window_dimensions_, + window_strides_, padding); + context->SetOutput(0, output); + } + + private: + const NameAttrList* computation_; + std::vector window_dimensions_; + std::vector window_strides_; + std::vector padding_low_; + std::vector padding_high_; + + TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); +}; + +REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 03b13b2924f4b81c1017804c91d5ffb81c44ea0b..0f425637795e9633a8e36f921000ee2f5e25813a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -27,10 +27,14 @@ namespace { class SumOp : public XlaReductionOp { public: - explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + explicit SumOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Add(scalar_lhs, scalar_rhs); } }; @@ -39,16 +43,16 @@ REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: - explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit ProdOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { - return XlaHelpers::One(builder, input_type(0)); + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::One(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Mul(scalar_lhs, scalar_rhs); } }; @@ -58,18 +62,15 @@ REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), class MinOp : public XlaReductionOp { public: - explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MinOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MaxValue(type)); + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::MaxValue(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Min(scalar_lhs, scalar_rhs); } }; @@ -78,18 +79,15 @@ REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: - explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MaxOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MinValue(type)); + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::MinValue(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Max(scalar_lhs, scalar_rhs); } }; @@ -98,18 +96,21 @@ REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: - explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MeanOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Add(scalar_lhs, scalar_rhs); } - xla::ComputationDataHandle BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced) override { + xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); return builder->Div(reduce_output, divisor); @@ -121,16 +122,15 @@ REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), class AllOp : public XlaReductionOp { public: - explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AllOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return builder->ConstantR0(true); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->And(scalar_lhs, scalar_rhs); } }; @@ -139,16 +139,15 @@ REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: - explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AnyOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return builder->ConstantR0(false); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 9aca6d8fedf92f176b3b7b40c5961d4a2e557a8a..2ecfb854a1c8625524d4f1199af3927edd204926 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -28,41 +28,42 @@ namespace tensorflow { // to override: description is a textual description of the mapped // function; InitialValue constructs the base case for the reduction; // BuildReducer adds the implementation of the reduction lambda to a -// xla::ComputationBuilder and BuildFinalizer adds the +// xla::XlaBuilder and BuildFinalizer adds the // implementation of the finalizer lambda (if there is one) to a -// xla::ComputationBuilder. +// xla::XlaBuilder. class XlaReductionOp : public XlaOpKernel { public: - explicit XlaReductionOp(OpKernelConstruction* ctx); + XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); ~XlaReductionOp() override {} - // Return the base case for the reduction. Defaults to zero. - virtual xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder); + // Return the base case for the reduction. + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired // computation should be added to 'builder' and // '(scalar_lhs,scalar_rhs)' are the function's inputs. - virtual void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) = 0; + virtual void BuildReducer(xla::XlaBuilder* builder, + const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired // computation should be added to 'builder'. Argument 'reduce_output' is the // output of the reduction. 'num_elements_reduced' is the number of elements // that contributed to the reduction. Returns the transformed reduction // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::ComputationDataHandle BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced); + virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced); void Compile(XlaOpKernelContext* ctx) override; private: // True if the number of dimensions should be maintained. bool keep_dims_; + + protected: + DataType reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4b5d09eb9fd4110cdc4221099ff55767e9132540..4fd5bfd03999a7f8b7bb081cc4b03aa1434d4c3d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -24,25 +24,20 @@ limitations under the License. namespace tensorflow { -XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { +XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, + DataType reduction_type) + : XlaOpKernel(ctx), reduction_type_(reduction_type) { const DataType dt = BaseType(input_type(0)); OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); } -// Return the base case for the reduction. Defaults to zero. -xla::ComputationDataHandle XlaReductionOp::InitialValue( - xla::ComputationBuilder* builder) { - return XlaHelpers::Zero(builder, input_type(0)); -} - // Unless BuildFinalizer is overridden the reduction has no // finalizer. -xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced) { +xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced) { return reduce_output; } @@ -100,36 +95,26 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { string desc = ctx->op_kernel().name(); - // Call virtual method to get the initial value. - const xla::ComputationDataHandle initial = InitialValue(ctx->builder()); + xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::ComputationBuilder r(ctx->builder()->client(), - strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - // Make two scalar parameters of the desired type for the lambda. - xla::ComputationDataHandle rx = - r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - xla::ComputationDataHandle ry = - r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); - - auto data = ctx->Input(0); + TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); + auto data = b->ConvertElementType(ctx->Input(0), type); + // Call virtual method to get the initial value. + auto initial = b->ConvertElementType(InitialValue(b), type); + // Make two scalar parameters of the desired type for the lambda. + auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); - xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); - xla::ComputationDataHandle reduce = - ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); - - xla::ComputationDataHandle finalized = - BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); - - xla::ComputationDataHandle result; - if (keep_dims_) { - result = ctx->builder()->Reshape(finalized, final_shape); - } else { - result = finalized; - } + xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); + + auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 12a35529992e6160566046dd28f9321c88afec91..ba7d484d53d7258edaa5bc42fa116cf16e94835b 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -32,7 +32,7 @@ class ReluOp : public XlaOpKernel { explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); } @@ -43,7 +43,7 @@ class Relu6Op : public XlaOpKernel { explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Clamp the scalar input between 0 and 6. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); @@ -56,7 +56,7 @@ class ReluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); @@ -71,7 +71,7 @@ class Relu6GradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index c283e3b02c2676785952e3e17bffa671b0dabc1e..a711278638444be01fb865561957702368b75114 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -45,7 +45,7 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); auto is_constant = ctx->builder()->IsConstant(input); @@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + if (tc.resolve_compile_time_constants() && + (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - // The core from which a return value is returned depends on the core - // assignment of the input to the retval .Since we can't change the core - // assignment of as this point, create a tuple/get-tuple-element - // combination so that the core will be set on them. - auto tuple_elem = - ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); - tc.AddRetval(index_, dtype_, tuple_elem); + TensorShape shape = ctx->InputShape(0); + TensorShape representation_shape = + tc.is_entry_computation() + ? tc.RepresentationShape(shape, ctx->input_type(0)) + : shape; + + xla::XlaOp output = input; + if (tc.is_entry_computation()) { + output = + ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + } else { + // The core from which a return value is returned depends on the + // device assignment of the input to the retval. Since we can't change + // the device assignment of "input" at this point, we must always + // introduce an operator here, even if the shape does not change. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + output = ctx->builder()->GetTupleElement( + ctx->builder()->Tuple({output}), 0); + } + tc.AddRetval(index_, dtype_, shape, output); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index e51d386926763ecbb5a943dfb6f872e78901dc69..2872a3c4d49d0d269aa3d216887a5c32cd51f1c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -48,7 +48,7 @@ class ReverseOp : public XlaOpKernel { ctx->SetOutput(0, ctx->Input(0)); return; } - // ComputationBuilder::Rev() requires concrete values for dimensions arg. + // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); std::vector revdims(x_shape.dims()); @@ -90,7 +90,7 @@ class ReverseV2Op : public XlaOpKernel { ctx->SetOutput(0, ctx->Input(0)); return; } - // ComputationBuilder::Rev() requires concrete values for dimensions arg. + // XlaBuilder::Rev() requires concrete values for dimensions arg. std::vector axes; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 6bc5d3adb091cd238974c5b69b7a2f8fe639cc68..5d1c05268493f4f6404c40a4092a71f1e5b3f3b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -54,7 +54,7 @@ class ReverseSequenceOp : public XlaOpKernel { "), ", "(", seq_lens_shape.num_elements(), " vs. ", input_shape.dim_size(batch_dim_))); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); const auto input = context->Input(0); const auto seq_lens = context->Input(1); @@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens, body_builder->Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( + auto batch_element_indices = body_builder->Broadcast( XlaHelpers::Zero(body_builder.get(), seq_lens_type), {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), + batch_element_indices = body_builder->DynamicUpdateSlice( + batch_element_indices, body_builder->Reshape(i, {1}), body_builder->Reshape( XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, batch_dim_), {1})); - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, + // Slice out the current batch element and pad it out in the sequence + // dimension. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = body_builder->DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = body_builder->Pad( + slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = body_builder->DynamicUpdateSlice( + sequence_start_indices, body_builder->Sub(XlaHelpers::IntegerLiteral( body_builder.get(), seq_lens_type, max_seq_len), seq_len), @@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel { XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, seq_dim_), {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice = body_builder->DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = body_builder->DynamicUpdateSlice(output, slice, + batch_element_indices); body_builder->Tuple( {body_builder->Add( @@ -155,7 +169,7 @@ class ReverseSequenceOp : public XlaOpKernel { auto output = builder->GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::ComputationDataHandle iota; + xla::XlaOp iota; OP_REQUIRES_OK( context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); std::vector dims(input_shape.dims(), 1); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ee4a94164c4a43828eb4feedbfa9d1a9e231ef8f..1819fb543317eed15b2fe0518d74aba5c564697d 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -66,7 +66,7 @@ class ScanOp : public XlaOpKernel { -input_shape.dims(), ", ", input_shape.dims(), "), but got ", axis)); - DataType dtype = ctx->input_type(0); + DataType dtype = XlaHelpers::SumAccumulationType(ctx->input_type(0)); if (input_shape.num_elements() == 0) { // Exit early if there is nothing to compute. @@ -74,7 +74,7 @@ class ScanOp : public XlaOpKernel { return; } - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); std::vector window_strides(input_shape.dims(), 1); std::vector window_dims(input_shape.dims(), 1); @@ -91,9 +91,8 @@ class ScanOp : public XlaOpKernel { std::swap(padding[axis].first, padding[axis].second); } - xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle init; - const xla::Computation* reducer; + xla::XlaOp init; + const xla::XlaComputation* reducer; if (sum_) { init = XlaHelpers::Zero(builder, dtype); reducer = ctx->GetOrCreateAdd(dtype); @@ -102,7 +101,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = builder->ReduceWindowWithGeneralPadding( - ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, + *reducer, window_dims, window_strides, padding); + output = + XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 8433a29c4e203cac726ee6bf7f67a863447326ed..f2c63b4f9083ad3c7dd7cf318dc22def1e99fa9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -102,7 +102,7 @@ class ScatterNdOp : public XlaOpKernel { OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape)); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), buffer_shape.dim_sizes()); auto indices = context->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 80d6df6c48b0141734dcee1c2a3c413926931feb..664078ca16c6d5d4b57c4a8c661ad0848f30dd7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -62,16 +62,16 @@ class UnsortedSegmentSum : public XlaOpKernel { d, " differs ", data_shape.dim_size(d), " vs. ", indices_shape.dim_size(d))); } - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), buffer_shape.dim_sizes()); - auto combiner = - [](xla::ComputationDataHandle a, xla::ComputationDataHandle b, - xla::ComputationBuilder* builder) { return builder->Add(a, b); }; + auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { + return builder->Add(a, b); + }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); @@ -83,7 +83,9 @@ class UnsortedSegmentSum : public XlaOpKernel { DataType dtype_; }; -REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum); +REGISTER_XLA_OP( + Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), + UnsortedSegmentSum); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 8081d3c41c436324c21858124121fecfac71cefa..f9f48164d63492b057d4950abfc2ca6153e44870 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -40,7 +40,7 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 5172781c0d05b6682fe92086654e3b86961949ee..9ce01d0d44509bbcbea18afdb4210a675834bb6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -48,7 +48,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { ctx->builder()->Send(ctx->Input(0), channel); } -REGISTER_XLA_OP(Name("_XLASend"), SendOp); +REGISTER_XLA_OP(Name("XlaSend"), SendOp); class RecvOp : public XlaOpKernel { public: @@ -68,7 +68,7 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { TensorShape tensor_shape; DataType dtype; OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype)); OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_)); } @@ -79,7 +79,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel)); } -REGISTER_XLA_OP(Name("_XLARecv"), RecvOp); +REGISTER_XLA_OP(Name("XlaRecv"), RecvOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 05354bca5bb089703fdcceb6f44648bbb98d004b..d59720bef742c7441ee01a954247013559bb909c 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape"), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank"), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size"), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze dimension ", i, + " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 750a4c2dec8154f97f307978b3d8884271292279..bbf5ee8b12186a582666121b1df5d8b7d881863e 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -28,7 +29,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = StringPiece(type_string()).starts_with("Log"); + log_ = str_util::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { @@ -42,9 +43,8 @@ class SoftmaxOp : public XlaOpKernel { const DataType type = input_type(0); auto logits = ctx->Input(0); - xla::ComputationBuilder* b = ctx->builder(); - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); + xla::XlaBuilder* const b = ctx->builder(); + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = @@ -52,21 +52,20 @@ class SoftmaxOp : public XlaOpKernel { // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - xla::ComputationDataHandle softmax; - if (log_) { - // softmax = shifted_logits - log(sum(exp(shifted_logits))) - auto log_sum_exp = - b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type), - add_func, {kClassDim})); - softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - } else { - // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - auto exp_shifted = b->Exp(shifted_logits); - auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func, - {kClassDim}); - softmax = b->Div(exp_shifted, sum_exp, {kBatchDim}); - } - + auto exp_shifted = b->Exp(shifted_logits); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto softmax = + log_ + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + : b->Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -77,17 +76,15 @@ class SoftmaxOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); -std::pair -CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, - const xla::ComputationDataHandle& logits, - const xla::ComputationDataHandle& labels) { - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); +std::pair CrossEntropyWithLogits( + XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, + const xla::XlaOp& labels) { + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; const int kClassDim = 1; - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); @@ -100,8 +97,12 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, auto exp_shifted_logits = b->Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) - auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), - add_func, {kClassDim}); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = b->Log(sum_exp); @@ -110,14 +111,18 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - xla::ComputationDataHandle loss = b->Reduce( - b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), - XlaHelpers::Zero(b, type), add_func, {kClassDim}); + auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = b->Mul(b->Neg(labels), sub); + auto sum = + b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) - xla::ComputationDataHandle backprop = + xla::XlaOp backprop = b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -144,7 +149,7 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { auto logits = ctx->Input(0); auto labels = ctx->Input(1); - xla::ComputationDataHandle loss, backprop; + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, type, logits, labels); ctx->SetOutput(0, loss); @@ -185,10 +190,10 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { DataType logits_type = input_type(0); DataType indices_type = input_type(1); - xla::ComputationDataHandle indices = ctx->Input(1); + xla::XlaOp indices = ctx->Input(1); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle labels; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp labels; OP_REQUIRES_OK(ctx, XlaHelpers::OneHot( builder, depth, /*axis=*/1, input_type(1), labels_shape, @@ -201,7 +206,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::ComputationDataHandle nan_or_zero = builder->Select( + xla::XlaOp nan_or_zero = builder->Select( builder->And( builder->Le(XlaHelpers::Zero(builder, indices_type), indices), builder->Lt(indices, XlaHelpers::IntegerLiteral( @@ -212,7 +217,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { {batch_size})); labels = builder->Add(labels, nan_or_zero, {0}); - xla::ComputationDataHandle loss, backprop; + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 01b46e160d1f1f10a43faf7ca35afb42dfde6e33..ec077924b5b5af4a573c86c8d9aeb8623bd7f801 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -20,9 +20,8 @@ limitations under the License. namespace tensorflow { namespace { -void SpaceToBatch(XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, DataType input_dtype, - const TensorShape& input_tensor_shape, +void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, + DataType input_dtype, const TensorShape& input_tensor_shape, gtl::ArraySlice block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); @@ -46,7 +45,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, ", 2] instead of ", xla::ShapeUtil::HumanString(paddings.shape()))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the // input according to `paddings` to produce `padded` of shape `padded_shape`. @@ -73,7 +72,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, errors::InvalidArgument( "The product of the block dimensions must be positive")); - xla::ComputationDataHandle padded = + xla::XlaOp padded = b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: @@ -101,8 +100,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::ComputationDataHandle reshaped_padded = - b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -121,7 +119,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, permutation[block_rank] = 0; std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::ComputationDataHandle permuted_reshaped_padded = + xla::XlaOp permuted_reshaped_padded = b->Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the @@ -142,8 +140,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 806fda632cde64c1b37ae3b9199028d6b6b0a215..4c5886ee2a0f63d609f79fc690f457d93e284e3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -50,8 +50,8 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,8 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -156,8 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 79c435c90a1f57250be90c2c2523bf3d7d231461..8958b2e7701e62d802e37a895c14b662ecf9786a 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -111,27 +111,24 @@ class SplitVOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); + const TensorShape input_shape = ctx->InputShape(0); const TensorShape index_shape = ctx->InputShape(2); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index)); - int32 split_dim; - OP_REQUIRES(ctx, index_shape.dims() == 0, - errors::InvalidArgument("split_dim input to Split Op must be a " - "scalar")); - split_dim = literal_index.Get({}); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig)); + int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() + : split_dim_orig; + OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("-input rank(-", input_shape.dims(), + ") <= split_dim < input rank (", + input_shape.dims(), "), but got ", + split_dim_orig)); - xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); + xla::XlaOp input = ctx->Input(0); OP_REQUIRES(ctx, input_shape.dims() > 0, errors::InvalidArgument("Can't split a 0 dimensional input")); - OP_REQUIRES( - ctx, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); - OP_REQUIRES( ctx, num_split > 0, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 1a78c7ab9be701d3d02285ed21604f0f856b3f1f..0fb05a2be7b1034d6c2e864643b69647d622ede7 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -38,13 +38,13 @@ limitations under the License. namespace tensorflow { namespace { -Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, +Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, TensorShape* stack_shape) { auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } - xla::Shape shape = *shape_or_status.ValueOrDie(); + xla::Shape shape = shape_or_status.ValueOrDie(); TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), stack_shape); @@ -60,9 +60,8 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, // // TODO(phawkins): consider changing the API of the stack operators to // allow an optional element shape at stack construction time. -Status MaybeInitializeStack(xla::ComputationBuilder* builder, - XlaResource* resource, DataType dtype, - const TensorShape& elem_shape) { +Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, + DataType dtype, const TensorShape& elem_shape) { if (resource->type() != dtype) { return errors::InvalidArgument( "Stack dtype is ", DataTypeString(resource->type()), @@ -75,8 +74,6 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder, if (!resource->initialized()) { // Stack has not been initialized. - xla::ComputationDataHandle zero = - XlaHelpers::Zero(builder, resource->type()); TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { @@ -111,7 +108,7 @@ class StackOp : public XlaOpKernel { // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. - xla::ComputationDataHandle value; + xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; string name = strings::StrCat("Stack: ", stack_name_); @@ -138,7 +135,7 @@ class StackPushOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape elem_shape = ctx->InputShape(1); XlaResource* resource; @@ -147,9 +144,9 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0); - xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1); - xla::ComputationDataHandle value = ctx->Input(1); + xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); + xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -184,7 +181,7 @@ class StackPopOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -199,9 +196,9 @@ class StackPopOp : public XlaOpKernel { TensorShape stack_shape; OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); - xla::ComputationDataHandle state = resource->value(); - xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); - xla::ComputationDataHandle index = b->GetTupleElement(state, 1); + xla::XlaOp state = resource->value(); + xla::XlaOp ta = b->GetTupleElement(state, 0); + xla::XlaOp index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); @@ -216,8 +213,7 @@ class StackPopOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::ComputationDataHandle read = - b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index b10880de77e6b9811008076cd4a959c284e558d1..a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -30,9 +30,8 @@ namespace tensorflow { namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& v, - int distance) { +xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, + int distance) { return builder->Or( builder->ShiftLeft(v, builder->ConstantR0(distance)), builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); @@ -40,25 +39,24 @@ xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, // TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than // building XOR out of other bitwise operators. -xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& y) { +xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x, + const xla::XlaOp& y) { return builder->Or(builder->And(x, builder->Not(y)), builder->And(builder->Not(x), y)); } -using ThreeFry2x32State = std::array; +using ThreeFry2x32State = std::array; // Implements the ThreeFry counter-based PRNG algorithm. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, +ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, ThreeFry2x32State input, ThreeFry2x32State key) { // Rotation distances specified by the Threefry2x32 algorithm. constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; ThreeFry2x32State x; - std::array ks; + std::array ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. ks[2] = builder->ConstantR0(0x1BD11BDA); for (int i = 0; i < 2; ++i) { @@ -121,10 +119,9 @@ ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, // Returns a tensor of 'shape' random values uniformly distributed in the range // [minval, maxval) -xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& seed, - const TensorShape& shape, - double minval, double maxval) { +xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, + const TensorShape& shape, double minval, + double maxval) { // Split the seed into two 32-bit scalars to form a key. auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); @@ -178,9 +175,8 @@ xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, // p = sum_{i=1}^n gq[i]*w^i // } // return p*x -xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& x, - const TensorShape& shape) { +xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, + const TensorShape& shape) { constexpr int kDegree = 9; constexpr std::array w_less_than_5_constants = { 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, @@ -220,7 +216,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); TensorShape shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); @@ -229,7 +225,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); - xla::ComputationDataHandle seed = ctx->Input(1); + xla::XlaOp seed = ctx->Input(1); ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); } @@ -239,6 +235,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomUniform") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformOp); @@ -256,9 +253,10 @@ class StatelessRandomNormalOp : public XlaOpKernel { OP_REQUIRES(ctx, seed_shape == TensorShape({2}), errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); - xla::ComputationDataHandle seed = ctx->Input(1); - xla::ComputationBuilder* builder = ctx->builder(); - auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + xla::XlaOp seed = ctx->Input(1); + xla::XlaBuilder* builder = ctx->builder(); + auto uniform = + RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), @@ -272,6 +270,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomNormal") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 91c169428c7a88a8d107a97445aeea999946e3e9..55254c746e5ebaf6b468c24ab59b968bf0d6260b 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -77,19 +77,20 @@ class StridedSliceOp : public XlaOpKernel { for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { slice_begin.push_back(begin[i]); - slice_end.push_back(end[i]); + slice_end.push_back(std::max(end[i], begin[i])); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); - slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); + slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, + input_shape.dim_size(i) - begin[i] - 1)); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } } - xla::ComputationDataHandle slice = ctx->Input(0); + xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } @@ -167,7 +168,7 @@ class StridedSliceGradOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); - xla::ComputationDataHandle grad = ctx->Input(4); + xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); @@ -254,7 +255,7 @@ class StridedSliceAssignOp : public XlaOpKernel { &strides_tensor)); TensorShape lhs_shape; - xla::ComputationDataHandle lhs; + xla::XlaOp lhs; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -283,7 +284,7 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle rhs = ctx->Input(4); + xla::XlaOp rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_dims; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 000b50af6bd86b7268c016865fb0856c16053ece..9adee78a1fd1fb9a12afae83197425c328b5fe7e 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -47,7 +47,7 @@ namespace { // the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. -Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, +Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { if (resource->kind() != XlaResource::kTensorArray) { @@ -64,9 +64,6 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, << resource->name() << " size " << resource->tensor_array_size(); if (!resource->initialized()) { - xla::ComputationDataHandle zero = - XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { @@ -77,7 +74,7 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, } TensorShape shape; TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TensorShape ta_shape; ta_shape.AddDim(resource->tensor_array_size()); @@ -114,23 +111,21 @@ Status CheckTensorArrayIsInitialized(const string& op_name, } Status GetTensorArrayShape(const XlaResource* resource, - xla::ComputationBuilder* builder, - TensorShape* shape) { + xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } -// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the +// Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the // relevant slice of 'operand'. -xla::ComputationDataHandle DynamicAddSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand, - const xla::ComputationDataHandle& update, - const gtl::ArraySlice& update_dims, - const xla::ComputationDataHandle& start_indices) { - xla::ComputationDataHandle current = +xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, + const xla::XlaOp& update, + const gtl::ArraySlice& update_dims, + const xla::XlaOp& start_indices) { + xla::XlaOp current = builder->DynamicSlice(operand, start_indices, update_dims); - xla::ComputationDataHandle sum = builder->Add(current, update); + xla::XlaOp sum = builder->Add(current, update); return builder->DynamicUpdateSlice(operand, sum, start_indices); } @@ -155,18 +150,18 @@ class TensorArrayOp : public XlaOpKernel { OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument("TensorArray size must be >= 0")); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. - xla::ComputationDataHandle value; + xla::XlaOp value; TensorShape shape; if (element_shape_.IsFullyDefined()) { CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); ta_shape.AppendShape(shape); - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_); + xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); value = b->Broadcast(zero, ta_shape.dim_sizes()); } @@ -202,7 +197,7 @@ class TensorArrayWriteOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape elem_shape = ctx->InputShape(2); @@ -213,10 +208,10 @@ class TensorArrayWriteOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value(); - xla::ComputationDataHandle index = ctx->Input(1); - xla::ComputationDataHandle value = ctx->Input(2); - xla::ComputationDataHandle flow = ctx->Input(3); + xla::XlaOp ta = resource->value(); + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + xla::XlaOp flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -227,7 +222,7 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = b->Reshape(value, slice_shape.dim_sizes()); - xla::ComputationDataHandle written = + xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); OP_REQUIRES_OK(ctx, resource->SetValue(written)); @@ -249,7 +244,7 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -259,8 +254,8 @@ class TensorArrayReadOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value(); - xla::ComputationDataHandle index = ctx->Input(1); + xla::XlaOp ta = resource->value(); + xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -270,8 +265,7 @@ class TensorArrayReadOp : public XlaOpKernel { auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::ComputationDataHandle read = - b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); @@ -293,7 +287,7 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -309,7 +303,7 @@ class TensorArrayGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); // Look for the case where the gather takes a simple slice from the // tensor array (0, 1, 2, 3, 4, ..., N) @@ -337,7 +331,7 @@ class TensorArrayGatherOp : public XlaOpKernel { } } - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, @@ -360,7 +354,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape value_shape = ctx->InputShape(2); @@ -375,11 +369,11 @@ class TensorArrayScatterOp : public XlaOpKernel { OP_REQUIRES(ctx, indices_shape.dims() >= 1, errors::InvalidArgument("indices must be rank 1")); const int num_indices = indices_shape.dim_size(0); - const xla::ComputationDataHandle indices = ctx->Input(1); + const xla::XlaOp indices = ctx->Input(1); - xla::ComputationDataHandle ta = resource->value(); - const xla::ComputationDataHandle value = ctx->Input(2); - const xla::ComputationDataHandle flow = ctx->Input(3); + xla::XlaOp ta = resource->value(); + const xla::XlaOp value = ctx->Input(2); + const xla::XlaOp flow = ctx->Input(3); // Look for the case where the scatter is for each sub-tensor in order. The // tensor array implementation allows for this to be a straight addition. @@ -443,7 +437,7 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -453,7 +447,7 @@ class TensorArrayConcatOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -503,12 +497,12 @@ class TensorArraySplitOp : public XlaOpKernel { TensorShape elem_shape = value_shape; elem_shape.set_dim(0, length); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); TensorShape ta_shape; ta_shape.AddDim(resource->tensor_array_size()); @@ -520,8 +514,8 @@ class TensorArraySplitOp : public XlaOpKernel { "TensorArray's size is not equal to the size of lengths (", lengths.size(), " vs. ", resource->tensor_array_size(), ")")); - const xla::ComputationDataHandle value = ctx->Input(1); - const xla::ComputationDataHandle flow = ctx->Input(3); + const xla::XlaOp value = ctx->Input(1); + const xla::XlaOp flow = ctx->Input(3); OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), errors::InvalidArgument("mismatched element count ", @@ -569,7 +563,7 @@ class TensorArrayGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9aefcd4fc7f94a1dba1c56273c55d0b98fbbfaf2..e91075196bd8414939888e22b5483ad637487af6 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -112,7 +112,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::ComputationDataHandle output = + xla::XlaOp output = ctx->builder()->Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index f750f7003be288461f5f10455e58932d1b4e4524..34caefa050c0d58f5f7bad557286b6ed64b996ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -30,8 +30,8 @@ class ResourceApplyGradientDescent : public XlaOpKernel { explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaOp handle; + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -63,12 +63,12 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; - xla::ComputationDataHandle var, accum; + xla::XlaOp var, accum; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); @@ -93,9 +93,9 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(2); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle momentum = ctx->Input(4); + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); accum = b->Add(b->Mul(accum, momentum), grad); if (use_nesterov_) { @@ -121,12 +121,12 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; - xla::ComputationDataHandle var, accum; + xla::XlaOp var, accum; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); @@ -146,8 +146,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(2); - xla::ComputationDataHandle grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); var = b->Sub( @@ -168,7 +168,7 @@ class ResourceApplyAdam : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape var_shape, m_shape, v_shape; - xla::ComputationDataHandle var, m, v; + xla::XlaOp var, m, v; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); @@ -213,25 +213,25 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle beta1_power = ctx->Input(3); - xla::ComputationDataHandle beta2_power = ctx->Input(4); - xla::ComputationDataHandle lr = ctx->Input(5); - xla::ComputationDataHandle beta1 = ctx->Input(6); - xla::ComputationDataHandle beta2 = ctx->Input(7); - xla::ComputationDataHandle epsilon = ctx->Input(8); - xla::ComputationDataHandle grad = ctx->Input(9); + xla::XlaOp beta1_power = ctx->Input(3); + xla::XlaOp beta2_power = ctx->Input(4); + xla::XlaOp lr = ctx->Input(5); + xla::XlaOp beta1 = ctx->Input(6); + xla::XlaOp beta2 = ctx->Input(7); + xla::XlaOp epsilon = ctx->Input(8); + xla::XlaOp grad = ctx->Input(9); // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); - xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::ComputationDataHandle alpha = + xla::XlaOp alpha = b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), b->Sub(one, beta1_power)); m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); @@ -255,12 +255,12 @@ class ResourceApplyRMSProp : public XlaOpKernel { explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(3); TensorShape var_shape, ms_shape, mom_shape; - xla::ComputationDataHandle var, ms, mom; + xla::XlaOp var, ms, mom; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); @@ -297,11 +297,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(3); - xla::ComputationDataHandle rho = ctx->Input(4); - xla::ComputationDataHandle momentum = ctx->Input(5); - xla::ComputationDataHandle epsilon = ctx->Input(6); - xla::ComputationDataHandle grad = ctx->Input(7); + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp momentum = ctx->Input(5); + xla::XlaOp epsilon = ctx->Input(6); + xla::XlaOp grad = ctx->Input(7); // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -320,16 +320,16 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::ComputationDataHandle new_ms = b->Add( + xla::XlaOp new_ms = b->Add( ms, b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); - xla::ComputationDataHandle new_mom = + xla::XlaOp new_mom = b->Add(b->Mul(mom, momentum), b->Mul(b->Mul(grad, lr), b->Pow(b->Add(new_ms, epsilon), XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::ComputationDataHandle new_var = b->Sub(var, new_mom); + xla::XlaOp new_var = b->Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -341,10 +341,10 @@ REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape var_shape, accum_shape, linear_shape; - xla::ComputationDataHandle var, accum, linear; + xla::XlaOp var, accum, linear; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); @@ -399,12 +399,12 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle lr = ctx->Input(4); - xla::ComputationDataHandle l1 = ctx->Input(5); - xla::ComputationDataHandle l2 = ctx->Input(6); - xla::ComputationDataHandle l2_shrinkage; - xla::ComputationDataHandle lr_power; + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp l1 = ctx->Input(5); + xla::XlaOp l2 = ctx->Input(6); + xla::XlaOp l2_shrinkage; + xla::XlaOp lr_power; if (has_l2_shrinkage) { l2_shrinkage = ctx->Input(7); lr_power = ctx->Input(8); @@ -421,26 +421,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, // var = (linear_clipped - linear) / quadratic // accum = new_accum - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - xla::ComputationDataHandle grad_to_use; + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::XlaOp grad_to_use; if (has_l2_shrinkage) { grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); } else { grad_to_use = grad; } - xla::ComputationDataHandle new_accum = - b->Add(accum, b->Pow(grad_to_use, two)); - xla::ComputationDataHandle new_accum_lr_pow = - b->Pow(new_accum, b->Neg(lr_power)); - xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); + xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); + xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); + xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); linear = b->Add( linear, b->Sub(grad_to_use, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::ComputationDataHandle linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::ComputationDataHandle quadratic = - b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); + xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); + xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); var = b->Div(b->Sub(linear_clipped, linear), quadratic); accum = new_accum; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0c5ad9e5255ffc3dfcfb83335060ae833937b3ce..71a9fd051bfc8db09738a4bfe8ddde447895ecf0 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -33,9 +33,9 @@ namespace { public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ - xla::ComputationBuilder* b = ctx->builder(); \ - xla::ComputationDataHandle x = ctx->Input(0); \ - xla::ComputationDataHandle y = COMPUTATION; \ + xla::XlaBuilder* b = ctx->builder(); \ + xla::XlaOp x = ctx->Input(0); \ + xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ } \ }; \ @@ -60,11 +60,13 @@ XLAJIT_MAKE_UNARY( b->Add(XlaHelpers::One(b, input_type(0)), x)))); // acosh(x) = log(x + sqrt(x^2 - 1)) +// = log(x + sqrt((x+1)*(x-1))) XLAJIT_MAKE_UNARY( Acosh, - b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + b->Log(b->Add(x, + b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), + b->Sub(x, XlaHelpers::One(b, input_type(0)))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( @@ -98,8 +100,7 @@ XLAJIT_MAKE_UNARY(Cosh, XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); -// TODO(b/34703906): use a more accurate implementation of expm1. -XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); @@ -113,8 +114,7 @@ XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Log, b->Log(x)); -// TODO(b/34703906): use a more accurate implementation of log1p. -XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); +XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); XLAJIT_MAKE_UNARY(Invert, b->Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); @@ -122,9 +122,8 @@ XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, - DataType dtype, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, + const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); @@ -146,9 +145,8 @@ XLAJIT_MAKE_UNARY(Rsqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, - DataType dtype, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, + const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); } @@ -160,26 +158,17 @@ XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -static xla::ComputationDataHandle Softplus( - xla::ComputationBuilder* b, DataType dtype, - const xla::ComputationDataHandle& features) { - xla::ComputationDataHandle threshold = - b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), - XlaHelpers::FloatLiteral(b, dtype, 2.0)); - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold)); - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - xla::ComputationDataHandle too_small = b->Lt(features, threshold); - xla::ComputationDataHandle features_exp = b->Exp(features); - xla::ComputationDataHandle output = b->Select( - too_large, features, - b->Select(too_small, features_exp, - b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); - return output; -} -XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); +// softplus(x) = log(1 + exp(x)) +// +// This is not numerically stable when x is large, it can easily overflow. +// However, we can compute it as LogSumExp(x, 0): +// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) +// +// This is equivalent to: +// max(x, 0) + log1p(exp(-abs(x))) +XLAJIT_MAKE_UNARY(Softplus, + b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 71173f5aead47702f0ed9e95b827a6fefd9b7efd..a163fa0a5b34675e46d0d7c5f4e0ccb1e3fb18eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -48,7 +48,7 @@ class ReadVariableOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK( ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); @@ -57,7 +57,7 @@ class ReadVariableOp : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); +REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); class AssignVariableOp : public XlaOpKernel { public: @@ -67,14 +67,14 @@ class AssignVariableOp : public XlaOpKernel { ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); } }; -REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); +REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); @@ -90,7 +90,7 @@ class AssignSubVariableOp : public XlaOpKernel { explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); @@ -105,19 +105,19 @@ class ResourceGatherOp : public XlaOpKernel { public: explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); DataType type = ctx->expected_output_dtype(0); TensorShape resource_shape; - xla::ComputationDataHandle resource_handle; + xla::XlaOp resource_handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/false, type, index_type, diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 0ff1b65ae9179d506e453f98097cd88083eb2be7..5467c5d9946846ff9f14ce9c5aac9e2be4b9d6ab 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -101,7 +101,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { ctx, MakeXlaCompilerArgumentsFromInputs( ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); XlaCompiler* compiler = ctx->compiler(); VLOG(1) << "Compiling body"; @@ -234,7 +234,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); - std::vector inputs(num_inputs); + std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; if (ctx->input_type(input_num) == DT_RESOURCE) { @@ -246,24 +246,24 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::ComputationDataHandle init = builder->Tuple(inputs); + xla::XlaOp init = builder->Tuple(inputs); VLOG(1) << "Building while loop"; // Wraps the condition in a computation that unpacks the output tuple. - xla::Computation cond_wrapper; + xla::XlaComputation cond_wrapper; { - std::unique_ptr cb = + std::unique_ptr cb = builder->CreateSubBuilder("cond_wrapper"); auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); auto outputs = cb->Call(*cond.computation, {inputs}); cb->GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); + xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::ComputationDataHandle while_result = + xla::XlaOp while_result = builder->While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 488fda74bf7b5c1d66f8d706a1be3cc1fc29a492..ee7f5d510ab7a3ce7d3bbe843c5fefd362f79b7b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -39,12 +39,13 @@ cc_library( ":batch_dot", ":triangular_solve", ":util", + ":while_loop", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -61,9 +62,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -79,10 +80,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -90,6 +90,7 @@ cc_library( xla_test( name = "triangular_solve_test", srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", @@ -99,9 +100,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -120,12 +121,35 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) +xla_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop", srcs = ["while_loop.cc"], @@ -135,22 +159,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 798f0fa78055e800038e8bf41b4f410b670be7dd..526694d5a0c7124e1696f34b516f3b202462bc19 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,24 +25,22 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x, bool conjugate_y) { - TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, - builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, - builder->GetShape(y)); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, + bool conjugate_y) { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); // Check that both tensors have the same number of dimensions. There must be // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(*x_shape), " vs. ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs. ", + xla::ShapeUtil::HumanString(y_shape)); } - const int ndims = xla::ShapeUtil::Rank(*x_shape); + const int ndims = xla::ShapeUtil::Rank(x_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to BatchedDot must have rank >= 2: ", ndims); @@ -52,46 +50,46 @@ xla::StatusOr BatchDot( // valid. std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - if (x_shape->dimensions(i) != y_shape->dimensions(i)) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " vs ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); } batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(*y_shape), + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), " transpose: ", transpose_y); } // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(*x_shape) || - xla::ShapeUtil::HasZeroElements(*y_shape)) { + if (xla::ShapeUtil::HasZeroElements(x_shape) || + xla::ShapeUtil::HasZeroElements(y_shape)) { std::vector dimensions(batch_dimension_numbers.size()); for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); } int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), dimensions); } - if (x_shape->element_type() == xla::C64 && conjugate_x) { + if (x_shape.element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && conjugate_y) { + if (y_shape.element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b230e885f10f45a78cdd6e455da3ba55ce589b96..1acc72033b05e73b0f5f88907df20cde5cfffbf0 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -43,10 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x = false, bool conjugate_y = false); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index e795701181dd80a2ff544743d513bffd52fd2399..20925118bf598a6436c43bd727ce40e3abafc46c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,88 +32,137 @@ namespace tensorflow { namespace { +// The Cholesky–Banachiewicz algorithm. See +// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms +// for a description. +// // def cholesky_unblocked(a): // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] // n = a.shape[-2] // l = np.zeros_like(a) // for j in xrange(n): -// r = l[..., j, :j] -// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r)) -// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], -// np.transpose(r))) / l[..., j, j] +// row = l[..., j, :j] +// row_t = np.swapaxes(row, -1, -2) +// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t)) +// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / +// l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(a)); - xla::ComputationDataHandle l = Zeros(builder, *shape); - const int64 n = xla::ShapeUtil::GetDimension(*shape, -2); - for (int j = 0; j < n; ++j) { - // Picture of block structure: - // ... \ - // \ - // -- r -- d - // |\ - // B c \ - // | \ - // | ... - // - // ^ - // column j - TF_ASSIGN_OR_RETURN(auto d, - SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1})); - TF_ASSIGN_OR_RETURN(auto c, - SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1})); - xla::ComputationDataHandle new_d_squared = d; - xla::ComputationDataHandle br; - if (j > 0) { - TF_ASSIGN_OR_RETURN(auto r, - SliceInMinorDims(builder, l, {j, 0}, {j + 1, j})); - TF_ASSIGN_OR_RETURN(auto b, - SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); - TF_ASSIGN_OR_RETURN(auto r_squared, - BatchDot(builder, r, r, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - new_d_squared = builder->Sub(new_d_squared, r_squared); - - TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, - /*transpose_y=*/true, - /*conjugate_x=*/false, - /*conjugate_y=*/false)); - } - auto new_d_inv = builder->Pow( - new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); - auto new_d = builder->Mul(new_d_inv, new_d_squared); - TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j})); - - if (j > 0) { - c = builder->Sub(c, br); +xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, + const xla::XlaOp& a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); + + xla::XlaOp l = Zeros(builder, a_shape); + + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); } - auto new_c = builder->Mul(c, new_d_inv); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j})); - } - return l; + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = body_builder->ConstantR1(mask_vector); + auto mask_range_row = body_builder->Broadcast( + body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = body_builder->Broadcast( + body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = body_builder->ConstantR0(0); + TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, + {i, zero}, {1, n})); + auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), + mask_zeros_row, l_i); + // a[..., i, i] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, + {i, i}, {1, 1})); + // np.dot(row, np.swapaxes(row, -1, -2)) + xla::XlaOp diag_dot; + TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, + /*transpose_x=*/false, + /*transpose_y=*/true)); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = body_builder->Pow( + body_builder->Sub(a_ii, diag_dot), + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + // select the whole i-th column, then mask out all rows above i+1 + TF_ASSIGN_OR_RETURN( + auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); + auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), + mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true)); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), + mask_zeros_col, dot); + + auto col_update = + body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); + TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( + body_builder, body_l, col_update, {i})); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( + body_builder, body_l, l_ii, {i, i})); + + return std::vector{body_a, body_l}; + }; + + TF_ASSIGN_OR_RETURN( + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; } } // namespace -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to Cholesky must have rank >= 2: ", ndims); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { @@ -124,7 +174,7 @@ xla::StatusOr Cholesky( // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::ComputationDataHandle l = Zeros(builder, *a_shape); + xla::XlaOp l = Zeros(builder, a_shape); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -163,7 +213,7 @@ xla::StatusOr Cholesky( /*lower=*/true, /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/8)); + /*block_size=*/block_size)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index e083a383be4be0d1b556b63214fe5f70323b4149..20fca7969ece2729a44933fd3ef3f87230ab6cad 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -29,10 +29,9 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. -// TODO(mattjj): handle the complex Hermitian case -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size = 256); +// TODO(znado): handle the complex Hermitian case +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 45699233ea8b2a75e3850098250307b95546cc28..d5a27abb2585f699ae2719cb8a6b9a829263389e 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -30,24 +30,19 @@ limitations under the License. namespace tensorflow { -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, - builder->GetShape(buffer)); - TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, - builder->GetShape(updates)); - TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, - builder->GetShape(indices)); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); + TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); gtl::ArraySlice indices_dims = - xla::AsInt64Slice(indices_shape->dimensions()); + xla::AsInt64Slice(indices_shape.dimensions()); gtl::ArraySlice buffer_dims = - xla::AsInt64Slice(buffer_shape->dimensions()); + xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -55,12 +50,12 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", - xla::ShapeUtil::HumanString(*indices_shape), + xla::ShapeUtil::HumanString(indices_shape), ") must be <= the rank of the buffer (shape: ", - xla::ShapeUtil::HumanString(*buffer_shape), ")"); + xla::ShapeUtil::HumanString(buffer_shape), ")"); } indices_dims.pop_back(); } @@ -78,10 +73,10 @@ xla::StatusOr XlaScatter( // If any of the indexed dimensions are zero in the buffer, the update cannot // succeed since it updates a slice of size 1. for (int64 i = 0; i < num_index_dims; ++i) { - if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { - return errors::InvalidArgument( - "Scatter dimension ", i, " is of size zero in tensor with shape ", - xla::ShapeUtil::HumanString(*buffer_shape)); + if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { + return errors::InvalidArgument("Scatter dimension ", i, + " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(buffer_shape)); } } @@ -111,18 +106,17 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* body_builder) { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; auto buffer = loop_vars[2]; auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape->element_type())); + xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; + xla::XlaOp index; auto indices_offset = body_builder->Reshape(i, {1}); if (indices_are_vectors) { indices_offset = body_builder->Pad(indices_offset, zero_index, @@ -180,12 +174,12 @@ xla::StatusOr XlaScatter( // Apply the update. buffer = body_builder->DynamicUpdateSlice(buffer, update, index); - return std::vector{indices, updates, buffer}; + return std::vector{indices, updates, buffer}; }; - TF_ASSIGN_OR_RETURN( - auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), - body_fn, init, "scatter", builder)); + TF_ASSIGN_OR_RETURN(auto outputs, + XlaForEachIndex(num_indices, indices_shape.element_type(), + body_fn, init, "scatter", builder)); return outputs[2]; } diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 41e6d3b195ebf90662c7b9b42c53fcb0133ab29e..87309e10ede320a81d173cd0a64492f88a2c7376 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { @@ -39,14 +39,12 @@ namespace tensorflow { // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 7f72a6073df218b9e2bd4cc0c0b5bb10b5cd4b84..b4503601f94baa5a595a64c9fc81bc92d9980ac6 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,21 +29,20 @@ limitations under the License. namespace tensorflow { -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(*a_shape), " vs. ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - const int ndims = xla::ShapeUtil::Rank(*a_shape); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to TriangularSolve must have rank >= 2: ", ndims); @@ -51,30 +50,30 @@ xla::StatusOr TriangularSolve( // The batch dimensions must be equal. std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); - int64 b_size = b_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { return errors::InvalidArgument( "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(*a_shape, -1) != - xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { @@ -83,36 +82,20 @@ xla::StatusOr TriangularSolve( block_size); } - // Returns [b1, b2, ... , bn, indices[0], indices[1]]. - auto prepend_batch_dims = [&](std::array indices) { - std::vector output(ndims); - std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); - std::copy(indices.begin(), indices.end(), - output.begin() + batch_dimensions.size()); - return output; - }; - - // Applies a complex conjugation operation if `a` is complex and `conjugate_a` - // is true, otherwise returns its argument. - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - - std::map base_computations; + std::map base_computations; auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::Computation& computation = base_computations[k]; + [&](int k) -> xla::StatusOr { + xla::XlaComputation& computation = base_computations[k]; if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( + std::unique_ptr sub = builder->CreateSubBuilder( tensorflow::strings::StrCat("trsm_base_", k)); - auto a_param = - sub->Parameter(0, - xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims({k, k})), - "a"); + auto a_param = sub->Parameter( + 0, + xla::ShapeUtil::MakeShape( + b_shape.element_type(), + PrependMajorDims(sub.get(), batch_dimensions, {k, k})), + "a"); std::array b_lastd; if (left_side) { @@ -120,22 +103,28 @@ xla::StatusOr TriangularSolve( } else { b_lastd = {m, k}; } - auto b_param = - sub->Parameter(1, - xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims(b_lastd)), - "b"); - - // We use a left-looking subroutine on the block diagonal in some common - // cases, while falling back to a recursive call in unsupported cases. The - // left-looking subroutine is written with a While loop and so yields much - // faster compile times. Moreover, the left-looking variant can give - // higher performance on smaller (sub)problems. + auto b_param = sub->Parameter( + 1, + xla::ShapeUtil::MakeShape( + b_shape.element_type(), + PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), + "b"); + + // We use a left-looking or right-looking subroutine on the block diagonal + // in the lower=true cases, while falling back to a recursive call in + // others. The left-looking and right-looking subroutines are written with + // a While loop and so yields much faster compile times. Moreover, they + // can give higher performance on smaller (sub)problems. if (left_side && lower) { TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, b_param, transpose_a, conjugate_a) .status()); + } else if (!left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); } else { TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, left_side, lower, transpose_a, @@ -149,7 +138,7 @@ xla::StatusOr TriangularSolve( return &computation; }; - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); // Right-looking blocked triangular solve. // For an explanation of the algorithm, see the TRSM discussion in: @@ -172,13 +161,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -188,7 +179,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) if (i + k < n) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); @@ -222,13 +213,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -238,7 +231,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) if (i + k < m) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); @@ -271,13 +264,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -287,7 +282,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -321,13 +316,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -337,7 +334,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -363,37 +360,23 @@ xla::StatusOr TriangularSolve( return output; } -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); batch_dimensions.push_back(a_size); } - auto prepend_batch_dims = [&](std::array indices) { - std::vector output(ndims); - std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); - std::copy(indices.begin(), indices.end(), - output.begin() + batch_dimensions.size()); - return output; - }; - - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - // The main computation is performed in a While loop. // Allocate the output and set its first or last row, @@ -402,14 +385,16 @@ xla::StatusOr TriangularSolveLeftLooking( // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] // else: // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); { auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + auto update = builder->Div(b_slice, a_slice_conj); TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); } @@ -423,11 +408,11 @@ xla::StatusOr TriangularSolveLeftLooking( // The loop iteration counter is a scalar, incremented each iteration. xla::ShapeUtil::MakeShape(xla::S32, {}), // The output has the shape of b, with one row updated each iteration. - *b_shape, + b_shape, // The coefficient matrix a is a loop invariant. - *a_shape, + a_shape, // The right-hand-side matrix b is a loop invariant. - *b_shape}; + b_shape}; xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); auto init = builder->Tuple({init_i, output, a, b}); @@ -436,7 +421,7 @@ xla::StatusOr TriangularSolveLeftLooking( // def cond_fun(loop_carry): // i, output, a, b = loop_carry // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); { auto i = condb->GetTupleElement( @@ -466,7 +451,7 @@ xla::StatusOr TriangularSolveLeftLooking( // return (i + 1, output, a, b) // We have to do some extra FLOPs propagating zeros in the matrix multiply // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); { auto input_tuple = bodyb->Parameter(0, tuple_shape, @@ -479,30 +464,6 @@ xla::StatusOr TriangularSolveLeftLooking( auto body_b = bodyb->GetTupleElement(input_tuple, 3); auto zero = bodyb->ConstantR0(0); - // Set up some helper functions. - auto prepend_zeros = [&](std::array starts) { - auto zero = bodyb->Reshape(bodyb->ConstantR0(0), {1}); - std::vector padded_starts(ndims, zero); - padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); - padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); - return bodyb->ConcatInDim(padded_starts, 0); - }; - - auto dynamic_slice = [&](xla::ComputationDataHandle x, - std::array starts, - std::array sizes) { - auto padded_starts = prepend_zeros(starts); - auto padded_sizes = prepend_batch_dims(sizes); - return bodyb->DynamicSlice(x, padded_starts, padded_sizes); - }; - - auto update = [&](xla::ComputationDataHandle x, - xla::ComputationDataHandle update, - std::array starts) { - auto padded_starts = prepend_zeros(starts); - return bodyb->DynamicUpdateSlice(x, update, padded_starts); - }; - // We'd like to implement this: // if transpose_a: // a_row = T(a[..., i+1:, i:i+1]) @@ -514,24 +475,33 @@ xla::StatusOr TriangularSolveLeftLooking( // But since we can't have intermediate array sizes depend on the loop // counter, we instead exploit the fact that we initialized the output to // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::ComputationDataHandle a_row; + xla::XlaOp a_row; if (transpose_a) { - a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); + TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, + {zero, i}, {m, 1})); } else { - a_row = dynamic_slice(body_a, {i, zero}, {1, m}); + TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, zero}, {1, m})); } TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, /*transpose_x=*/transpose_a, /*transpose_y=*/false, /*conjugate_x=*/conjugate_a, /*conjugate_y=*/false)); - auto result_row = - bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); + TF_ASSIGN_OR_RETURN( + auto result_row_slice, + DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); + auto result_row = bodyb->Sub(result_row_slice, b_update); // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); - auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); - body_out = update(body_out, div_result, {i, zero}); + TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); + TF_ASSIGN_OR_RETURN(auto a_elt_conj, + MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_elt_conj); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {i, zero})); // if transpose_a: // return (i - 1, body_out, a, b) @@ -548,4 +518,130 @@ xla::StatusOr TriangularSolveLeftLooking( return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); + } + + // The main computation is performed in a While loop. + xla::XlaOp output = Zeros(builder, b_shape); + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (0, output, a, b) + // else: + // init = (n-1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Lt(i, condb->ConstantR0(n)); + } else { + condb->Ge(i, condb->ConstantR0(0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter( + 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, + // i:i+1]) But since we can't have intermediate array sizes depend on the + // loop counter, we instead exploit the fact that we initialized the output + // to all zeros and use that as zero-padding (doing unnecessary FLOPs). + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + // result = b - np.matmul(output, a) + auto result = bodyb->Sub(body_b, b_update); + // result_row = result[..., :, i:i+1] + TF_ASSIGN_OR_RETURN( + auto result_row, + DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); + + // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); + TF_ASSIGN_OR_RETURN(auto a_ii_conj, + MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_ii_conj); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {zero, i})); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index e32223bfdddda800b1fd4de3e4f0c8061e0f81d8..540c26b2473df9e7885f4e549b3e516a3d8a0d43 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -57,14 +57,23 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 256); +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); + +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 661707062916263fd0d5d935ce41698a7655df02..87ea4763f7c2357ae179b68ade3715b24c46432f 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -80,9 +80,9 @@ xla::Array2D AValsFull() { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -102,9 +102,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -124,9 +124,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -146,9 +146,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -168,9 +168,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -191,9 +191,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -214,9 +214,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -237,9 +237,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -260,9 +260,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -288,9 +288,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -318,9 +318,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, @@ -340,9 +340,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { } XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index f579669bbd852b514e021ce71d635f8ce5e4fe4d..d9ff7e6259f3fbab8957394bff5c5670a67dd0eb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,15 +27,14 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape) { +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value) { +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value) { switch (type) { case xla::F16: return builder->ConstantR0(static_cast(value)); @@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, - int64 value) { +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value) { xla::Literal literal; switch (type) { case xla::U8: @@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, return builder->ConstantLiteral(literal); } -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end) { +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end) { TF_RET_CHECK(start.size() == end.size()); int64 n_minor_dims = start.size(); - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - n_minor_dims); @@ -140,20 +139,56 @@ xla::StatusOr SliceInMinorDims( return builder->Slice(x, padded_start, padded_end, strides); } -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { +std::vector PrependMajorDims(xla::XlaBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices) { + std::vector output(indices.size() + major_dims.size()); + std::copy(major_dims.begin(), major_dims.end(), output.begin()); + std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); + return output; +} + +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, + const gtl::ArraySlice& sizes) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); + return builder->DynamicSlice(x, padded_starts, padded_sizes); +} + +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); - return builder->DynamicUpdateSlice( - x, update, builder->ConstantR1(start_as_int32)); + auto start_constant = builder->ConstantR1(start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return builder->DynamicUpdateSlice(x, update, start_constant); } -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -162,10 +197,32 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + return builder->DynamicUpdateSlice(x, update, padded_starts); +} + +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = builder->Reshape(builder->ConstantR0(0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = + builder->Reshape(starts[i], {1}); + } + return builder->ConcatInDim(padded_starts, 0); +} + +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); @@ -173,4 +230,11 @@ xla::StatusOr TransposeInMinorDims( return builder->Transpose(x, permutation); } +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? builder->Conj(x) : x; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 51f8baaf00bd8fd25baa1a87be8cb0089dfb22b5..3c120a2548576d6ad46870583ca65beea63507a3 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,47 +16,79 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape); +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value); +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value); + +// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros +// prepended until the array is length n_dims. +xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, + gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, int64 value); +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value); + +// Builds a vector of zeros of length rank(x) with the last two values being +// those in `starts`. +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end); +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end); + +// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. +std::vector PrependMajorDims(xla::XlaBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices); + +// Performs a dynamic slice in the minor dimensions of a Tensor. +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, const gtl::ArraySlice& sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); + +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x); + +// Applies a complex conjugation operation if `a` is complex and `conjugate_a` +// is true, otherwise returns its argument. +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..265b39402c832f8c810a74f281563b05afdf2b1b --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using UtilTest = xla::ClientLibraryTestBase; +using UtilLeftLookingTest = xla::ClientLibraryTestBase; + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(UtilTest, Simple2dLookup) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, x, y; + auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); + auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); + auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); + TF_ASSERT_OK(result.status()); + + ComputeAndCompareR2(&builder, {{10}}, + {a_data.get(), x_data.get(), y_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(UtilTest, Simple3dLookup) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, 4})); + + ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, + {a_data.get(), index_data.get()}); +} + +XLA_TEST_F(UtilTest, SimpleSliceUpdate) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, b, x, y; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); + auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); + + auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected( + {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); + + ComputeAndCompareR2( + &builder, expected, + {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); +} + +XLA_TEST_F(UtilTest, RowBatchDot) { + xla::XlaBuilder builder(TestName()); + + int n = 4; + + xla::XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, n})); + TF_ASSERT_OK_AND_ASSIGN( + auto dot, BatchDot(&builder, l_index, row, + /*transpose_x=*/false, /*transpose_y=*/true)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 86c02ac2e65c12d3527c4022df0cc603e522ef7a..09ce594930efc0af47306590d76b322ac730f80f 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -20,24 +20,24 @@ limitations under the License. namespace tensorflow { -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::ComputationDataHandle& input : initial_values) { + for (const xla::XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); - var_shapes.push_back(std::move(*shape)); + var_shapes.push_back(std::move(shape)); } xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, - xla::ComputationBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](xla::XlaOp tuple, int arity, + xla::XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { elements[i] = builder->GetTupleElement(tuple, i); } @@ -45,21 +45,20 @@ xla::StatusOr> XlaWhileLoop( }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); - TF_ASSIGN_OR_RETURN( - auto result, + TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), - cond_builder.get())); - TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result)); + cond_builder.get()) + .status()); } TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); @@ -79,38 +78,38 @@ xla::StatusOr> XlaWhileLoop( return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { - auto while_cond_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* cond_builder) - -> xla::StatusOr { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { + auto while_cond_fn = + [&](gtl::ArraySlice values, + xla::XlaBuilder* cond_builder) -> xla::StatusOr { return cond_builder->Lt( values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr> { - xla::ComputationDataHandle iteration = values[0]; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); updated_values.push_back(body_builder->Add( iteration, body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); values.push_back( builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 2e67a0c99b6deb65fa16ab2dec1727f5cb5fcb92..5b6684c995889efbb1378c7ac4903548891d090a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -29,14 +29,14 @@ namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function(gtl::ArraySlice, + xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function>( + gtl::ArraySlice, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -47,27 +47,26 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::ComputationDataHandle, gtl::ArraySlice, - xla::ComputationBuilder*)> +typedef std::function>( + xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2c3cd658e0462368ac0b51938979b7a6815a7574..db56b128375ce8ff2faf12c5d7ea256bdfab0f63 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,38 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} + +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); + + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); + } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + + return Status::OK(); +} + +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && xla::ShapeUtil::ElementsIn(literal.shape()) == @@ -63,8 +94,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal, return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index f283b0236811f8d52e8fe2982a74c11c92cd20d8..74685025c1780c5c0ba56205a98786582e9191e9 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -29,6 +30,17 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); + // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . // Fails if the literal's primitive type != @@ -36,13 +48,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor); +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 98f72b3792eb147f5a1847c5e1ecef18bccbca5f..bb9168fa358154f3db9dab87bacc9bf28dd16406 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -7,17 +7,13 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( - name = "functional_ops", - srcs = ["functional_ops.cc"], - deps = [ - "//tensorflow/core:framework", + name = "xla_ops", + srcs = [ + "dynamic_slice_ops.cc", + "functional_ops.cc", + "reduce_window_op.cc", + "sendrecv_ops.cc", ], - alwayslink = 1, -) - -cc_library( - name = "sendrecv_ops", - srcs = ["sendrecv_ops.cc"], deps = [ "//tensorflow/core:framework", ], @@ -25,31 +21,9 @@ cc_library( ) tf_gen_op_wrapper_py( - name = "gen_functional_ops", - out = "gen_functional_ops.py", - deps = [ - ":functional_ops", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_sendrecv_ops", - out = "gen_sendrecv_ops.py", + name = "gen_xla_ops", + out = "gen_xla_ops.py", deps = [ - ":sendrecv_ops", + ":xla_ops", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6c0edbb889b1751ac9d9d47d0c9534b543196ff --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9af982adc090ea78c711fd4656ba429c53b18c9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc index 4b41c16a8b3fdc0c3412c76d29d3ec2b7bdfd0aa..7ec7b50e905a6cbdecea4543dcb87322b5a7e844 100644 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc @@ -18,22 +18,24 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XLASend") +REGISTER_OP("XlaSend") .Input("tensor: T") .Attr("T: type") .Attr("tensor_name: string") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( -Sends the named tensor to another XLA computation. +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . tensor: The tensor to send. -tensor_name: The name of the tensor to send. +tensor_name: A string key that identifies the channel. )doc"); -REGISTER_OP("_XLARecv") - .Output("tensor: T") - .Attr("T: type") +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") .Attr("tensor_name: string") .Attr("shape: shape") .SetIsStateful() @@ -46,11 +48,14 @@ REGISTER_OP("_XLARecv") return Status::OK(); }) .Doc(R"doc( -Receives the named tensor from another XLA computation. +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . tensor: The tensor to receive. -tensor_name: The name of the tensor to receive. -shape: The shape of the input tensor. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. )doc"); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index f0a2ef0651ff6115bd201a3b1c34b3c061a22a3d..42b6292f79ffddd155c05758a1420a2a583eb0c6 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -22,3 +22,11 @@ tf_py_clif_cc( "//tensorflow/compiler/tf2xla:xla_compiler", ], ) + +py_library( + name = "xla", + srcs = ["xla.py"], + deps = [ + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + ], +) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ce65bec950fdfd38c3ca5bc62ac745ef8ca4a7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental library that exposes XLA operations directly in TensorFlow. + +It is sometimes useful to be able to build HLO programs directly from +TensorFlow. This file provides Tensorflow operators that map as closely as +possible to HLO operators. + +There is no promise of backward or forward compatibility for operators defined +in this module. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tf2xla.ops import gen_xla_ops + +# TODO(phawkins): provide wrappers for all XLA operators. + +dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice + + +def reduce_window(operand, + init, + reducer, + window_dimensions, + window_strides=None, + padding=None, + name=None): + """Wraps the XLA ReduceWindow operator. + + ReduceWindow is documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + + Args: + operand: the input tensor + init: a scalar tensor representing the initial value for the reduction + reducer: a reduction function that combines a pair of scalars. + window_dimensions: shape of the window, as a list of integers + window_strides: inter-window strides, as a list of integers. Optional; + if omitted, defaults to strides of 1. + padding: padding to apply to 'operand'. List of (low, high) pairs of + integers that specify the padding to apply before and after each + dimension. Optional; if omitted, defaults to no padding. + name: the operator name, or None. + Returns: + A tensor that represents the output of the reduce_window operator. + """ + window_strides = window_strides or [1] * len(window_dimensions) + padding = padding or [(0, 0)] * len(window_dimensions) + padding_low = [x for (x, _) in padding] + padding_high = [y for (_, y) in padding] + return gen_xla_ops.xla_reduce_window( + operand, + init, + reducer, + window_dimensions, + window_strides, + padding_low, + padding_high, + name=name) + + +recv = gen_xla_ops.xla_recv +send = gen_xla_ops.xla_send + +while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 1a0e09758f7cc6714793300c6ece14093a8ad246..5759c72af301785f3ca1110b58eeb2fe7dead713 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -65,8 +66,8 @@ ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !StringPiece(parsed_device.type) - .contains(kDeviceSuffixReplicatedCore)) { + !str_util::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 6051d7dffd7493d8cffb07c1b5d10500e7e75522..ac768b206e2a8d163a4253432a1911152f89ce86 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -251,7 +251,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation) { + xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; @@ -303,7 +302,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, } // InitGraph creates a graph based on the graph_def, that may then be converted -// to an xla::Computation via ConvertGraphToXla. +// to an xla::XlaComputation via ConvertGraphToXla. // // The graph is rewritten with _Arg and _Retval nodes, representing the inputs // and outputs of the function that will be compiled. Each feed id causes a new @@ -348,7 +347,7 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation) { + xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index 473c431b12d441c652f1d0d6c11c5e87836ab36d..d02fc56c5b8f58f0e4cfe1779ad34fe3b79324c7 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -18,21 +18,21 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { -// Converts a tensorflow::GraphDef into an xla::Computation. The given `config` -// specifies the portion of the graph to convert, via feeds and fetches. Each -// feed is a positional input argument for the generated computation, while each -// fetch is a positional output argument. +// Converts a tensorflow::GraphDef into an xla::XlaComputation. The given +// `config` specifies the portion of the graph to convert, via feeds and +// fetches. Each feed is a positional input argument for the generated +// computation, while each fetch is a positional output argument. // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation); + xla::XlaComputation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index a9978e697b091715ce120f0d18fdddd259e08b32..84c133ffabe20dbdaa4d5a64e035efb5e4c4c44b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -69,7 +69,7 @@ TEST(ConvertGraphDefToXla, Sum) { tf2xla::Config config = SumConfig(); xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. @@ -90,6 +90,11 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(result_or.status()); std::unique_ptr result = std::move(result_or.ValueOrDie()); EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + + config.mutable_feed(0)->mutable_id()->set_output_index( + 123); /* invalid output_index */ + EXPECT_TRUE(errors::IsInvalidArgument( + ConvertGraphDefToXla(graph_def, config, client, &computation))); } } // namespace diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index f428a194328935fec1210ea96245344de859e611..9203e8d9e607e99ad738350a1c3f2b9e900df179 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -151,8 +151,15 @@ Status AddPlaceholdersForFeeds( Status status; Node* feed_node = g.AddNode(gd.node(0), &status); TF_RETURN_IF_ERROR(status); - info.data_type = - BaseType(feed_node->output_type(info.feed->id().output_index())); + + if (info.feed->id().output_index() < feed_node->num_outputs()) { + info.data_type = + BaseType(feed_node->output_type(info.feed->id().output_index())); + } else { + return errors::InvalidArgument( + "Invalid output_index ", info.feed->id().output_index(), + " for feed node ", info.feed->id().node_name()); + } } } @@ -225,7 +232,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = id.first.ToString(); + const string node_name = std::string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); @@ -281,4 +288,13 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } +void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, + KernelDef* kdef) { + for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { + if (constraint.name() == name) { + constraint.mutable_allowed_values()->mutable_list()->add_type(dtype); + } + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index e5fba8ede7745febbb42c572a7b52247213afc95..745beb39c1d917cd0d1cd219536ee26a96253ec9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -51,6 +52,10 @@ string TensorIdToString(const tf2xla::TensorId& id); // edges are considered. Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); +// Add an allowed data type to the AttrConstraint with the given name. +void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, + KernelDef* kdef); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ed10d80609641b090cf78bf2e17364fe2fa89c31..ae51446204baf14dc03fc6305641048dbf3872b0 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -33,7 +34,7 @@ namespace { void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fcb0a4e63814b4afc114bdaea312a92dd8396a2e..fe7ec633eca2504faf6cbb2f5fd7f59780ab7976 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" @@ -108,7 +109,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device // 0. - xla::ScopedShardingAssignment assign_sharding(b, op_sharding); + xla::XlaScopedShardingAssignment assign_sharding(b, op_sharding); op_kernel->Compute(context); b->ClearOpMetadata(); @@ -126,9 +127,7 @@ Status XlaCompilationDevice::MakeTensorFromProto( XlaExpression::XlaExpression() = default; -void XlaExpression::set_handle(const xla::ComputationDataHandle& h) { - handle_ = h; -} +void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } void XlaExpression::set_constant_value(Tensor value) { has_constant_value_ = true; diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 0243ee332fbdca0fe5e28b1a7d9530df4417f807..d0b9e34e162f3412cd6662a2e2bbfe3df213c4c2 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" @@ -69,7 +69,7 @@ class XlaCompilationDevice : public LocalDevice { // A XlaExpression wraps an XLA computation. Each Tensor on an // XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the ComputationDataHandle. Each +// matches the shape of the subcomputation in the XlaOp. Each // expression is either a constant, or a function of previously-compiled // expressions. class XlaExpression { @@ -78,8 +78,8 @@ class XlaExpression { // handle() stores the XLA handle of the computation that the // expression represents. - void set_handle(const xla::ComputationDataHandle& h); - const xla::ComputationDataHandle& handle() const { return handle_; } + void set_handle(const xla::XlaOp& h); + const xla::XlaOp& handle() const { return handle_; } void set_constant_value(Tensor value); bool has_constant_value() const { return has_constant_value_; } @@ -90,7 +90,7 @@ class XlaExpression { private: // The XLA handle of the expression's computation. - xla::ComputationDataHandle handle_; + xla::XlaOp handle_; // If this expression is a constant with a known value, 'constant_value' is a // host-memory Tensor containing the value. Used to avoid invoking XLA for diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 15bba46ac62a97592656942afc767a303c9b97f3..9c8e56a17e07348d3cfaaca0b5eb335295af05c3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include #include +#include -#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -40,7 +38,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -86,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -110,10 +104,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); - // The default variable representation shape is the identity function. - if (!options_.variable_representation_shape_fn) { - options_.variable_representation_shape_fn = - [](const TensorShape& shape, DataType type) { return shape; }; + // The default shape representation function is the identity. + if (!options_.shape_representation_fn) { + options_.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return shape; }; } } @@ -230,20 +224,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, - xla::Shape* xla_shape) { + bool is_entry_computation, + xla::Shape* xla_shape) const { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), - xla_shape); - case XlaCompiler::Argument::kParameter: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + LOG(FATAL) << "Unreachable case"; + case XlaCompiler::Argument::kParameter: { + TensorShape shape = + is_entry_computation + ? options_.shape_representation_fn(arg.shape, arg.type) + : arg.shape; + return TensorShapeToXLAShape(arg.type, shape, xla_shape); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.variable_representation_shape_fn(arg.shape, arg.type); + options_.shape_representation_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -337,16 +336,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& arg_cores, - const std::vector& retvals, + const std::vector& retvals, const std::vector>& resources, - bool return_updated_values_for_all_resources, - xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, + bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, + xla::XlaComputation* computation, int* num_computation_outputs, + int* num_nonconst_outputs, + std::vector* outputs, std::vector* resource_updates) { - std::vector elems; + std::vector elems; elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + output.type = retvals[i].type; + output.shape = retvals[i].shape; + const XlaExpression& retval = retvals[i].expression; + if (retval.has_constant_value()) { + output.is_constant = true; + output.constant_value = retval.constant_value(); + } else { + output.is_constant = false; elems.push_back(retval.handle()); } } @@ -365,18 +373,23 @@ Status BuildComputation( return a->arg_num() < b->arg_num(); }); + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + for (const XlaResource* resource : arg_resources) { const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = - resource->value().handle() != resource->initial_value().handle(); + bool modified = resource->value() != resource->initial_value(); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { modified = modified || - grad.second->value().handle() != - grad.second->initial_value().handle() || + grad.second->value() != grad.second->initial_value() || arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { @@ -391,11 +404,11 @@ Status BuildComputation( } // Request that the value be returned on a specific core. - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - xla::ComputationDataHandle handle; + xla::XlaOp handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); // Since we can't change the sharding metadata of as this point, @@ -412,7 +425,9 @@ Status BuildComputation( // Builds the XLA computation. builder->Tuple(elems); - xla::StatusOr computation_status = builder->Build(); + builder->ClearOpMetadata(); + + xla::StatusOr computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status(); } @@ -426,7 +441,7 @@ Status BuildComputation( // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { @@ -452,8 +467,7 @@ Status XlaCompiler::BuildArguments( // alias. XlaResource* resource; TF_RETURN_IF_ERROR(context->CreateResource( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::ComputationDataHandle(), + arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); @@ -484,8 +498,8 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR( - XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR(XLAShapeForArgument( + args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -514,10 +528,17 @@ Status XlaCompiler::BuildArguments( } } + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name("XLA_Args"); + builder->SetOpMetadata(arg_metadata); + // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { - xla::ComputationDataHandle tuple; + xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); @@ -528,15 +549,15 @@ Status XlaCompiler::BuildArguments( core == -1 ? xla::sharding_builder::AssignDevice(root_device) : xla::sharding_builder::AssignDevice(core); } - xla::ScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); + xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, + tuple_sharding); tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } else { tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); @@ -544,7 +565,7 @@ Status XlaCompiler::BuildArguments( } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = @@ -552,7 +573,10 @@ Status XlaCompiler::BuildArguments( } } - // Fill in the handles in non-constant arguments. + builder->ClearOpMetadata(); + + // Fill in the handles in non-constant arguments, and reshape parameters + // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -571,7 +595,15 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - arg_expression.set_handle(arg_handles[i]); + // Reshape parameters back to their correct shapes. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + if (is_entry_computation) { + arg_expression.set_handle( + builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + } else { + arg_expression.set_handle(arg_handles[i]); + } break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -582,12 +614,114 @@ Status XlaCompiler::BuildArguments( return Status::OK(); } +Status XlaCompiler::CompileSingleOp( + const XlaCompiler::CompileOptions& options, string const& name, + OpKernelContext* ctx, const std::vector& args, + CompilationResult* result) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Status status; + // First create the actual node we care about computing. + Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + TF_RETURN_IF_ERROR(status); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + Status status = NodeBuilder(name, "_Arg") + .ControlInput(graph->source_node()) + .Attr("T", ctx->input_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); + } + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64 i = 0; i < ctx->num_outputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + Status status = NodeBuilder(name, "_Retval") + .Input(main_node, i) + .Attr("T", ctx->expected_output_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + } + FixupSourceAndSinkEdges(graph.get()); + + return CompileGraph(options, name, std::move(graph), args, result); +} + +namespace { + +// Check that the ops of all non-functional nodes have been registered. +string ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + std::vector invalid_ops; + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { + invalid_ops.push_back(op); + } + } + return tensorflow::str_util::Join(invalid_ops, ", "); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + std::vector invalid_ops; + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + if (fdef) { + string error_msg = ValidateFunctionDef(fdef, flib_def); + if (!error_msg.empty()) { + invalid_ops.push_back( + strings::StrCat(node->def().op(), ":{", error_msg, "}")); + } + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { + invalid_ops.push_back(node->def().op()); + continue; + } + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { + invalid_ops.push_back(node->def().op()); + } + } + if (!invalid_ops.empty()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ":", + tensorflow::str_util::Join(invalid_ops, ", "))); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, const std::vector& args, CompilationResult* result) { - VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " @@ -601,13 +735,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Converts Tensorflow's graph control-flow constructs into functional // control-flow that can be compiled into XLA code. TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); - - xla::ComputationBuilder builder(client(), name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, - &options_.variable_representation_shape_fn); + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); + + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); + + xla::XlaBuilder builder(name); + XlaContext* context = new XlaContext( + this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, options.is_entry_computation, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -623,59 +763,37 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; - result->computation = std::make_shared(); + result->computation = std::make_shared(); + result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(context->retvals().size()); - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (retval.has_constant_value()) { - OutputDescription& output = result->outputs[i]; - output.shape = retval.constant_value().shape(); - output.is_constant = true; - output.constant_value = retval.constant_value(); - } - } - // Compute the output shapes, if there is a computation with non-constant + // Compute the XLA output shape, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(*result->computation); - if (!computation_shape.ok()) { - return computation_shape.status(); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, + client()->GetComputationShape(*result->computation)); - result->xla_output_shape.Swap( - computation_shape.ValueOrDie()->mutable_result()); + result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); + // Copy the host transfer metadata to the result. + for (const auto& send : host_compute_sends_) { + *result->host_compute_metadata.add_device_to_host() = send.second; + } + for (const auto& recv : host_compute_recvs_) { + *result->host_compute_metadata.add_host_to_device() = recv.second; + } + // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - // Converts the output shapes to TensorShapes. - int computation_output = 0; - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (!retval.has_constant_value()) { - TF_RET_CHECK(computation_output < num_computation_outputs) - << "Computation has more outputs than expected"; - OutputDescription& output = result->outputs[i]; - output.is_constant = false; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, - computation_output), - &output.shape)); - ++computation_output; - } - } return Status::OK(); } @@ -690,4 +808,84 @@ Status XlaCompiler::GetChannelHandle(const string& key, return Status::OK(); } +namespace { + +void SetTransfer(const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes, + tf2xla::HostTransferMetadata* transfer) { + transfer->set_key(key); + CHECK(types.size() == shapes.size()); + for (int i = 0; i < types.size(); ++i) { + tf2xla::TensorMetadata* metadata = transfer->add_metadata(); + metadata->set_type(types[i]); + shapes[i].AsProto(metadata->mutable_shape()); + } +} + +} // namespace + +Status XlaCompiler::SetDeviceToHostMetadata( + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { + if (host_compute_sends_.find(key) != host_compute_sends_.end()) { + return errors::InvalidArgument( + "Duplicate calls to SetDeviceToHostMetadata with key ", key); + } + tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; + SetTransfer(key, types, shapes, &transfer); + return Status::OK(); +} + +Status XlaCompiler::GetDeviceToHostShapes( + const string& key, std::vector* shapes) const { + const auto iter = host_compute_sends_.find(key); + if (iter == host_compute_sends_.end()) { + return errors::InvalidArgument( + "No host compute send shapes registered for key ", key); + } + shapes->clear(); + for (int i = 0; i < iter->second.metadata_size(); ++i) { + TensorShape shape(iter->second.metadata(i).shape()); + shapes->push_back(shape); + } + return Status::OK(); +} + +Status XlaCompiler::SetHostToDeviceMetadata( + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { + if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { + return errors::InvalidArgument( + "Duplicate calls to SetHostToDeviceMetadata with key ", key); + } + tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; + SetTransfer(key, types, shapes, &transfer); + return Status::OK(); +} + +Status XlaCompiler::GetHostComputeControlDependency( + const string& host_compute_name, xla::XlaOp* handle) { + const auto iter = host_compute_control_output_.find(host_compute_name); + if (iter == host_compute_control_output_.end()) { + return errors::InvalidArgument( + "No registered control handle for host compute Op '", host_compute_name, + "'"); + } else { + *handle = iter->second; + } + return Status::OK(); +} + +Status XlaCompiler::SetHostComputeControlDependency( + const string& host_compute_name, const xla::XlaOp& handle) { + if (host_compute_control_output_.find(host_compute_name) != + host_compute_control_output_.end()) { + return errors::InvalidArgument( + "Duplicate control handles registered for for host compute Op ", + host_compute_name); + } + host_compute_control_output_[host_compute_name] = handle; + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c4449bc4be06daff856eff70c6d89be6ddbcf0ee..c93850ce270502ea1df1f6469963e96e86994fa2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -37,7 +39,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // @@ -66,6 +68,15 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -170,7 +181,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. + // Type and shape of the output. The shape is the unflattened shape. DataType type; TensorShape shape; @@ -205,10 +216,12 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Input shapes of the computation. + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. std::vector xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -216,19 +229,25 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector outputs; + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla::HostComputeMetadata host_compute_metadata; + // Resources whose values were updated by the computation, ordered // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. std::vector resource_updates; // The XLA computation built from the tensorflow subgraph. - std::shared_ptr computation; + std::shared_ptr computation; }; + typedef std::function + ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; @@ -245,8 +264,7 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - std::function - variable_representation_shape_fn; + ShapeRepresentationFn shape_representation_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -276,7 +294,7 @@ class XlaCompiler { const NameAttrList& fn_name_attrs, std::vector args, CompilationResult* result); - // Compiles a tensorflow::Graph into an xla::Computation. + // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, @@ -284,10 +302,19 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + // Compiles a single Op, given by an OpKernelContext, into an + // xla::XlaComputation. Similar to CompileFunction but takes a single Op as + // input. + Status CompileSingleOp(const CompileOptions& options, string const& name, + OpKernelContext* ctx, + const std::vector& args, + CompilationResult* result); + // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -296,6 +323,38 @@ class XlaCompiler { // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + // Sets the shapes and types for the device to host transfer associated with + // 'key'. + Status SetDeviceToHostMetadata(const string& key, + gtl::ArraySlice types, + gtl::ArraySlice shapes); + + // Gets the shapes the device to host transfer associated with 'key'. + Status GetDeviceToHostShapes(const string& key, + std::vector* shapes) const; + + // Sets the shapes and types for the host to device transfer associated with + // 'key'. + Status SetHostToDeviceMetadata(const string& key, + gtl::ArraySlice types, + gtl::ArraySlice shapes); + + // In order to avoid deadlocks from dependencies in host computations, it can + // be necessary to enforce a partial order on the execution of HostCompute + // Ops. In particular it may be necessary to constrain the SendToHost for one + // HostCompute to run before blocking on the RecvAtHost for another + // HostCompute. The compiler maintains a mapping from 'host_compute_name' to + // handle, where the handle is an 'output' of the HostCompute Op corresponding + // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced + // later can add the handle as an 'input' to enforce the constraints. + // 'host_compute_name' can be any string the client wishes to use to identify + // a given HostCompute Op as long as the names are unique within the + // compilation. + Status GetHostComputeControlDependency(const string& host_compute_name, + xla::XlaOp* handle); + Status SetHostComputeControlDependency(const string& host_compute_name, + const xla::XlaOp& handle); + const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } @@ -312,7 +371,7 @@ class XlaCompiler { // `args` are the arguments to the computation. Status BuildArguments(const Graph& graph, const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, @@ -359,6 +418,11 @@ class XlaCompiler { std::unordered_map channels_; + std::unordered_map host_compute_sends_; + std::unordered_map host_compute_recvs_; + + std::unordered_map host_compute_control_output_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index a18eeacd41808884fac9ec5d617cb0d274ea27d8..613230452b74755ce7543ec2ab82861aa0dfeb7a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,16 +25,20 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -42,8 +46,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -55,7 +57,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -65,7 +67,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -163,7 +164,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); - // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); @@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, Simple) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { @@ -257,10 +257,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - StringPiece(status.error_message()).contains("depends on a parameter")) + str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - StringPiece(status.error_message()).contains("[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) << status.error_message(); } @@ -320,7 +320,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE( + xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } { @@ -355,10 +356,80 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } +TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { + // Define a function with one compile-time constant output and one + // data-dependent output. + // @function.Defun(noinline=True) + // foo(a) {b=7; return b, a; } + const Tensor seven = test::AsScalar(7); + FunctionDef fdef = FunctionDefHelper::Create( + "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {}, + { + {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}}, + }, + {{"a", "a_0"}, {"const", "Const:output:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0); + NodeDef foo; + foo.set_name("foo"); + foo.set_op("foo"); + *foo.add_input() = "input_arg"; + Status status; + scope.graph()->AddNode(foo, &status); + TF_ASSERT_OK(status); + NodeDef retval_1; + retval_1.set_name("retval_0"); + retval_1.set_op(FunctionLibraryDefinition::kRetOp); + *retval_1.add_input() = "foo"; + (*retval_1.mutable_attr())["T"].set_type(DT_INT32); + (*retval_1.mutable_attr())["index"].set_i(0); + scope.graph()->AddNode(retval_1, &status); + TF_ASSERT_OK(status); + NodeDef retval_2; + retval_2.set_name("retval_1"); + retval_2.set_op(FunctionLibraryDefinition::kRetOp); + *retval_2.add_input() = "foo:1"; + (*retval_2.mutable_attr())["T"].set_type(DT_INT32); + (*retval_2.mutable_attr())["index"].set_i(1); + scope.graph()->AddNode(retval_2, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1}); + + XlaCompiler::Options options = DefaultOptions(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph), args, &result)); + + ASSERT_EQ(2, result.outputs.size()); + EXPECT_TRUE(result.outputs[0].is_constant); + test::ExpectTensorEqual(result.outputs[0].constant_value, + test::AsScalar(7)); + EXPECT_FALSE(result.outputs[1].is_constant); +} + // Tests compilation and execution of a graph that adds two tensors. TEST_F(XlaCompilerTest, ResourceManager) { // Builds a graph that calls the dummy resource Op. @@ -432,21 +503,26 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { } for (int64 i = 1; i < test_count; ++i) { - auto m1 = - results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); - auto m2 = - results[i].computation->Snapshot().ValueOrDie()->entry().requests(); - // Check if every entry is the same. - for (auto& entry1 : m1) { - int64 key = entry1.first; - auto value1 = entry1.second; - auto entry2 = m2.find(key); - auto value2 = entry2->second; - EXPECT_TRUE(entry2 != m2.end()); - string str1, str2; - value1.AppendToString(&str1); - value2.AppendToString(&str2); - EXPECT_EQ(str1, str2); + const auto& m1 = results[i - 1].computation->proto(); + const auto& m2 = results[i].computation->proto(); + ASSERT_EQ(m1.computations_size(), m2.computations_size()); + // Check if every hlo computation is the same. + for (int k = 0; k < m1.computations_size(); k++) { + const auto& c1 = m1.computations(k); + const auto& c2 = m2.computations(k); + ASSERT_EQ(c1.instructions_size(), c2.instructions_size()); + for (int j = 0; j < c1.instructions_size(); j++) { + auto instr1 = c1.instructions(j); + auto instr2 = c2.instructions(j); + instr1.clear_name(); + instr2.clear_name(); + // The names of instructions were uniquified by the XlaBuilder, the rest + // of the fields should be identical. + string str1, str2; + instr1.AppendPartialToString(&str1); + instr2.AppendPartialToString(&str2); + EXPECT_EQ(str1, str2); + } } } } @@ -518,7 +594,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -597,7 +673,8 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "is not defined.")) << status.error_message(); } @@ -676,11 +753,12 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE( - StringPiece(status.error_message()).contains("Attr T is not found")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "Attr T is not found")) << status.error_message(); } @@ -739,13 +817,10 @@ TEST_F(XlaCompilerTest, Variables) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a simple graph that reads and writes a variable, with a -// variable_representation_shape_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { +xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -756,7 +831,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_ASSERT_OK(scope.ToGraph(graph.get())); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + return std::move(graph); +} + +// Tests a simple graph that reads and writes a variable, with a +// shape_representation_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); // Builds a description of the arguments. std::vector args(2); @@ -771,15 +854,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.variable_representation_shape_fn = [](const TensorShape& shape, - DataType type) { + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; // Only reshape variables. + XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::Compatible(program_shape->parameters(0), + xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -804,7 +905,186 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; // Reshape args and retvals. + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {4}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({4, 55, 1, -3}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({27, 67, 35, 402}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); +} + +TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef no_op; + no_op.set_name("NoOp"); + no_op.set_op("NoOp"); + Status status; + graph->AddNode(no_op, &status); + TF_ASSERT_OK(status); + + std::vector args; + XlaCompiler compiler(DefaultOptions()); + // No control edge linking NoOp with source/sink. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: NoOp")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); + } } } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 8423921086fec1cf534cf613102fc3839035cb85..098072d33cd4eb7f7dec0ec4196b43eca0220d4a 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,28 +63,32 @@ void XlaContext::set_args(std::vector args) { } XlaContext::XlaContext( - XlaCompiler* compiler, xla::ComputationBuilder* builder, + XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn) + shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - variable_representation_shape_fn_(variable_representation_shape_fn) {} + is_entry_computation_(is_entry_computation), + shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const xla::ComputationDataHandle& handle) { + const TensorShape& shape, const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - retvals_[retval_index].set_handle(handle); + XlaExpression e; + e.set_handle(handle); + retvals_[retval_index] = Retval{type, shape, e}; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -94,23 +98,20 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - if (resolve_compile_time_constants_) { - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - retvals_[retval_index].set_constant_value(std::move(value)); - } else { - retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); - } + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + XlaExpression e; + e.set_constant_value(value); + retvals_[retval_index] = Retval{dtype, value.shape(), e}; return Status::OK(); } -xla::ComputationBuilder* XlaContext::builder() { return builder_; } +xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, - TensorShape shape, const xla::ComputationDataHandle& handle, - int64 tensor_array_size, const std::set& tensor_array_gradients, - XlaResource** resource) { + TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource) { resources_.emplace_back( new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), handle, tensor_array_size, tensor_array_gradients)); @@ -118,16 +119,16 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, - DataType type) const { - return (*variable_representation_shape_fn_)(shape, type); +TensorShape XlaContext::RepresentationShape(const TensorShape& shape, + DataType type) const { + return (*shape_representation_fn_)(shape, type); } -const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">"); + xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -137,11 +138,11 @@ const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { return LookupOrCreate(type, &min_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "min<" + type_string + ">"); + xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -151,11 +152,11 @@ const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { return LookupOrCreate(type, &add_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">"); + xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -165,11 +166,11 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { return LookupOrCreate(type, &mul_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -179,9 +180,9 @@ const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { }); } -const xla::Computation* XlaContext::LookupOrCreate( +const xla::XlaComputation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create) { + const std::function& create) { { const auto& entry = (*out)[type]; if (!entry.IsNull()) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 00fbaba37c542954f690b310a184cff985a05156..341bf6ff1f37fa7cd81f41c02a941214067b1bd1 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -42,32 +42,44 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. - XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn); + shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; XlaCompiler* compiler() const { return compiler_; } - // Returns the ComputationBuilder that Ops use for compiling new - // expressions. - xla::ComputationBuilder* builder(); + // Returns the XlaBuilder that Ops use for compiling new expressions. + xla::XlaBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool resolve_compile_time_constants() const { + return resolve_compile_time_constants_; + } + bool is_entry_computation() const { return is_entry_computation_; } + const std::vector& args() const { return args_; } void set_args(std::vector args); - const std::vector& retvals() { return retvals_; } + struct Retval { + DataType type; + TensorShape shape; + // An XlaExpression representing the Retval's value. + XlaExpression expression; + }; + const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, - const xla::ComputationDataHandle& handle); + void AddRetval(int retval_index, DataType type, const TensorShape& shape, + const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -79,8 +91,7 @@ class XlaContext : public ResourceBase { // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, - const xla::ComputationDataHandle& handle, - int64 tensor_array_size, + const xla::XlaOp& handle, int64 tensor_array_size, const std::set& tensor_array_gradients, XlaResource** resource); @@ -89,29 +100,29 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`. - TensorShape VariableRepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`, or of an argument or return value of a top-level computation. + TensorShape RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(const DataType type); // Get an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(const DataType type); // Get an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(const DataType type); // Get an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(const DataType type); // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -119,9 +130,8 @@ class XlaContext : public ResourceBase { private: XlaCompiler* const compiler_; - // The ComputationBuilder used to construct the subgraph's compiled - // representation. - xla::ComputationBuilder* builder_; + // The XlaBuilder used to construct the subgraph's compiled representation. + xla::XlaBuilder* builder_; // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; @@ -135,25 +145,33 @@ class XlaContext : public ResourceBase { std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // A function that describes how variable shapes should be represented - // in XLA. Variable values will be reshaped to this shape. Must be non-null. + // Is this a top-level computation, or an inner computation (e.g., a while + // body)? + const bool is_entry_computation_; + + // A function that describes how the shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared shapes + // for computations. Must be non-null. const std::function* - variable_representation_shape_fn_; + shape_representation_fn_; // Cache of prebuilt computations indexed by their type. - using ComputationMap = std::map; + using ComputationMap = std::map; // Finds the value for the given type in out map if it already // exists or makes a new value with create function and keeps it the // map. The returned value != nullptr and is owned by the map. - const xla::Computation* LookupOrCreate( + const xla::XlaComputation* LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create); + const std::function& create); // Cached computation to compute Max of two elements, specialized by type. ComputationMap max_func_; diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index 8286480e0ea07429adbe31ec4f16d043e321df0a..ead229aaccc292d4944db0c1eaf98c82583533cd 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -30,6 +31,12 @@ bool CpuOpFilter(KernelDef* kdef) { DT_FLOAT); return true; } + if (kdef->op() == "Const") { + AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); + } + if (kdef->op() == "Assert") { + AddDtypeToKernalDefConstraint("T", DT_STRING, kdef); + } return true; } diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index 8ca757e72355d890c13b8b448d35c327d3986696..62168b648331844bfe2db1a4d5dcad895c8726f3 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -25,6 +26,12 @@ bool GpuOpFilter(KernelDef* kdef) { kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; } + if (kdef->op() == "Const") { + AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); + } + if (kdef->op() == "Assert") { + AddDtypeToKernalDefConstraint("T", DT_STRING, kdef); + } return true; } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f048662953e20b2a612271e2daeef6e370c4822a..a1da176fe30ddd0d4460a51b60b2568ecc1af6aa 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,25 +19,27 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { namespace { -Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, bool is_min, - xla::ComputationDataHandle* argminmax) { - xla::ComputationDataHandle init_value; - const xla::Computation* reducer; +Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + bool is_min, xla::XlaOp* argminmax) { + xla::XlaOp init_value; + const xla::XlaComputation* reducer; if (is_min) { init_value = XlaHelpers::MaxValue(builder, input_type); reducer = ctx->GetOrCreateMin(input_type); @@ -49,13 +51,13 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, xla::PrimitiveType xla_output_type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - xla::ComputationDataHandle input_max = builder->Reduce( - input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); + xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, + /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.dims() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. - xla::ComputationDataHandle partial_mask = builder->ConvertElementType( + xla::XlaOp partial_mask = builder->ConvertElementType( builder->Eq(input, input_max, broadcast_dims), xla_output_type); // In order to make identity elements for a bitwise And, we: @@ -64,23 +66,23 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, // 0xFF...F int32 bits_in_type = xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; - xla::ComputationDataHandle shift_amount = + xla::XlaOp shift_amount = XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic( + xla::XlaOp full_mask = builder->ShiftRightArithmetic( builder->ShiftLeft(partial_mask, shift_amount), shift_amount); // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its // index. - xla::ComputationDataHandle iota; + xla::XlaOp iota; const int64 axis_size = input_shape.dim_size(axis); TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); - xla::ComputationDataHandle product = + xla::XlaOp product = builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); // If there are multiple maximum elements, choose the one with the highest // index. - xla::ComputationDataHandle output = + xla::XlaOp output = builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), *ctx->GetOrCreateMax(output_type), /*dimensions_to_reduce=*/{axis}); @@ -90,37 +92,35 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, } // namespace -xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MinValue(type)); } -xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MaxValue(type)); } -xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::Zero(type)); } -xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::One(type)); } -xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { switch (data_type) { + case DT_HALF: + return b->ConstantR0( + static_cast(Eigen::NumTraits::epsilon())); case DT_BFLOAT16: return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: @@ -133,16 +133,15 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, } } -xla::ComputationDataHandle XlaHelpers::IntegerLiteral( - xla::ComputationBuilder* b, DataType data_type, int64 value) { +xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64 value) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return ::tensorflow::IntegerLiteral(b, type, value); } -xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, - DataType data_type, - double value) { +xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return ::tensorflow::FloatLiteral(b, type, value); @@ -179,28 +178,24 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, +Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmax) { + DataType output_type, int axis, xla::XlaOp* argmax) { return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, axis, /*is_min=*/false, argmax); } -Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, +Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmin) { + DataType output_type, int axis, xla::XlaOp* argmin) { return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, axis, /*is_min=*/true, argmin); } -Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, - int64 size, xla::ComputationDataHandle* iota) { +Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, + xla::XlaOp* iota) { TensorShape linspace_shape({size}); Tensor linspace; switch (dtype) { @@ -217,19 +212,17 @@ Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, return errors::InvalidArgument("Invalid argument type ", DataTypeString(dtype)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); return Status::OK(); } -Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, - int axis, DataType index_type, - const TensorShape& indices_shape, - const xla::ComputationDataHandle& indices, - const xla::ComputationDataHandle& on_value, - const xla::ComputationDataHandle& off_value, - xla::ComputationDataHandle* one_hot) { +Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, + DataType index_type, const TensorShape& indices_shape, + const xla::XlaOp& indices, const xla::XlaOp& on_value, + const xla::XlaOp& off_value, xla::XlaOp* one_hot) { const int indices_dims = indices_shape.dims(); const int output_dims = indices_dims + 1; @@ -255,15 +248,15 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::ComputationDataHandle one_hot_bool = builder->Eq( + xla::XlaOp one_hot_bool = builder->Eq( indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. @@ -273,4 +266,19 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, return Status::OK(); } +DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { + return DT_FLOAT; + } + return dtype; +} + +xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, + const xla::XlaOp& operand, + const DataType new_element_type) { + xla::PrimitiveType convert_to; + TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); + return builder->ConvertElementType(operand, convert_to); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 2a027db4c839c917f3a7acd27184792d157356bf..c3fdc5252e74363fe289eeabb2cb0d68298ee291 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -30,41 +30,34 @@ class XlaHelpers { public: // Returns a handle representing the minimum value of a scalar // element of data_type. - static xla::ComputationDataHandle MinValue(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the maximum value of a scalar // element of data_type. - static xla::ComputationDataHandle MaxValue(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the zero value of a scalar // element of data_type. - static xla::ComputationDataHandle Zero(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the one value of a scalar // element of data_type. - static xla::ComputationDataHandle One(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); // Returns the machine epsilon for floating-point type `data_type`, i.e., // the difference between 1.0 and the next representable value. - static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. - static xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* b, - DataType data_type, - int64 value); + static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64 value); // Returns a handle representing the given value of a floating-point scalar // element of data_type. - static xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* b, - DataType data_type, - double value); + static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value); // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. @@ -75,38 +68,43 @@ class XlaHelpers { // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and // `input_dtype` are the shape and dtype of `input` respectively, and // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmax); + static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + xla::XlaOp* argmax); // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and // `input_dtype` are the shape and dtype of `input` respectively, and // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmin); + static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + xla::XlaOp* argmin); // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::ComputationBuilder* builder, DataType dtype, - int64 size, xla::ComputationDataHandle* iota); + static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, + xla::XlaOp* iota); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and // `off_value` represent the values to use for the on and off positions, // respectively. - static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, + static Status OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, - const xla::ComputationDataHandle& indices, - const xla::ComputationDataHandle& on_value, - const xla::ComputationDataHandle& off_value, - xla::ComputationDataHandle* one_hot); + const xla::XlaOp& indices, const xla::XlaOp& on_value, + const xla::XlaOp& off_value, xla::XlaOp* one_hot); + + // Certain DataTypes should use increased precision DataTypes when performing + // reductions. This function remaps a given DataType to a higher precision + // DataType if needed. + static DataType SumAccumulationType(const DataType& dtype); + + // A helper for creating a ConvertElementType xla op given a DataType rather + // than the xla::PrimitiveType. + static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder, + const xla::XlaOp& operand, + const DataType new_element_type); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1fe6e69ff2dc838152032ac3d7b21de41684c6f6..9e17756b27733e2453ea1688d13e1d718c25cfc8 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -112,10 +112,10 @@ void CollectNames(const T& entries, std::vector* nonempty_names, XlaJitCompiledCpuFunction::Compile( const GraphDef& graph_def, const tf2xla::Config& config, const xla::ExecutableBuildOptions& build_options) { - // Convert the graph_def into an xla::Computation. + // Convert the graph_def into an xla::XlaComputation. TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, &computation)); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c4bb90d58755f16672ca7c6a6738065be6330485..76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -30,7 +30,7 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); } -xla::ComputationBuilder* XlaOpKernelContext::builder() const { +xla::XlaBuilder* XlaOpKernelContext::builder() const { return XlaContext::Get(this).builder(); } @@ -38,9 +38,9 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().handle() != 0 || + CHECK(expression->handle().builder() != nullptr || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle().handle(); + VLOG(1) << "Fetched T" << expression->handle(); return expression; } @@ -48,20 +48,18 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().handle(), 0); + CHECK_EQ(expression->handle().builder(), nullptr); return const_cast(expression); } -// Retrieves the ComputationDataHandle from an input Tensor to an Op. This -// computation was constructed by an Op that executed previously and -// created the output Tensor using CreateOutputTensorFromComputation -// or CreateConstantOutputTensor. -static const xla::ComputationDataHandle& GetComputationFromTensor( - const Tensor& tensor) { +// Retrieves the XlaOp from an input Tensor to an Op. This computation was +// constructed by an Op that executed previously and created the output Tensor +// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. +static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { return CastExpressionFromTensor(tensor)->handle(); } -const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { +const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } @@ -106,7 +104,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( return HostTensorToLiteral(temp, constant_literal); } - xla::ComputationDataHandle handle = expression->handle(); + xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. handle = builder()->Reshape(handle, new_shape.dim_sizes()); @@ -141,8 +139,17 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Ask the XLA compiler to evaluate the data handle to a literal. + xla::StatusOr constant_graph = + builder()->BuildConstantSubGraph(handle); + if (!constant_graph.ok()) { + return errors::Internal( + "Error getting a compile-time constant graph for ", + context_->op_kernel().name(), " input ", index, + ".\nError: ", constant_graph.status().error_message()); + } xla::StatusOr> computed = - builder()->ComputeConstant(handle, &layout); + compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), + &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, @@ -260,9 +267,9 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList( - StringPiece name, std::vector* handles, - std::vector* shapes) { +Status XlaOpKernelContext::InputList(StringPiece name, + std::vector* handles, + std::vector* shapes) { OpInputList inputs; TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); @@ -285,9 +292,9 @@ Status XlaOpKernelContext::ConstantInputList( return Status::OK(); } -Status XlaOpKernelContext::ReadVariableInput( - int index, DataType type, TensorShape* shape, - xla::ComputationDataHandle* value) { +Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, + TensorShape* shape, + xla::XlaOp* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -307,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput( } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = xla_context.VariableRepresentationShape( - variable->shape(), variable->type()); + TensorShape representation_shape = + xla_context.RepresentationShape(variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -334,8 +341,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -void XlaOpKernelContext::SetOutput(int index, - const xla::ComputationDataHandle& handle) { +void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; auto shape = builder()->GetShape(handle); @@ -349,7 +355,7 @@ void XlaOpKernelContext::SetOutput(int index, // corresponds. TensorShape tensor_shape; OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape)); + XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, context_->allocate_output(index, tensor_shape, &output)); @@ -364,8 +370,8 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { xla::Literal literal; OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); - xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); - CHECK_NE(handle.handle(), 0); + xla::XlaOp handle = builder()->ConstantLiteral(literal); + CHECK_NE(handle.builder(), nullptr); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -386,8 +392,7 @@ void XlaOpKernelContext::SetInvalidOutput(int index) { OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape({}), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::ComputationDataHandle handle; - handle.set_handle(0); + xla::XlaOp handle; expression->set_handle(handle); } @@ -410,8 +415,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { } Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, - xla::ComputationDataHandle handle) { - TF_RET_CHECK(handle.handle() != 0); + xla::XlaOp handle) { + TF_RET_CHECK(handle.builder() != nullptr); const XlaExpression* expression = CastExpressionFromTensor(context_->input(input_index)); @@ -425,13 +430,13 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, } TensorShape shape; TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.VariableRepresentationShape(shape, type); + xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); } @@ -457,22 +462,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, context_->CtxFailureWithWarning(file, line, s); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMax( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { return XlaContext::Get(context_).GetOrCreateMax(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMin( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( const DataType type) { return XlaContext::Get(context_).GetOrCreateMin(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( const DataType type) { return XlaContext::Get(context_).GetOrCreateAdd(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMul( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const DataType type) { return XlaContext::Get(context_).GetOrCreateMul(type); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 4e4b97e0cec8d16b9b5686a779b1285906765dbd..667dc262ca03ca716ffbf015a78fc14c7a8b7c1a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -58,8 +58,8 @@ class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); - // Returns the XLA ComputationBuilder containing the output of compilation. - xla::ComputationBuilder* builder() const; + // Returns the XLA XlaBuilder containing the output of compilation. + xla::XlaBuilder* builder() const; // Inputs @@ -72,10 +72,10 @@ class XlaOpKernelContext { // Returns the shape of input 'index'. TensorShape InputShape(int index); - // Returns input 'index' as a ComputationDataHandle. Unlike + // Returns input 'index' as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::ComputationDataHandle& Input(int index); + const xla::XlaOp& Input(int index); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -85,8 +85,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, - std::vector* handles, + Status InputList(StringPiece name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -132,10 +131,10 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } - // Sets output 'index' to the ComputationDataHandle 'handle'. + // Sets output 'index' to the XlaOp 'handle'. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. - void SetOutput(int index, const xla::ComputationDataHandle& handle); + void SetOutput(int index, const xla::XlaOp& handle); // Sets output 'index' to compile-time constant 'host_tensor', where // 'host_tensor' is a tensor in host memory. It is preferable to use @@ -168,14 +167,13 @@ class XlaOpKernelContext { // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. Status ReadVariableInput(int index, DataType type, TensorShape* shape, - xla::ComputationDataHandle* value); + xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. - Status AssignVariable(int input_index, DataType type, - xla::ComputationDataHandle handle); + Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -205,22 +203,22 @@ class XlaOpKernelContext { // Gets an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(const DataType type); // Gets an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(const DataType type); // Gets an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(const DataType type); // Gets an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(const DataType type); private: OpKernelContext* const context_; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index bbe808595d958346bd55bf8419306bf3de4cd1d0..4692038b61f6871a8a16299fd4d11e963eb46a57 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); @@ -311,7 +311,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = name.ToString(); + registration_->name = std::string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { @@ -323,14 +323,14 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( gtl::ArraySlice devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); return *this; } @@ -347,7 +347,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; types.insert(allowed); return *this; } @@ -355,7 +355,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, gtl::ArraySlice allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -364,7 +364,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(input_name.ToString()); + registration_->compile_time_constant_inputs.insert(std::string(input_name)); return *this; } @@ -394,7 +394,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( StringPiece name, gtl::ArraySlice types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(name.ToString(), types, op_filter); + registry.RegisterBackend(std::string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index ff7453194af3a85bded86a5ce298f8779422dccb..e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,13 +51,13 @@ constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index c2075b44b82ba279d1246ec6bfcf305d12c418a6..540c65c597f20d5bb26494e56c09ff2187cfb0db 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, - TensorShape shape, - const xla::ComputationDataHandle& initial_value, + TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, const std::set& tensor_array_gradients) : kind_(kind), @@ -41,11 +40,10 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, CHECK(kind_ != kInvalid); for (const string& gradient : tensor_array_gradients) { - tensor_array_gradients_[gradient].reset( - new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), - type_, shape_, xla::ComputationDataHandle(), - tensor_array_size_, /*tensor_array_gradients=*/{})); + tensor_array_gradients_[gradient].reset(new XlaResource( + /*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, + xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); } } @@ -73,7 +71,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { return Status::OK(); } -Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { +Status XlaResource::SetValue(const xla::XlaOp& value) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -83,7 +81,7 @@ Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { return Status::OK(); } -Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { +Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -121,9 +119,9 @@ Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { return Status::OK(); } -Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::ComputationBuilder* builder, - XlaResource** gradient_out) { +Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, + xla::XlaBuilder* builder, + XlaResource** gradient_out) { VLOG(2) << "Gradient lookup for resource: " << name_ << " gradient: " << source; TF_RET_CHECK(kind_ == kTensorArray); @@ -132,7 +130,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient( TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::ComputationDataHandle gradient_value = builder->Broadcast( + xla::XlaOp gradient_value = builder->Broadcast( XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, @@ -144,13 +142,12 @@ Status XlaResource::GetOrCreateTensorArrayGradient( return Status::OK(); } -Status XlaResource::Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const { +Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { if (tensor_array_gradients_.empty()) { *pack = value_; } else { TF_RET_CHECK(kind_ == kTensorArray); - std::vector elems; + std::vector elems; elems.push_back(value_); for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); @@ -161,8 +158,8 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack, } Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder) { + const xla::XlaOp& pack, + xla::XlaBuilder* builder) { if (gradient_sources.empty()) { if (!initialized()) { initial_value_ = pack; diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 1bb2c7274ecdf0954768fd96def51194e52deee8..9ce36d1aa7622334b2acfbe9aa85d7419c4772ed 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -37,8 +37,7 @@ class XlaResource { }; XlaResource(Kind kind, int arg_num, string name, DataType type, - TensorShape shape, - const xla::ComputationDataHandle& initial_value, + TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, const std::set& tensor_array_gradients); @@ -69,16 +68,14 @@ class XlaResource { // this is the shape of each entry in the TensorArray/Stack. const TensorShape& shape() const { return shape_; } - const xla::ComputationDataHandle& value() const { return value_; } + const xla::XlaOp& value() const { return value_; } // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. - const xla::ComputationDataHandle& initial_value() const { - return initial_value_; - } + const xla::XlaOp& initial_value() const { return initial_value_; } // A variable is initialized if it has a value. - bool initialized() const { return value_.handle() > 0; } + bool initialized() const { return value_.builder() != nullptr; } // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. @@ -86,17 +83,17 @@ class XlaResource { // Sets the current value of the resource. Returns an error if the type is not // set to a valid value. - Status SetValue(const xla::ComputationDataHandle& value); + Status SetValue(const xla::XlaOp& value); // Sets the current value of the resource to an all-zero value. - Status SetZeroValue(xla::ComputationBuilder* builder); + Status SetZeroValue(xla::XlaBuilder* builder); // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator // documentation for TensorArrayGradV3 for details. Status GetOrCreateTensorArrayGradient(const string& source, - xla::ComputationBuilder* builder, + xla::XlaBuilder* builder, XlaResource** gradient_out); // Packs a resource into a single XLA value `pack`, suitable for use as @@ -104,8 +101,7 @@ class XlaResource { // gradients, sets `*pack` to `value`. // For TensorArrays with gradients, packs the value and its gradient values in // a tuple; the gradients values are packed in order by source name. - Status Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const; + Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and @@ -114,8 +110,7 @@ class XlaResource { // values. // Opposite of Pack(). Status SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder); + const xla::XlaOp& pack, xla::XlaBuilder* builder); // TensorArray and Stack specific fields @@ -144,8 +139,8 @@ class XlaResource { DataType type_; TensorShape shape_; - xla::ComputationDataHandle value_; - xla::ComputationDataHandle initial_value_; + xla::XlaOp value_; + xla::XlaOp initial_value_; int64 tensor_array_size_ = -1; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 34e733bc8d80b364cec1783006eba0a5468b55ea..1b8e516770c3e217dd7c2f26ce426895b478c2e4 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -52,7 +52,7 @@ xla_proto_library( visibility = ["//visibility:public"], deps = [ ":xla_data_proto", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", ], ) @@ -98,8 +98,9 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":status", + ":xla_data_proto", ":xla_proto", - "//tensorflow/core:lib", ], ) @@ -243,6 +244,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protobuf_util", + ":status", ":status_macros", ":statusor", ":types", @@ -301,13 +303,13 @@ cc_library( ":array2d", ":array3d", ":array4d", - ":shape_tree", ":shape_util", ":sparse_index_array", ":status_macros", ":types", ":util", ":xla_data_proto", + "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) @@ -322,12 +324,30 @@ tf_cc_test( ":shape_util", ":test", ":types", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) +cc_library( + name = "error_spec", + hdrs = ["error_spec.h"], +) + +cc_library( + name = "literal_comparison", + srcs = ["literal_comparison.cc"], + hdrs = ["literal_comparison.h"], + deps = [ + ":error_spec", + ":literal_util", + ":util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "metric_table_report", srcs = ["metric_table_report.cc"], @@ -372,7 +392,6 @@ tf_cc_test( cc_library( name = "array2d", - srcs = ["array2d.cc"], hdrs = ["array2d.h"], visibility = ["//visibility:public"], deps = [ @@ -443,6 +462,9 @@ cc_library( srcs = ["executable_run_options.cc"], hdrs = ["executable_run_options.h"], visibility = ["//visibility:public"], + deps = [ + ":types", + ], ) cc_library( @@ -560,6 +582,7 @@ tf_cc_test( ":shape_util", ":test", ":xla_data_proto", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -602,8 +625,8 @@ cc_library( ":util", ":window_util", ":xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", @@ -654,18 +677,6 @@ tf_cc_test( # ----------------------------------------------------------------------------- -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. cc_header_only_library( name = "xla_headers_lib", diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index c93c39e180655e7930e943e6aa6514c47f2859d7..39f8caaa961dc7b57d2b45f974fc6ecf89cf6748 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -1 +1,7 @@ -This is the home of XLA. +

+ +

+ +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear +algebra that optimizes TensorFlow computations. See the +[documentation](https://www.tensorflow.org/performance/xla/) for more details. diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 46ee4e64c9ae7ca111d9d04bedcb74ff02a42386..ea75ad32d5df7bbadd37e89de6144b264ab6d5d1 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -121,10 +122,31 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 2D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 1D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && + std::is_same::value>::type> + Array(std::initializer_list values) + : Array(ToInt64Vector({values.size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + values_[idx] = static_cast(it1); + ++idx; + } + CHECK(idx == num_elements()); + } + + // Creates a 2D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list> values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { @@ -155,10 +177,13 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 3D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 3D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list>> values) @@ -196,10 +221,13 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 4D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 4D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list< std::initializer_list>>> diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 41f563486d21e42e88dcf6c751ce4a64da5e3213..a17e81f44832f272fd93dce9f854042b4a84fde4 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -52,10 +53,13 @@ class Array2D : public Array { Array2D(std::initializer_list> values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array2D(std::initializer_list> values) : Array(values) {} @@ -94,9 +98,23 @@ class Array2D : public Array { // Returns a linspace-populated Array2D in the range [from, to] (inclusive) // with dimensions n1 x n2. -std::unique_ptr> MakeLinspaceArray2D(float from, float to, - int64 n1, int64 n2); - +template +std::unique_ptr> MakeLinspaceArray2D(double from, double to, + int64 n1, int64 n2) { + auto array = MakeUnique>(n1, n2); + int64 count = n1 * n2; + NativeT step = + static_cast((count > 1) ? (to - from) / (count - 1) : 0); + auto set = [&array, n1, n2](int64 index, NativeT value) { + (*array)(index / n2, index % n2) = value; + }; + for (int64 i = 0; i < count - 1; ++i) { + set(i, (static_cast(from) + + static_cast(i) * static_cast(step))); + } + set(count - 1, static_cast(to)); + return array; +} } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_ARRAY2D_H_ diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index e5eb235d45d160d486d1499db665ed14a8509043..0e9a0722ae43e1dc6ecddde9cbc3daf1db058840 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -57,10 +57,13 @@ class Array3D : public Array { values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array3D( std::initializer_list>> diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index cff70e54bad0116bdd08674b626b3bf99dc89e1f..a75fffc605aa0df3e1e2eeb6d3129718cbbba0e4 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -82,10 +82,13 @@ class Array4D : public Array { values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array4D(std::initializer_list>>> diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 02356699a25e47be50eb15872df4c9c302fc289b..8f08d3b2e04670ad6590aca1db0fd9d25faed83f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -74,8 +73,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -86,6 +86,7 @@ cc_library( hdrs = ["executable_build_options.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", @@ -98,17 +99,18 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", @@ -124,11 +126,11 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", @@ -159,47 +161,6 @@ cc_library( ], ) -cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "computation_builder", - srcs = ["computation_builder.cc"], - hdrs = ["computation_builder.h"], - deps = [ - ":client", - ":computation", - ":global_data", - ":padding", - "//tensorflow/compiler/xla:array", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], @@ -213,17 +174,3 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d15ccb0c28522c647617153aaa8e738d029dfaba..3d596a6e65430b6e9692aabd65fc8aa84b7b873d 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -64,7 +64,7 @@ StatusOr> Client::Transfer( } StatusOr> Client::TransferToServer( - const Literal& literal, const DeviceHandle* device_handle) { + const LiteralSlice& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; *request.mutable_literal() = literal.ToProto(); if (device_handle) { @@ -91,7 +91,7 @@ StatusOr> Client::TransferToServer( return MakeUnique(stub_, response.data()); } -Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, +Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; *request.mutable_literal() = literal.ToProto(); @@ -162,7 +162,7 @@ Status Client::ResetDevice() { } StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { @@ -177,27 +177,46 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr Client::LoadSnapshot(const SessionModule& module) { - LoadComputationSnapshotRequest request; - *request.mutable_module() = module; - LoadComputationSnapshotResponse response; +StatusOr> Client::ComputeConstant( + const XlaComputation& computation, const Layout* output_layout) const { + ComputeConstantGraphRequest request; + *request.mutable_computation() = computation.proto(); + if (output_layout != nullptr) { + *request.mutable_output_layout() = *output_layout; + } + + ComputeConstantResponse response; + + VLOG(2) << "making compute-constant-graph request"; + Status s = stub_->ComputeConstantGraph(&request, &response); + VLOG(2) << "done with request"; - Status s = stub_->LoadComputationSnapshot(&request, &response); if (!s.ok()) { return s; } - VLOG(1) << "load snapshot response: " << response.ShortDebugString(); - return Computation(stub_, response.computation()); + VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return InternalError( + "no computed literal in the provided response in ComputeConstantGraph " + "request"); + } + return Literal::CreateFromProto(response.literal()); +} + +StatusOr Client::LoadSnapshot(const HloSnapshot& module) { + TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); + return XlaComputation(module.hlo().hlo_module()); } StatusOr> Client::Execute( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); + ExecuteGraphRequest request; + *request.mutable_computation() = computation.proto(); if (execution_options == nullptr) { *request.mutable_execution_options() = CreateDefaultExecutionOptions(); @@ -211,7 +230,7 @@ StatusOr> Client::Execute( ExecuteResponse response; VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); + Status s = stub_->ExecuteGraph(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -232,12 +251,12 @@ StatusOr> Client::Execute( } StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; + tensorflow::gtl::ArraySlice computations) { + ExecuteGraphParallelRequest request; - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); + for (const XlaComputationInstance& computation : computations) { + ExecuteGraphRequest single_request; + *single_request.mutable_computation() = computation.computation.proto(); for (GlobalData* argument : computation.arguments) { *single_request.add_arguments() = argument->handle(); } @@ -246,8 +265,9 @@ StatusOr>> Client::ExecuteParallel( } ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteParallel(&request, &response); + VLOG(1) << "making execute-graph-parallel request: " + << request.ShortDebugString(); + Status s = stub_->ExecuteGraphParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -276,7 +296,7 @@ StatusOr> Client::GetDeviceHandles( GetDeviceHandlesResponse response; VLOG(1) << "making get device request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + Status s = stub_->GetDeviceHandles(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -325,14 +345,17 @@ StatusOr>> Client::DeconstructTuple( } StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); + const XlaComputation& computation, + const DebugOptions& debug_options) const { + ComputationGraphStatsRequest request; + + // TODO(b/74197823): Find a way to avoid the copy of the hlo proto. + *request.mutable_computation() = computation.proto(); *request.mutable_debug_options() = debug_options; ComputationStatsResponse response; - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); + VLOG(1) << "making computation graph stats request"; + Status s = stub_->GetComputationGraphStats(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -343,20 +366,9 @@ StatusOr Client::GetComputationStats( } StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); + return MakeUnique(result); } StatusOr Client::GetShape(const GlobalData& data) { @@ -376,7 +388,7 @@ StatusOr Client::GetShape(const GlobalData& data) { } StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { + const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, GetComputationStats(computation, diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index c28380b689c7a0e16bf0bcbf15003f4aa15e42a7..68f0d0ac78c859fde7a6a007cd250b047a7bfcda 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,7 +52,7 @@ class Client { // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. StatusOr> Execute( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -62,26 +62,27 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; + struct XlaComputationInstance { + const XlaComputation& computation; std::vector arguments; ExecutionOptions execution_options; ExecutionProfile* execution_profile; - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) + XlaComputationInstance(const XlaComputation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) : computation(computation), arguments(std::move(arguments)), execution_options(execution_options), execution_profile(execution_profile) {} }; - // Executes a list ComputationInstances and returns global data produced from - // each computation. + // Executes a list XlaComputationInstances and returns global data produced + // from each computation. + // StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); + tensorflow::gtl::ArraySlice computations); // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations @@ -106,14 +107,14 @@ class Client { // device (and its replicas if replication is enabled). Otherwise, data is // transferred to the default device (and its replicas). StatusOr> TransferToServer( - const Literal& literal, const DeviceHandle* device_handle = nullptr); + const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr); // Transfer the given literal to the Infeed interface of the device. // // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, + Status TransferToInfeed(const LiteralSlice& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); // Transfers from the Outfeed of the device. @@ -132,11 +133,30 @@ class Client { // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). StatusOr> ExecuteAndTransfer( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Computes the value of the given computation using a non-optimized + // interpreter on the host. + // + // The computation must not depend on any parameters, or on stateful operators + // such as `RngNormal` or `Infeed`. + // + // This functionality can be useful when translating a computation into XLA + // where something that looked dynamic is required by XLA to be specified as a + // constant. E.g. the source computation (outside of XLA) may include a + // dynamic computation of the shape of something and ComputeConstant lets you + // determine what the value of that computation is in the case where the value + // can be determined at compile time. + // + // If output_layout is non-null, then the output of the computation will be + // stored using that layout. + StatusOr> ComputeConstant( + const XlaComputation& computation, + const Layout* output_layout = nullptr) const; + // Unregister the memory for the given GlobalData on the device. Status Unregister(const GlobalData& data); @@ -146,7 +166,8 @@ class Client { // Retrieves the statistics of the given computation. StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; + const XlaComputation& computation, + const DebugOptions& debug_options) const; // Returns the Shape of the given array specified by 'data'. The shape // includes the Layout of the array as it is stored on the service. @@ -155,20 +176,20 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). StatusOr> GetComputationShape( - const Computation& computation); + const XlaComputation& computation); // Creates a channel handle that can be used to transfer data between // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - StatusOr LoadSnapshot(const SessionModule& module); + StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, + StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); ServiceInterface* stub_; // Stub that this client is connected on. diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index b1663bc815719c3da75b37593ac665b1f3493db8..803a9e40094391ba47ed27713f4538caf875c4f6 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -23,22 +23,19 @@ limitations under the License. namespace xla { -LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform, +LocalClientOptions::LocalClientOptions(se::Platform* platform, int number_of_replicas, int intra_op_parallelism_threads) : platform_(platform), number_of_replicas_(number_of_replicas), intra_op_parallelism_threads_(intra_op_parallelism_threads) {} -LocalClientOptions& LocalClientOptions::set_platform( - perftools::gputools::Platform* platform) { +LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) { platform_ = platform; return *this; } -perftools::gputools::Platform* LocalClientOptions::platform() const { - return platform_; -} +se::Platform* LocalClientOptions::platform() const { return platform_; } LocalClientOptions& LocalClientOptions::set_number_of_replicas( int number_of_replicas) { @@ -69,7 +66,7 @@ ClientLibrary::ClientLibrary() = default; ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( - perftools::gputools::Platform* platform) { + se::Platform* platform) { LocalClientOptions default_options; default_options.set_platform(platform); return GetOrCreateLocalClient(default_options); @@ -77,7 +74,7 @@ ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( const LocalClientOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); int replica_count = options.number_of_replicas(); ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); @@ -115,7 +112,7 @@ ClientLibrary::~ClientLibrary() = default; } /* static */ LocalService* ClientLibrary::GetXlaService( - perftools::gputools::Platform* platform) { + se::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); auto it = client_library.local_instances_.find(platform->id()); @@ -124,8 +121,7 @@ ClientLibrary::~ClientLibrary() = default; } /* static */ StatusOr -ClientLibrary::GetOrCreateCompileOnlyClient( - perftools::gputools::Platform* platform) { +ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index a6f30d82e43587135697e76e8bc7d122edc0f602..3ad558fa532931937fab898f7b855f0a3370eaec 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -43,13 +43,13 @@ namespace xla { // Options to configure the local client when it is created. class LocalClientOptions { public: - LocalClientOptions(perftools::gputools::Platform* platform = nullptr, + LocalClientOptions(se::Platform* platform = nullptr, int number_of_replicas = 1, int intra_op_parallelism_threads = -1); // Set the platform backing the service, or nullptr for the default platform. - LocalClientOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; + LocalClientOptions& set_platform(se::Platform* platform); + se::Platform* platform() const; // Set the number of replicas to use when compiling replicated // programs. @@ -61,7 +61,7 @@ class LocalClientOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_; + se::Platform* platform_; int number_of_replicas_; int intra_op_parallelism_threads_; }; @@ -74,7 +74,7 @@ class ClientLibrary { // platform : The platform the underlying XLA service should target. If // null then default platform is used. static StatusOr GetOrCreateLocalClient( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); static StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); @@ -84,14 +84,14 @@ class ClientLibrary { // Returns the service from the service thread. Only used in unit tests to // access user computations from client. - static LocalService* GetXlaService(perftools::gputools::Platform* platform); + static LocalService* GetXlaService(se::Platform* platform); // Singleton constructor-or-accessor for compile-only clients. Arguments: // // platform : The platform the underlying XLA service should target. If // null then default platform is used. static StatusOr GetOrCreateCompileOnlyClient( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); // Clears the local instance and compile only instance caches. The client // pointers returned by the previous GetOrCreateLocalClient() or @@ -120,12 +120,10 @@ class ClientLibrary { }; tensorflow::mutex service_mutex_; // Guards the singleton creation state. - std::unordered_map> + std::unordered_map> local_instances_ GUARDED_BY(service_mutex_); - std::unordered_map> + std::unordered_map> compile_only_instances_ GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index c7e2c4367b89ca2112022fa40449ae3ebe28463e..5c9abad4c3126be5e45e96c770c0679fe8606788 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -23,32 +23,33 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector service_instances; + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { + std::vector service_instances; service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = + for (const AotXlaComputationInstance& instance : computations) { + service_instances.emplace_back(); + CompileOnlyService::AotXlaComputationInstance& service_instance = service_instances.back(); TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); + service_instance.computation = instance.computation->proto(); service_instance.argument_layouts = instance.argument_layouts; service_instance.result_layout = instance.result_layout; } - return compiler_service_->CompileAheadOfTime(service_instances, options); + return compiler_service_->CompileAheadOfTime(service_instances, options, + metadata); } -int64 CompileOnlyClient::PointerSizeForTriple( - tensorflow::StringPiece target_triple) { - llvm::Triple triple(llvm::Triple::normalize( - llvm::StringRef(target_triple.data(), target_triple.size()))); - if (triple.isArch64Bit()) { +int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { + llvm::Triple llvm_triple( + llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); + if (llvm_triple.isArch64Bit()) { return 8; - } else if (triple.isArch32Bit()) { + } else if (llvm_triple.isArch32Bit()) { return 4; } else { - CHECK(triple.isArch16Bit()); + CHECK(llvm_triple.isArch16Bit()); return 2; } } diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index 5900048711384e0240a3cd502260eb388eb40f51..332c96503637344d56e363e19db4880c37ca9684 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/statusor.h" @@ -37,22 +37,24 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; + // A description of an xla computation to compile using CompileAheadOfTime. + struct AotXlaComputationInstance { + const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. std::vector argument_layouts; // Specifies the expected result layout. const Shape* result_layout; }; - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. + // Compiles a list of xla computations for ahead-of-time execution. + // This is intended for use in static compilation. The |options| + // parameter describes the target for which the compiler should emit + // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. static int64 PointerSizeForTriple(tensorflow::StringPiece triple); diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bda0f0c4cb969939883efebcf3a6d6be381..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index a53fc9e9cf34704bd08ddb5bf062c1ec1107f5fb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc deleted file mode 100644 index 2a6e02649d15bc9fd47a893c41f9c8a62ac076c6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ /dev/null @@ -1,1580 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation_builder.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { - -ComputationBuilder::ComputationBuilder(Client* client, - const string& computation_name) - : name_(computation_name), client_(client) {} - -ComputationBuilder::~ComputationBuilder() {} - -void ComputationBuilder::NoteError(const Status& error) { - if (die_immediately_on_error_) { - LOG(FATAL) << "error building computation: " << error; - } - - if (first_error_.ok()) { - first_error_ = error; - first_error_backtrace_.CreateCurrent(/*skip_count=*/1); - } -} - -std::unique_ptr ComputationBuilder::CreateSubBuilder( - const string& computation_name) { - auto sub_builder = MakeUnique(client_, computation_name); - sub_builder->parent_builder_ = this; - sub_builder->die_immediately_on_error_ = die_immediately_on_error_; - return sub_builder; -} - -Status ComputationBuilder::PrepareComputation() { - TF_RETURN_IF_ERROR(first_error_); - - if (!computation_.IsNull()) { - return Status::OK(); - } - - ComputationRequest request; - request.set_name(name_); - ComputationResponse response; - - VLOG(2) << "making computation request"; - Status s = client_->stub()->Computation(&request, &response); - VLOG(2) << "done with computation request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - computation_ = Computation(client_->stub(), response.computation()); - return Status::OK(); -} - -Status ComputationBuilder::RunOp(OpRequest* op_request, - OpResponse* op_response) { - TF_RETURN_IF_ERROR(first_error_); - TF_RETURN_IF_ERROR(PrepareComputation()); - - // Fill in fields that are set on every OpRequest. - *op_request->mutable_computation() = computation_.handle(); - *op_request->mutable_metadata() = metadata_; - if (sharding_) { - *op_request->mutable_sharding() = *sharding_; - } - - const string& op_name = - OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); - VLOG(2) << "running op request: " << op_name; - Status status = client_->stub()->Op(op_request, op_response); - VLOG(2) << "done with op request: " << op_name; - return status; -} - -void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - } -} - -ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( - OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - if (op_response.output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return op_response.output(); -} - -bool ComputationBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const size_t x, const char* x_name) { - if (x == 0 || x == window_dimensions.size()) { - return true; - } else { - NoteError(InvalidArgument( - "%s", tensorflow::strings::StrCat( - "Window has different number of window dimensions than of ", - x_name, "\nNumber of window dimensions: ", - window_dimensions.size(), "\nNumber of ", x_name, ": ", x, - "\n") - .c_str())); // - return false; - } - }; - if (!verify_size(window_strides.size(), "window strides") || - !verify_size(padding.size(), "padding entries") || - !verify_size(lhs_dilation.size(), "lhs dilation factors") || - !verify_size(rhs_dilation.size(), "rhs dilation factors")) { - return false; - } - - window->Clear(); - for (size_t i = 0; i < window_dimensions.size(); i++) { - auto dim = window->add_dimensions(); - dim->set_size(window_dimensions[i]); - if (!window_strides.empty()) { - dim->set_stride(window_strides[i]); - } else { - dim->set_stride(1); - } - if (!padding.empty()) { - dim->set_padding_low(padding[i].first); - dim->set_padding_high(padding[i].second); - } else { - dim->set_padding_low(0); - dim->set_padding_high(0); - } - if (!lhs_dilation.empty()) { - dim->set_base_dilation(lhs_dilation[i]); - } else { - dim->set_base_dilation(1); - } - if (!rhs_dilation.empty()) { - dim->set_window_dilation(rhs_dilation[i]); - } else { - dim->set_window_dilation(1); - } - dim->set_window_reversal(false); - } - return true; -} - -ComputationDataHandle ComputationBuilder::ConstantLiteral( - const Literal& literal) { - OpRequest op_request; - ConstantRequest* request = op_request.mutable_constant_request(); - *request->mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request->literal().ShortDebugString(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { - OpRequest op_request; - ParameterRequest* request = op_request.mutable_parameter_request(); - *request->mutable_shape() = shape; - request->set_parameter(parameter_number); - request->set_name(name); - return RunOpAndParseResponse(&op_request); -} - -StatusOr> ComputationBuilder::GetShapeWithoutNoteError( - const ComputationDataHandle& operand) { - GetLocalShapeRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - GetLocalShapeResponse response; - - VLOG(2) << "making get-shape request"; - TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); - VLOG(2) << "done with request"; - - TF_RET_CHECK(response.has_shape()); - std::unique_ptr shape = WrapUnique(response.release_shape()); - TF_RET_CHECK(shape != nullptr); - return std::move(shape); -} - -StatusOr> ComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - auto status_or_shape = GetShapeWithoutNoteError(operand); - if (!status_or_shape.ok()) { - NoteError(status_or_shape.status()); - return first_error_; - } - return status_or_shape; -} - -StatusOr ComputationBuilder::GetProgramShape() { - TF_RETURN_IF_ERROR(first_error_); - - GetComputationShapeRequest request; - *request.mutable_computation() = computation_.handle(); - GetComputationShapeResponse response; - - VLOG(2) << "making get-program-shape-request"; - Status status = client_->stub()->GetComputationShape(&request, &response); - VLOG(2) << "done with get-program-shape-request"; - - if (!status.ok()) { - first_error_ = status; - return status; - } - - TF_RET_CHECK(response.has_program_shape()); - return std::move(*response.mutable_program_shape()); -} - -ComputationDataHandle ComputationBuilder::CheckShape( - const ComputationDataHandle& operand, const Shape& expected_shape) { - std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); - CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) - << "want " << ShapeUtil::HumanString(expected_shape) << " got " - << ShapeUtil::HumanString(*actual_shape); - return operand; -} - -void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { - std::unique_ptr lhs_shape = GetShape(lhs).ConsumeValueOrDie(); - std::unique_ptr rhs_shape = GetShape(rhs).ConsumeValueOrDie(); - VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals " - << ShapeUtil::HumanString(*rhs_shape); - CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape)) - << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs " - << ShapeUtil::HumanString(*rhs_shape); -} - -ComputationDataHandle ComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - OpRequest op_request; - SliceRequest* request = op_request.mutable_slice_request(); - *request->mutable_operand() = operand; - for (int64 index : start_indices) { - request->add_start_indices(index); - } - for (int64 index : limit_indices) { - request->add_limit_indices(index); - } - for (int64 index : strides) { - request->add_strides(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - NoteError(shape_status.status()); - return ComputationDataHandle{}; - } - const Shape& shape = *shape_status.ValueOrDie(); - std::vector starts(ShapeUtil::Rank(shape), 0); - std::vector limits(shape.dimensions().begin(), - shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); - starts[dimno] = start_index; - limits[dimno] = limit_index; - strides[dimno] = stride; - return Slice(operand, starts, limits, strides); -} - -ComputationDataHandle ComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - OpRequest op_request; - DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_start_indices() = start_indices; - for (int64 index : slice_sizes) { - request->add_slice_sizes(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - OpRequest op_request; - DynamicUpdateSliceRequest* request = - op_request.mutable_dynamic_update_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_update() = update; - *request->mutable_start_indices() = start_indices; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - OpRequest op_request; - ConcatenateRequest* request = op_request.mutable_concatenate_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - request->set_dimension(dimension); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { - OpRequest op_request; - BroadcastRequest* request = op_request.mutable_broadcast_request(); - *request->mutable_operand() = operand; - for (int64 size : broadcast_sizes) { - request->add_broadcast_sizes(size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - OpRequest op_request; - PadRequest* request = op_request.mutable_pad_request(); - *request->mutable_operand() = operand; - *request->mutable_padding_value() = padding_value; - *request->mutable_padding_config() = padding_config; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { - OpRequest op_request; - ReshapeRequest* request = op_request.mutable_reshape_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (int64 new_size : new_sizes) { - request->add_new_sizes(new_size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - std::vector dimensions(shape.ValueOrDie()->dimensions().size()); - std::iota(dimensions.begin(), dimensions.end(), 0); - return Reshape(operand, dimensions, new_sizes); -} - -ComputationDataHandle ComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dims_to_collapse) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - // Don't support out-of-order collapse here. - // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dims_to_collapse.size(); ++i) { - if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { - NoteError(InvalidArgument( - "Collapsed dimensions are not in order and consecutive.")); - return ComputationDataHandle(); - } - } - - // Create a new sizes vector from the old shape, replacing the collapsed - // dimensions by the product of their sizes. - StatusOr> shape_or_status = GetShape(operand); - if (!shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); - - VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dims_to_collapse, ","); - - if (dims_to_collapse.size() <= 1) { - // Not collapsing anything, trivially we can return the operand versus - // enqueueing a trivial reshape. - return operand; - } - - std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { - new_sizes.push_back(original_shape->dimensions(i)); - } else { - new_sizes.back() *= original_shape->dimensions(i); - } - } - - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; - - return Reshape(operand, new_sizes); -} - -void ComputationBuilder::Trace(const string& tag, - const ComputationDataHandle& operand) { - OpRequest op_request; - TraceRequest* request = op_request.mutable_trace_request(); - request->set_tag(tag); - *request->mutable_operand() = operand; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); -} - -ComputationDataHandle ComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - OpRequest op_request; - VariadicOpRequest* request = op_request.mutable_variadic_op_request(); - request->set_varop(VAROP_TUPLE); - for (const ComputationDataHandle& operand : elements) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - OpRequest op_request; - GetTupleElementRequest* request = - op_request.mutable_get_tuple_element_request(); - *request->mutable_operand() = tuple_data; - request->set_index(index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - - DotDimensionNumbers dimension_numbers; - dimension_numbers.add_lhs_contracting_dimensions( - lhs_shape->dimensions_size() == 1 ? 0 : 1); - dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers) { - OpRequest op_request; - DotRequest* request = op_request.mutable_dot_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conv( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return ConvWithGeneralDimensions( - lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -bool ComputationBuilder::VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { - NoteError( - InvalidArgument("Convolution arguments must have same number of " - "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 2) { - NoteError(InvalidArgument( - "Convolution expects argument arrays with >= 3 dimensions. " - "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_spatial_dims = num_dims - 2; - - const auto check_spatial_dimensions = - [&](const char* const field_name, - const tensorflow::protobuf::RepeatedField& - numbers) { - if (numbers.size() != num_spatial_dims) { - NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", - num_spatial_dims, field_name, - numbers.size())); - return false; - } - for (int i = 0; i < numbers.size(); ++i) { - if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - NoteError( - InvalidArgument("Convolution %s[%d] is out of bounds: %lld", - field_name, i, numbers.Get(i))); - return false; - } - } - return true; - }; - return check_spatial_dimensions( - "input_spatial_dimensions", - dimension_numbers.input_spatial_dimensions()) && - check_spatial_dimensions( - "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()) && - check_spatial_dimensions( - "output_spatial_dimensions", - dimension_numbers.output_spatial_dimensions()); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - NoteError(InternalError("failed to verify convolution")); - return ComputationDataHandle(); - } - - std::vector base_area_dimensions( - dimension_numbers.input_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < base_area_dimensions.size(); - ++i) { - base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - return ConvGeneral(lhs, rhs, window_strides, - MakePadding(base_area_dimensions, window_dimensions, - window_strides, padding), - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - // Error is recorded in VerifyConvolution. - return ComputationDataHandle(); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - OpRequest op_request; - ConvolveRequest* request = op_request.mutable_convolve_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - - if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request->mutable_window())) { - // Error is recorded in MakeWindow. - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Fft( - const ComputationDataHandle& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { - OpRequest op_request; - FftRequest* request = op_request.mutable_fft_request(); - *request->mutable_operand() = operand; - request->set_fft_type(fft_type); - for (int64 dim_len : fft_length) { - request->add_fft_length(dim_len); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - OpRequest op_request; - InfeedRequest* request = op_request.mutable_infeed_request(); - *request->mutable_shape() = shape; - *request->mutable_config() = config; - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape, - const string& outfeed_config) { - OpRequest op_request; - OutfeedRequest* request = op_request.mutable_outfeed_request(); - request->set_outfeed_config(outfeed_config); - *request->mutable_operand() = operand; - *request->mutable_shape() = shape; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands) { - OpRequest op_request; - CallRequest* request = op_request.mutable_call_request(); - *request->mutable_to_apply() = computation.handle(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { - OpRequest op_request; - CustomCallRequest* request = op_request.mutable_custom_call_request(); - request->set_call_target_name(call_target_name); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - OpRequest op_request; - HostComputeRequest* request = op_request.mutable_host_compute_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - request->set_channel_name(channel_name); - request->set_cost_estimate_ns(cost_estimate_ns); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Conj( - const ComputationDataHandle& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -ComputationDataHandle ComputationBuilder::Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Not( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NOT, operand); -} - -ComputationDataHandle ComputationBuilder::ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Abs( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ABS, operand); -} - -ComputationDataHandle ComputationBuilder::Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Exp( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_EXP, operand); -} - -ComputationDataHandle ComputationBuilder::Floor( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_FLOOR, operand); -} - -ComputationDataHandle ComputationBuilder::Ceil( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CEIL, operand); -} - -ComputationDataHandle ComputationBuilder::Round( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand); -} - -ComputationDataHandle ComputationBuilder::Log( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOG, operand); -} - -ComputationDataHandle ComputationBuilder::Sign( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIGN, operand); -} - -ComputationDataHandle ComputationBuilder::Cos( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_COS, operand); -} - -ComputationDataHandle ComputationBuilder::Sin( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIN, operand); -} - -ComputationDataHandle ComputationBuilder::Tanh( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_TANH, operand); -} - -ComputationDataHandle ComputationBuilder::Real( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_REAL, operand); -} - -ComputationDataHandle ComputationBuilder::Imag( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IMAG, operand); -} - -ComputationDataHandle ComputationBuilder::IsFinite( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IS_FINITE, operand); -} - -ComputationDataHandle ComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - OpRequest op_request; - TransposeRequest* request = op_request.mutable_transpose_request(); - *request->mutable_operand() = operand; - for (int64 dimension : permutation) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - OpRequest op_request; - ReverseRequest* request = op_request.mutable_reverse_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Sort( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SORT, operand); -} - -ComputationDataHandle ComputationBuilder::SqrtF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BitcastConvertType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_bitcast_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SquareF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::ReciprocalF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Neg( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NEGATE, operand); -} - -ComputationDataHandle ComputationBuilder::Clamp( - const ComputationDataHandle& min, const ComputationDataHandle& operand, - const ComputationDataHandle& max) { - return TernaryOp(TRIOP_CLAMP, min, operand, max); -} - -ComputationDataHandle ComputationBuilder::UnaryOp( - UnaryOperation unop, const ComputationDataHandle& operand) { - OpRequest op_request; - UnaryOpRequest* request = op_request.mutable_unary_op_request(); - request->set_unop(unop); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - OpRequest op_request; - BinaryOpRequest* request = op_request.mutable_binary_op_request(); - request->set_binop(binop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - for (int64 dimension : broadcast_dimensions) { - request->add_broadcast_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape) { - OpRequest op_request; - RngRequest* request = op_request.mutable_rng_request(); - request->set_distribution(distribution); - for (const ComputationDataHandle& param : parameters) { - *request->add_parameter() = param; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::TernaryOp( - TernaryOperation triop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - OpRequest op_request; - TernaryOpRequest* request = op_request.mutable_ternary_op_request(); - request->set_triop(triop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_ehs() = ehs; - return RunOpAndParseResponse(&op_request); -} - -Status ComputationBuilder::SetReturnValue( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - SetReturnValueRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - - SetReturnValueResponse response; - - VLOG(2) << "making set-handle-to-execute request"; - Status s = client_->stub()->SetReturnValue(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - return Status::OK(); -} - -StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - TF_RETURN_IF_ERROR(first_error_); - - IsConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - request.set_num_parameters(num_parameters); - IsConstantResponse response; - - VLOG(2) << "making IsConstant request"; - Status s = client_->stub()->IsConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - return response.is_constant(); -} - -StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - TF_RETURN_IF_ERROR(first_error_); - - ComputeConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; - } - for (const auto& param : parameters) { - *request.add_parameters() = param.ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant request"; - Status s = client_->stub()->ComputeConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return InternalError( - "no computed literal in the provided response in ComputeConstant " - "request"); - } - return Literal::CreateFromProto(response.literal()); -} - -ComputationDataHandle ComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - OpRequest op_request; - MapRequest* request = op_request.mutable_map_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_to_apply() = computation.handle(); - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (const ComputationDataHandle& sop : static_operands) { - *request->add_static_operands() = sop; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); -} - -ComputationDataHandle ComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); -} - -ComputationDataHandle ComputationBuilder::While( - const Computation& condition, const Computation& body, - const ComputationDataHandle& init) { - OpRequest op_request; - WhileRequest* request = op_request.mutable_while_request(); - *request->mutable_condition() = condition.handle(); - *request->mutable_body() = body.handle(); - *request->mutable_init() = init; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - OpRequest op_request; - GatherRequest* gather_request = op_request.mutable_gather_request(); - *gather_request->mutable_input() = input; - *gather_request->mutable_gather_indices() = gather_indices; - *gather_request->mutable_dimension_numbers() = dimension_numbers; - for (int64 window_bound : window_bounds) { - gather_request->add_window_bounds(window_bound); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation) { - OpRequest op_request; - ConditionalRequest* request = op_request.mutable_conditional_request(); - *request->mutable_predicate() = predicate; - *request->mutable_true_operand() = true_operand; - *request->mutable_true_computation() = true_computation.handle(); - *request->mutable_false_operand() = false_operand; - *request->mutable_false_computation() = false_computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { - OpRequest op_request; - ReduceRequest* request = op_request.mutable_reduce_request(); - *request->mutable_operand() = operand; - *request->mutable_init_value() = init_value; - for (int64 dimension : dimensions_to_reduce) { - request->add_dimensions(dimension); - } - *request->mutable_to_apply() = computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReduceAll( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - std::vector all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie())); - std::iota(all_dimnos.begin(), all_dimnos.end(), 0); - return Reduce(operand, init_value, computation, all_dimnos); -} - -ComputationDataHandle ComputationBuilder::ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - Status padding_valid = - ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides); - if (!padding_valid.ok()) { - first_error_ = padding_valid; - return ComputationDataHandle(); - } - - std::vector> padding_values = - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); -} - -ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - OpRequest op_request; - ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); - *request->mutable_operand() = operand; - *request->mutable_to_apply() = computation.handle(); - *request->mutable_init_value() = init_value; - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormTraining( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormTrainingRequest* request = - op_request.mutable_batch_norm_training_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormInferenceRequest* request = - op_request.mutable_batch_norm_inference_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - *request->mutable_mean() = mean; - *request->mutable_variance() = variance; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormGrad( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& mean, const ComputationDataHandle& var, - const ComputationDataHandle& grad_output, float epsilon, - int64 feature_index) { - OpRequest op_request; - BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_mean() = mean; - *request->mutable_variance() = var; - *request->mutable_grad_output() = grad_output; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - OpRequest op_request; - CrossReplicaSumRequest* request = - op_request.mutable_cross_replica_sum_request(); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - return SelectAndScatterWithGeneralPadding( - operand, select, window_dimensions, window_strides, - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding), - source, init_value, scatter); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - OpRequest op_request; - SelectAndScatterRequest* request = - op_request.mutable_select_and_scatter_request(); - *request->mutable_operand() = operand; - *request->mutable_select() = select.handle(); - *request->mutable_source() = source; - *request->mutable_init_value() = init_value; - *request->mutable_scatter() = scatter.handle(); - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReducePrecision( - const ComputationDataHandle& operand, const int exponent_bits, - const int mantissa_bits) { - OpRequest op_request; - ReducePrecisionRequest* request = - op_request.mutable_reduce_precision_request(); - *request->mutable_operand() = operand; - request->set_exponent_bits(exponent_bits); - request->set_mantissa_bits(mantissa_bits); - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Send(const ComputationDataHandle& operand, - const ChannelHandle& handle) { - OpRequest op_request; - SendRequest* request = op_request.mutable_send_request(); - *request->mutable_operand() = operand; - *request->mutable_channel_handle() = handle; - *op_request.mutable_computation() = computation_.handle(); - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, - const ChannelHandle& handle) { - OpRequest op_request; - RecvRequest* request = op_request.mutable_recv_request(); - *request->mutable_shape() = shape; - *request->mutable_channel_handle() = handle; - return RunOpAndParseResponse(&op_request); -} - -Computation ComputationBuilder::BuildAndNoteError() { - DCHECK(parent_builder_ != nullptr); - auto build_status = Build(); - if (!build_status.ok()) { - parent_builder_->NoteError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); - return Computation(); - } - return build_status.ConsumeValueOrDie(); -} - -StatusOr ComputationBuilder::Build() { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } - - if (computation_.IsNull()) { - return FailedPrecondition("no computation was built"); - } - - return {std::move(computation_)}; -} - -/* static */ ConvolutionDimensionNumbers -ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(kConvBatchDimension); - dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_output_batch_dimension(kConvBatchDimension); - dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_kernel_output_feature_dimension( - kConvKernelOutputDimension); - dimension_numbers.set_kernel_input_feature_dimension( - kConvKernelInputDimension); - for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(i + 2); - dimension_numbers.add_kernel_spatial_dimensions(i + 2); - dimension_numbers.add_output_spatial_dimensions(i + 2); - } - return dimension_numbers; -} - -/* static */ StatusOr -ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({input_batch, input_feature, input_first_spatial, - input_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", - input_batch, input_feature, input_first_spatial, input_second_spatial); - } - if (std::set({kernel_output_feature, kernel_input_feature, - kernel_first_spatial, kernel_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", - kernel_output_feature, kernel_input_feature, kernel_first_spatial, - kernel_second_spatial); - } - if (std::set({output_batch, output_feature, output_first_spatial, - output_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", - output_batch, output_feature, output_first_spatial, - output_second_spatial); - } - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(input_batch); - dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.add_input_spatial_dimensions(input_first_spatial); - dimension_numbers.add_input_spatial_dimensions(input_second_spatial); - dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); - dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); - dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); - dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_output_spatial_dimensions(output_first_spatial); - dimension_numbers.add_output_spatial_dimensions(output_second_spatial); - return dimension_numbers; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h deleted file mode 100644 index e3facb3f258bb0bdf4b2dd3648e55421dfd56e79..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ /dev/null @@ -1,1068 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stacktrace.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Wraps an XLA client with a convenient interface for building up -// computations. Any errors encountered in building up the computation are -// deferred from being handled until Build() is called. -// -// Thread-compatible. -class ComputationBuilder { - public: - // client: client in which to build the computation. - // computation_name: name to use for the built computation. - ComputationBuilder(Client* client, const string& computation_name); - - ~ComputationBuilder(); - - // Returns the client the builder was initialized with. - Client* client() const { return client_; } - - // Returns the computation name. - const string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - // Retrieves the (inferred) shape of the operand in the computation. - StatusOr> GetShape( - const ComputationDataHandle& operand); - - // Retrieves the (inferred) result for the current computation's shape. - StatusOr GetProgramShape(); - - // Checks that the operand has the given expected shape. Returns the operand - // if yes, fails with a CHECK error if no. - ComputationDataHandle CheckShape(const ComputationDataHandle& operand, - const Shape& expected_shape); - - // Checks that the lhs and rhs results have the same shape. - void CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - - // Enqueues a constant with the value of the given literal onto the - // computation. - ComputationDataHandle ConstantLiteral(const Literal& literal); - - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. - template - ComputationDataHandle ConstantR0(NativeT value); - template - ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice values); - ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); - template - ComputationDataHandle ConstantR2( - std::initializer_list> values); - template - ComputationDataHandle ConstantFromArrayWithLayout( - const Array& values, const Layout& layout); - template - ComputationDataHandle ConstantFromArray(const Array& values); - template - ComputationDataHandle ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); - template - ComputationDataHandle ConstantR2FromArray2D(const Array2D& values); - template - ComputationDataHandle ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); - template - ComputationDataHandle ConstantR3FromArray3D(const Array3D& values); - template - ComputationDataHandle ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); - template - ComputationDataHandle ConstantR4FromArray4D(const Array4D& values); - - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. - template - ComputationDataHandle ConstantR1(int64 length, NativeT value); - - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); - - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); - - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); - - // Enqueues an operation onto the computation that collapses the operand, from - // minor to major order, then reshapes it into the shape with the given - // dimension sizes, also from major to minor. Conceptually, this is a limited - // form of "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes); - - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); - - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); - - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); - - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); - - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. - void Trace(const string& tag, const ComputationDataHandle& operand); - - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - - // Enqueues a tuple-creation instruction onto the computation. - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); - - // Enqueues a tuple-element-get instruction onto the computation. - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); - - // Enqueues an equal-to comparison instruction onto the computation. - ComputationDataHandle Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a not-equal comparison instruction onto the computation. - ComputationDataHandle Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-or-equal comparison instruction onto the computation. - ComputationDataHandle Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-than comparison instruction onto the computation. - ComputationDataHandle Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-than comparison instruction onto the computation. - ComputationDataHandle Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-or-equal comparison instruction onto the computation. - ComputationDataHandle Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a dot instruction onto the computation. - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - - // Enqueues a general dot instruction onto the computation. - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); - - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an - // error if either the input or the weight dimension numbers have conflicts. - static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial); - - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. - ComputationDataHandle Conv(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). - ComputationDataHandle ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. - ComputationDataHandle ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. - ComputationDataHandle ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. - ComputationDataHandle Fft(const ComputationDataHandle& operand, - FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); - - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. - ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); - - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. - void Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, const string& outfeed_config); - - // Enqueues a call instruction onto the computation. - ComputationDataHandle Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands); - - // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - ComputationDataHandle CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - ComputationDataHandle HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape); - - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. - ComputationDataHandle Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a complex conjugate instruction onto the computation. - ComputationDataHandle Conj(const ComputationDataHandle& operand); - - // Enqueues an add instruction onto the computation. - ComputationDataHandle Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a subtract instruction onto the computation. - ComputationDataHandle Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a multiply instruction onto the computation. - ComputationDataHandle Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a divide instruction onto the computation. - ComputationDataHandle Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a remainder instruction onto the computation. - ComputationDataHandle Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a max instruction onto the computation. - ComputationDataHandle Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a min instruction onto the computation. - ComputationDataHandle Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Element-wise logical operators - ComputationDataHandle And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Not(const ComputationDataHandle& operand); - - ComputationDataHandle ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. - ComputationDataHandle ReduceAll(const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const Computation& computation); - - // Enqueues a windowed reduce instruction onto the computation. - ComputationDataHandle ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); - - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); - - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. - ComputationDataHandle SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // Enqueues an abs instruction onto the computation. - ComputationDataHandle Abs(const ComputationDataHandle& operand); - - // Enqueues a atan2 instruction onto the computation. - ComputationDataHandle Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an exp instruction onto the computation. - ComputationDataHandle Exp(const ComputationDataHandle& operand); - - // Enqueues a floor instruction onto the computation. - ComputationDataHandle Floor(const ComputationDataHandle& operand); - - // Enqueues a ceil instruction onto the computation. - ComputationDataHandle Ceil(const ComputationDataHandle& operand); - - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. - ComputationDataHandle Round(const ComputationDataHandle& operand); - - // Enqueues an log instruction (natural logarithm) onto the computation. - ComputationDataHandle Log(const ComputationDataHandle& operand); - - // Enqueues a sign instruction onto the computation. - ComputationDataHandle Sign(const ComputationDataHandle& operand); - - // Enqueues a cosine instruction onto the computation. - ComputationDataHandle Cos(const ComputationDataHandle& operand); - - // Enqueues a sine instruction onto the computation. - ComputationDataHandle Sin(const ComputationDataHandle& operand); - - // Enqueues a tanh instruction onto the computation. - ComputationDataHandle Tanh(const ComputationDataHandle& operand); - - // Enqueues a real-part instruction onto the computation. - ComputationDataHandle Real(const ComputationDataHandle& operand); - - // Enqueues an imaginary-part instruction onto the computation. - ComputationDataHandle Imag(const ComputationDataHandle& operand); - - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - ComputationDataHandle SquareF32(const ComputationDataHandle& operand); - - // Enqueues a lhs^rhs computation onto the computation. - ComputationDataHandle Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. - ComputationDataHandle IsFinite(const ComputationDataHandle& operand); - - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. - ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); - - // Enqueues a negate instruction onto the computation. - ComputationDataHandle Neg(const ComputationDataHandle& operand); - - // Enqueues a transpose instruction onto the computation. - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); - - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a sort (as increasing order) instruction onto the computation. - ComputationDataHandle Sort(const ComputationDataHandle& operand); - - // Enqueues a clamp instruction onto the computation. - ComputationDataHandle Clamp(const ComputationDataHandle& min, - const ComputationDataHandle& operand, - const ComputationDataHandle& max); - - // Enqueues a map instruction onto the computation. - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); - - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); - - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); - - // Enqueues a while node onto the computation. - ComputationDataHandle While(const Computation& condition, - const Computation& body, - const ComputationDataHandle& init); - - // Enqueues a conditional node onto the computation. - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation); - - // Enqueues a ReducePrecision node onto the computation. - ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, - const int exponent_bits, - const int mantissa_bits); - - // Enqueues a Gather node onto the computation. - ComputationDataHandle Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); - - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. - void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); - - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. - ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with index greater than or equal to - // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. - // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a - // compile-time constant without evaluating the computation. - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters = 0); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. - ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& offset, - float epsilon, int64 feature_index); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. - ComputationDataHandle BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, - int64 feature_index); - - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` - ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, - float epsilon, int64 feature_index); - - // Computes the value of a constant indicated by a - // ComputationDataHandle using a non-optimized interpreter on the host. - // - // The operand must be from the computation currently being built - - // i.e., returned from this builder with no intervening call to - // Build(). This happens to currently work regardless of that, but - // that may stop working at any time. - // - // The operand must represent a constant value, which in this case - // means that it must not statically depend on any parameter of the - // computation that is being built other then the ones specified on the - // parameter list. The parameters in the list will be indexed by their - // parameter id property so the number of parameters specified should be at - // least as many as the largest used parameter index. - // - // `IsConstant` can be used to test whether a computation is a compile-time - // constant without evaluation it. `ComputeConstant` only succeeds for - // computations where `IsConstant` returns true. - // - // This functionality can be useful when translating a computation - // into XLA where something that looked dynamic is required by - // XLA to be specified as a constant. E.g. the source - // computation (outside of XLA) may include a dynamic - // computation of the shape of something and ComputeConstant lets - // you determine what the value of that computation is in the case - // where the value can be determined at compile time. - // - // If output_layout is non-null, then the output of the computation - // will be stored using that layout. - StatusOr> ComputeConstant( - const ComputationDataHandle& operand, - const Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}); - - // Returns a new ComputationBuilder whose resultant Computation is used only - // by this ComputationBuilder. The sub-ComputationBuilder has the same - // die_immediately_on_error behavior as the parent. - std::unique_ptr CreateSubBuilder( - const string& computation_name); - - // Modifies the computation being built so that executions of it - // will return the value associated with operand, rather than the - // last expression enqueued on the ComputationBuilder. Any subsequent - // operations added to the ComputationBuilder will not have any effect unless - // SetReturnValue is called again. - Status SetReturnValue(const ComputationDataHandle& operand); - - // Builds the computation with the requested operations, or returns a non-ok - // status. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent ComputationBuilder and returns an empty computation if building - // failed. This function is intended to be used where the returned - // Computation is only used by the parent ComputationBuilder and hence further - // operation on the returned Computation will simply be error'ed out if an - // error occurred while building this computation. If the built computation is - // to be used by a ComputationBuilder other than the parent ComputationBuilder - // then Build() should be used instead. - Computation BuildAndNoteError(); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // ComputationDataHandle and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - private: - // Limited checking of convolution parameters. Returns false on - // error. - bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers); - - // The parent ComputationBuilder of a sub-ComputationBuilder. The - // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. - ComputationBuilder* parent_builder_{nullptr}; - - // Helper function for creating a Window proto from user-supplied - // data. Returns true if the user-supplied data was valid. - bool MakeWindow(tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - Window* window); - - // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation binop, - const ComputationDataHandle& operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. - ComputationDataHandle BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that does the building for an arbitrary ternary op. - ComputationDataHandle TernaryOp(TernaryOperation triop, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - const ComputationDataHandle& ehs); - - // Internal helper method that does the building for a random number generator - // of a given distribution with an explicitly specified shape. - ComputationDataHandle RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); - - // Populates computation_ with a valid object or returns a failing status. - // This is used before any given operation is enqueued. - Status PrepareComputation(); - - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - // Helper function that runs the given op_request, filling in op_response. - // Before the op is run, PrepareComputation is called, and common fields in - // the op_request are filled in. - Status RunOp(OpRequest* op_request, OpResponse* op_response); - - // Helper function that calls RunOp and calls NoteError on failures. - void RunOpAndNoteError(OpRequest* op_request); - - // Helper function that calls RunOp and either returns the output computation - // data handle (on success) or a vacuous computation data handle (on failure). - ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); - - // Helper function that implements GetShape without noting errors. This makes - // it easier to ensure the real GetShape will note errors on every error path. - StatusOr> GetShapeWithoutNoteError( - const ComputationDataHandle& operand); - - string name_; // Name to use for the built computation. - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tensorflow::SavedStackTrace first_error_backtrace_; - - // The computation that operations are enqueued onto. - Computation computation_; - - // The client that the computation is created in. Not owned. - Client* client_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); -}; - -template -ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1( - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, - NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline ComputationDataHandle ComputationBuilder::ConstantR1( - const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( - const Array& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArray( - const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( - const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( - const Array3D& values) { - return ConstantFromArray(values); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( - const Array4D& values) { - return ConstantFromArray(values); -} - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class ScopedShardingAssignment { - public: - ScopedShardingAssignment(xla::ComputationBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::ComputationBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 804e34f5e75ce2d153ac7627b94a543fda88e810..7dee41f6a05025ec196b78e54015e8e71777031f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -76,4 +76,47 @@ ExecutableBuildOptions::generate_hlo_graph() const { return generate_hlo_graph_; } +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_optimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { + return dump_optimized_hlo_proto_to_; +} + +ExecutableBuildOptions& +ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { + return dump_unoptimized_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { + return dump_per_pass_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { + hlo_profile_ = enabled; + return *this; +} + +tensorflow::gtl::optional ExecutableBuildOptions::hlo_profile() const { + return hlo_profile_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 3a52dbac9adb155ad9a7d91a8102707f70fe2fbf..9dc9be4423564fb967b247c2d1df31099cb80237 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -57,16 +59,53 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_generate_hlo_graph(string regex); const tensorflow::gtl::optional& generate_hlo_graph() const; + // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + + // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() + const; + + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs + // to (as in DebugOptions). + ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_per_pass_hlo_proto_to() const; + + // If true, specifies that we should record an HLO profile during execution + // and log it after execution (as in DebugOptions). If nullopt the default is + // used. + ExecutableBuildOptions& set_hlo_profile(bool enabled); + tensorflow::gtl::optional hlo_profile() const; + + void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + disabled_hlo_passes_.push_back(std::string(pass_name)); + } + const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + return disabled_hlo_passes_; + } + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: + tensorflow::gtl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; + tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; + tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; + std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 40f59eaa68ebeb47edbd2afbeabad0cd2623ebc6..2986d4060013703873b2cffb6aacbb012606d16f 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -31,7 +31,7 @@ GlobalData::~GlobalData() { *request.mutable_data() = handle_; UnregisterResponse response; VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - tensorflow::Status s = parent_->Unregister(&request, &response); + Status s = parent_->Unregister(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index fca2bf2688cd21b44f099da3bae3b890cbb069ab..d49d959a6c8112d3701857a70cecb24701c7b6d9 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -22,8 +22,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -41,24 +41,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 24048a1e5a782661ba577ba50e3b5b2914f17c0a..a1d34796ccfd86f2025eff0ecb51338eb6a9b1da 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,7 +17,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -26,14 +27,13 @@ limitations under the License. namespace xla { namespace { -using InstructionGenerator = - ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, - const ComputationDataHandle&); - -Computation CreateScalarComputation(const string& name, PrimitiveType type, - ComputationBuilder* builder, - InstructionGenerator generator) { - std::unique_ptr b; + +using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); + +XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, + XlaBuilder* builder, + XlaOpGenerator generator) { + std::unique_ptr b; if (type == PRED) { b = builder->CreateSubBuilder(name); } else { @@ -47,69 +47,76 @@ Computation CreateScalarComputation(const string& name, PrimitiveType type, generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } + } // namespace -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder) { +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( "add", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); }); + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Add(lhs, rhs); + }); } -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder) { +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( - "add", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); + "mul", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Mul(lhs, rhs); + }); } -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder) { +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( "ge", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); }); + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Ge(lhs, rhs); + }); } -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder) { +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( "max", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); }); + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Max(lhs, rhs); + }); } -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder) { +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( "min", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Min(lhs, rhs); + }); } -Computation CreateScalarAndComputation(ComputationBuilder* builder) { +XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { return CreateScalarComputation( "and", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->And(lhs, rhs); + }); } -Computation CreateScalarOrComputation(ComputationBuilder* builder) { +XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { return CreateScalarComputation( "or", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Or(lhs, rhs); + }); } -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder) { +StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); + std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); std::iota(all_dimensions.begin(), all_dimensions.end(), 0); return builder->Reduce(predicates, f, logical_or, all_dimensions); } diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index ae89784bc227d837cf15f0a89687dd00dccc2745..64b6b7d63353165e45bf12d35126a7eeef9e56e4 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,43 +18,42 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { // Creates a scalar add computation and returns it. -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder); +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar multiply computation and returns it. -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder); +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar ge computation and returns it. -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder); +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar max computation and returns it. -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder); +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar min computation and returns it. -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder); +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar logical AND computation and returns it. -Computation CreateScalarAndComputation(ComputationBuilder* builder); +XlaComputation CreateScalarAndComputation(XlaBuilder* builder); // Creates a scalar logical OR computation and returns it. -Computation CreateScalarOrComputation(ComputationBuilder* builder); +XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder); +StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index b63a1465ea755b906853860d47768ecbeaa0dcdd..3380af9f303b1dc2cec09aa37410ec40cdeaa526 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -46,16 +45,14 @@ int64 DataSizeOfShape(const Shape& shape) { return total_size; } -// Create a ComputationDataHandle for an op what generates fake data with the -// given shape. -ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, - ComputationBuilder* builder) { +// Creates a XlaOp for an op what generates fake data with the given shape. +XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return builder->Broadcast( builder->ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } - std::vector parts; + std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } @@ -64,11 +61,10 @@ ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - ComputationBuilder b( - client, + XlaBuilder b( tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); - Computation computation = b.Build().ConsumeValueOrDie(); + XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; @@ -97,14 +93,15 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, } std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); + const XlaComputation& computation, Client* client) { + CHECK(computation.proto().has_program_shape()) + << "Computation should have progran shape."; + auto program_shape = computation.proto().program_shape(); // For every (unbound) parameter that the computation wants, we manufacture // some arbitrary data so that we can invoke the computation. std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { + for (const Shape& parameter : program_shape.parameters()) { fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); } diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 7e640d1307edcc3e2c021f4391c456f578a015ee..dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -34,9 +34,9 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. +// xla computation. std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); + const XlaComputation& computation, Client* client); } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 91396f055fe4a3ecbd436139be9470e2a35e1c63..ae0308020d014e038d2f0fd7de6c5f372d6cbed1 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -24,8 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace se = ::perftools::gputools; - using xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -50,30 +48,52 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, << "Must have a valid device ordinal that the executable was built for."; } -tensorflow::Status LocalExecutable::ValidateExecutionOptions( +Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& computation_layout = - executable_->module_config().entry_computation_layout(); + const ComputationLayout& host_computation_layout = + executable_->module_config().host_entry_computation_layout(); + const ComputationLayout& device_computation_layout = + executable_->module_config().device_entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != computation_layout.parameter_count()) { + if (arguments.size() != host_computation_layout.parameter_count()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + host_computation_layout.parameter_count(), arguments.size()); + } + if (arguments.size() != device_computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - computation_layout.parameter_count(), arguments.size()); + device_computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, - "Argument does not match shape or layout of computation parameter " + "Argument does not match host shape or layout of computation " + "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString( + host_computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } + if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( + arguments[i]->on_device_shape())) { + return InvalidParameterArgument( + executable_.get(), i, + "Argument does not match device shape or layout of computation " + "parameter " + "%d: want %s, got %s", + i, + ShapeUtil::HumanString( + device_computation_layout.parameter_layout(i).shape()) + .c_str(), + ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); + } } if (run_options.stream() != nullptr) { @@ -136,7 +156,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( return Status::OK(); } -StatusOr> LocalExecutable::Run( +StatusOr LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( @@ -165,51 +185,46 @@ StatusOr> LocalExecutable::Run( run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); - if (executable_->dumping()) { + if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr result, - executable_->ExecuteOnStreamWrapper( - &service_options, run_options.execution_profile(), arguments)); - - return MakeUnique(std::move(*result), - run_options.allocator()); + return executable_->ExecuteOnStreamWrapper( + &service_options, run_options.execution_profile(), arguments); } -StatusOr> LocalExecutable::ExecuteAndDump( +StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { - executable_->session_module()->set_execution_platform( + executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result, + ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(result.get(), executable_->session_module())); - TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); - return ScopedShapedBuffer::MakeScoped(result.get(), run_options->allocator()); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); + return std::move(result); } -tensorflow::Status LocalExecutable::RecordArguments( +Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module) { - session_module->clear_arguments(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*argument)); - *session_module->add_arguments() = literal->ToProto(); + *hlo_snapshot->add_arguments() = literal->ToProto(); } return Status::OK(); } -tensorflow::Status LocalExecutable::RecordResult( - const ShapedBuffer* result, SessionModule* session_module) { - session_module->clear_result(); +Status LocalExecutable::RecordResult(const ShapedBuffer* result, + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); - *session_module->mutable_result() = literal->ToProto(); + *hlo_snapshot->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -247,7 +262,7 @@ Backend* LocalClient::mutable_backend() { } StatusOr> LocalClient::Compile( - const Computation& computation, + const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options) { ExecutableBuildOptions updated_options = options; @@ -256,18 +271,17 @@ StatusOr> LocalClient::Compile( VLOG(3) << "Set device ordinal to default value of: " << updated_options.device_ordinal(); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + local_service_->CompileExecutable( + computation, argument_layouts, updated_options)); return WrapUnique(new LocalExecutable(std::move(executable), local_service_->mutable_backend(), updated_options)); } -StatusOr> -LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, - DeviceMemoryAllocator* allocator) { +StatusOr LocalClient::LiteralToShapedBuffer( + const Literal& literal, int device_ordinal, + DeviceMemoryAllocator* allocator) { if (allocator == nullptr) { allocator = backend().memory_allocator(); } @@ -277,7 +291,7 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, literal, *scoped_buffer)); + executor, literal, scoped_buffer)); return std::move(scoped_buffer); } @@ -290,6 +304,11 @@ StatusOr> LocalClient::ShapedBufferToLiteral( shaped_buffer); } +StatusOr LocalClient::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + return local_service_->GlobalDataToShapedBuffer(data, replica_number); +} + Status LocalClient::TransferToInfeedLocal(const Literal& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index b52a30f5a0b92e0094e6b0de3241c10a5a909cad..4d9e0d7cd9d6ddebead1e12b23e94b529038039b 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,12 +19,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -38,19 +39,10 @@ class LocalExecutable { public: // Run the compiled computation with the given arguments and options and // return the result. - StatusOr> Run( + StatusOr Run( const tensorflow::gtl::ArraySlice arguments, ExecutableRunOptions run_options); - // Return the layout (contained in a shape) of the result produced by the - // computation. - const Shape& result_layout() const { - return executable_->module_config() - .entry_computation_layout() - .result_layout() - .shape(); - } - // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -67,25 +59,30 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. - tensorflow::Status ValidateExecutionOptions( + // + // The given ExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). + Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, const Backend& backend); + const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. - StatusOr> ExecuteAndDump( + // + // The given ServiceExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). + StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - tensorflow::Status RecordArguments( + Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module); + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. - tensorflow::Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -117,9 +114,12 @@ class LocalClient : public Client { void operator=(const LocalClient&) = delete; // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. + // the given XlaComputation, argument layouts and options. + // + // The given ExecutableBuildOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr> Compile( - const Computation& computation, + const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); @@ -127,7 +127,7 @@ class LocalClient : public Client { // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the // device is used. - StatusOr> LiteralToShapedBuffer( + StatusOr LiteralToShapedBuffer( const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator = nullptr); @@ -136,6 +136,11 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + // Transfer the given literal to the infeed queue of the given device. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with @@ -158,7 +163,7 @@ class LocalClient : public Client { StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); // Returns the platform that the underlying service targets. - perftools::gputools::Platform* platform() const; + se::Platform* platform() const; // Returns the number of devices on the system of the service platform // type. Not all devices may be supported by the service (see diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..507a2dc5f088e159156f0ef3d663ba2819f6a2d4 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -0,0 +1,78 @@ +# Description: +# The new XLA client libraries. +# +# This is NOT YET ready to use. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + deps = [ + ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae8fbdb2dc8b3f4fc21bcfef9692645f7e1d48b5 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -0,0 +1,2001 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { + +using tensorflow::strings::StrCat; + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 built_counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = built_counter++; + return id; +} + +// Returns true if an instruction with the given opcode can be the root of the +// computation. +bool CanBeRoot(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kOutfeed: + case HloOpcode::kTrace: + return false; + default: + return true; + } +} + +} // namespace + +StatusOr XlaBuilder::GetShape(const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + + TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); + return instr->shape(); +} + +StatusOr> XlaBuilder::GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + operand_shapes.push_back(shape); + } + return operand_shapes; +} + +XlaBuilder::XlaBuilder(const string& computation_name) + : name_(computation_name) {} + +XlaBuilder::~XlaBuilder() {} + +void XlaBuilder::NoteError(const Status& error) { + CHECK(!error.ok()); + if (die_immediately_on_error_) { + LOG(FATAL) << "error building computation: " << error; + } + + if (first_error_.ok()) { + first_error_ = error; + first_error_backtrace_.CreateCurrent(/*skip_count=*/1); + } +} + +XlaOp XlaBuilder::NoteErrorOrReturn( + const std::function()>& op_creator) { + if (!first_error_.ok()) { + return {}; + } + auto op = op_creator(); + if (!op.ok()) { + NoteError(op.status()); + return {}; + } + return op.ConsumeValueOrDie(); +} + +StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { + TF_RETURN_IF_ERROR(first_error_); + + TF_RET_CHECK(root_id != nullptr); + + ProgramShape program_shape; + + // Not all instructions can be roots. Walk backwards from the last added + // instruction until a valid root is found. + int64 index = instructions_.size() - 1; + for (; index >= 0; index--) { + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instructions_[index].opcode())); + if (CanBeRoot(opcode)) { + break; + } + } + if (index < 0) { + return FailedPrecondition("no root instruction was found"); + } + *root_id = instructions_[index].id(); + *program_shape.mutable_result() = instructions_[index].shape(); + + // Check that the parameter numbers are continuous from 0, and add parameter + // shapes and names to the program shape. + const int64 param_count = parameter_numbers_.size(); + for (int64 i = 0; i < param_count; i++) { + program_shape.add_parameters(); + program_shape.add_parameter_names(); + } + for (const HloInstructionProto& instr : instructions_) { + // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So + // to verify continuity, we just need to verify that every parameter is in + // the right range. + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + const int64 index = instr.parameter_number(); + TF_RET_CHECK(index >= 0 && index < param_count) + << "invalid parameter number: " << index; + *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameter_names(index) = instr.name(); + } + } + return program_shape; +} + +StatusOr XlaBuilder::GetProgramShape() const { + int64 root; + return GetProgramShape(&root); +} + +void XlaBuilder::IsConstantVisitor(const int64 op_handle, + std::set* visited, + bool* is_constant) const { + if (visited->count(op_handle) != 0 || !*is_constant) { + return; + } + + CHECK(op_handle < instructions_.size() && op_handle >= 0); + + const HloInstructionProto& instr = instructions_[op_handle]; + const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); + switch (opcode) { + default: + for (const int64 operand_id : instr.operand_ids()) { + IsConstantVisitor(operand_id, visited, is_constant); + } + // TODO(b/32495713): We aren't checking the called computations. + break; + + // Non functional ops. + case HloOpcode::kRng: + case HloOpcode::kCrossReplicaSum: + // TODO(b/33009255): Implmement constant folding for cross replica sum. + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kHostCompute: + case HloOpcode::kCall: + // TODO(b/32495713): We aren't checking the to_apply computation itself, + // so we conservatively say that computations containing the Call op + // cannot be constant. We cannot set is_functional=false in other similar + // cases since we're already relying on IsConstant to return true. + case HloOpcode::kCustomCall: + case HloOpcode::kWhile: + // TODO(b/32495713): We aren't checking the condition and body + // computations themselves. + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kParameter: + *is_constant = false; + break; + } + if (!*is_constant) { + VLOG(1) << "Non-constant: " << instr.name(); + } + visited->insert(op_handle); +} + +XlaComputation XlaBuilder::BuildAndNoteError() { + DCHECK(parent_builder_ != nullptr); + auto build_status = Build(); + if (!build_status.ok()) { + parent_builder_->NoteError( + AddStatus(build_status.status(), + tensorflow::strings::StrCat("error from: ", name_))); + return {}; + } + return build_status.ConsumeValueOrDie(); +} + +StatusOr XlaBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + HloComputationProto entry; + entry.set_id(GetUniqueId()); // Give the computation a global unique id. + entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. + + { + int64 root_id; + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), + GetProgramShape(&root_id)); + entry.set_root_id(root_id); + } + + for (auto& instruction : instructions_) { + // Ensures that the instruction names are unique among the whole graph. + const string& new_name = + StrCat(instruction.name(), ".", entry.id(), ".", instruction.id()); + instruction.set_name(new_name); + entry.add_instructions()->Swap(&instruction); + } + + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_program_shape() = entry.program_shape(); + for (auto& e : embedded_) { + module->add_computations()->Swap(&e.second); + } + module->add_computations()->Swap(&entry); + + // Clear data held by this builder. + this->instructions_.clear(); + this->embedded_.clear(); + this->parameter_numbers_.clear(); + + return std::move(computation); +} + +StatusOr XlaBuilder::InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + TF_RETURN_IF_ERROR(first_error_); + + HloInstructionProto instr; + *instr.mutable_shape() = shape; + for (int64 dim : broadcast_dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); +} + +StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + CHECK(ShapeUtil::IsScalar(operand_shape) || + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = + ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); + + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand_shape)) { + return InDimBroadcast(broadcast_shape, operand, {}); + } + + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand_shape.dimensions(i)); + } else { + TF_RET_CHECK(operand_shape.dimensions(i) == 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand shape: " + << operand_shape << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + Reshape(ShapeUtil::MakeShape(operand_shape.element_type(), + reshaped_dimensions), + operand)); + // Broadcast 'reshape' up to the larger size. + return InDimBroadcast(broadcast_shape, reshaped_operand, + broadcast_dimensions); +} + +XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferUnaryOpShape(unop, operand_shape)); + return AddInstruction(std::move(instr), unop, {operand}); + }); +} + +XlaOp XlaBuilder::BinaryOp( + HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBinaryOpShape( + binop, lhs_shape, rhs_shape, broadcast_dimensions)); + + const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); + const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + + if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { + const bool should_broadcast_lhs = lhs_rank < rhs_rank; + XlaOp from = should_broadcast_lhs ? lhs : rhs; + const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; + + std::vector to_size; + for (int64 size : instr.shape().dimensions()) { + to_size.push_back(size); + } + for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); + from_dim++) { + int64 to_dim = broadcast_dimensions[from_dim]; + to_size[to_dim] = from_shape.dimensions(from_dim); + } + + const Shape& broadcasted_shape = + ShapeUtil::MakeShape(from_shape.element_type(), to_size); + TF_ASSIGN_OR_RETURN( + XlaOp broadcasted_operand, + InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); + + updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; + updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; + } + + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), updated_lhs)); + } + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), updated_rhs)); + } + + return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); + }); +} + +XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferTernaryOpShape( + triop, lhs_shape, rhs_shape, ehs_shape)); + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + XlaOp updated_ehs = ehs; + if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(lhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + // lhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), lhs)); + } + if (!ShapeUtil::IsTuple(rhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + // rhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), rhs)); + } + if (!ShapeUtil::IsTuple(ehs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + // ehs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_ehs, + AddBroadcastSequence(instr.shape(), ehs)); + } + } + return AddInstruction(std::move(instr), triop, + {updated_lhs, updated_rhs, updated_ehs}); + }); +} + +XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = literal.shape(); + *instr.mutable_literal() = literal.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kConstant); + }); +} + +XlaOp XlaBuilder::Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCallShape(operand_shape_ptrs, + /*to_apply=*/called_program_shape)); + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kCall, operands); + }); +} + +XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, + const string& name) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (!parameter_numbers_.insert(parameter_number).second) { + return InvalidArgument("parameter %lld already registered", + parameter_number); + } + instr.set_parameter_number(parameter_number); + instr.set_name(name); + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kParameter); + }); +} + +XlaOp XlaBuilder::Broadcast( + const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + const Shape& shape, + ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); + + // The client-level broadcast op just appends dimensions on the left (adds + // lowest numbered dimensions). The HLO broadcast instruction is more + // flexible and can add new dimensions anywhere. The instruction's + // dimensions field maps operand dimensions to dimensions in the broadcast + // output, so to append dimensions on the left the instruction's dimensions + // should just be the n highest dimension numbers of the output shape where + // n is the number of input dimensions. + const int64 operand_rank = ShapeUtil::Rank(operand_shape); + std::vector dimensions(operand_rank); + for (int i = 0; i < operand_rank; ++i) { + dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + } + return InDimBroadcast(shape, operand, dimensions); + }); +} + +StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + + HloInstructionProto instr; + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); +} + +XlaOp XlaBuilder::Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferSliceShape(operand_shape, start_indices, + limit_indices, strides)); + for (int i = 0; i < start_indices.size(); i++) { + auto* slice_config = instr.add_slice_dimensions(); + slice_config->set_start(start_indices[i]); + slice_config->set_limit(limit_indices[i]); + slice_config->set_stride(strides[i]); + } + + return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); + }); +} + +XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + std::vector starts(ShapeUtil::Rank(shape), 0); + std::vector limits(shape.dimensions().begin(), + shape.dimensions().end()); + std::vector strides(ShapeUtil::Rank(shape), 1); + starts[dimno] = start_index; + limits[dimno] = limit_index; + strides[dimno] = stride; + return Slice(operand, starts, limits, strides); + }); +} + +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shape, slice_sizes)); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, + {operand, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, start_indices_shape)); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + + instr.add_dimensions(dimension); + + return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); + }); +} + +XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape, + GetShape(padding_value)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferPadShape(operand_shape, padding_value_shape, + padding_config)); + + *instr.mutable_padding_config() = padding_config; + + return AddInstruction(std::move(instr), HloOpcode::kPad, + {operand, padding_value}); + }); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferReshapeShape( + operand_shape, dimensions, new_sizes)); + XlaOp transposed = IsIdentityPermutation(dimensions) + ? operand + : Transpose(operand, dimensions); + return Reshape(shape, transposed); + }); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); + std::vector dimensions(shape.dimensions_size()); + std::iota(dimensions.begin(), dimensions.end(), 0); + return Reshape(operand, dimensions, new_sizes); + }); +} + +XlaOp XlaBuilder::Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (dimensions.size() <= 1) { + // Not collapsing anything, trivially we can return the operand versus + // enqueueing a trivial reshape. + return operand; + } + + // Out-of-order collapse is not supported. + // Checks that the collapsed dimensions are in order and consecutive. + for (tensorflow::gtl::ArraySlice::size_type i = 1; + i < dimensions.size(); ++i) { + if (dimensions[i] - 1 != dimensions[i - 1]) { + return InvalidArgument( + "Collapsed dimensions are not in consecutive order."); + } + } + + // Create a new sizes vector from the old shape, replacing the collapsed + // dimensions by the product of their sizes. + TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand)); + + VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape); + VLOG(3) << "dims to collapse: " + << tensorflow::str_util::Join(dimensions, ","); + + std::vector new_sizes; + for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { + if (i <= dimensions.front() || i > dimensions.back()) { + new_sizes.push_back(original_shape.dimensions(i)); + } else { + new_sizes.back() *= original_shape.dimensions(i); + } + } + + VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") + << "]"; + + return Reshape(operand, new_sizes); + }); +} + +void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); + }); +} + +XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false) { + return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); +} + +XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kTuple, operand_shape_ptrs)); + return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); + }); +} + +XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); + if (!ShapeUtil::IsTuple(tuple_shape)) { + return InvalidArgument( + "Operand to GetTupleElement() is not a tuple; got %s", + ShapeUtil::HumanString(tuple_shape).c_str()); + } + *instr.mutable_shape() = + ShapeUtil::GetTupleElementShape(tuple_shape, index); + + instr.set_tuple_index(index); + + return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, + {tuple_data}); + }); +} + +XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape.dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); + }); +} + +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, + dimension_numbers)); + *instr.mutable_dot_dimension_numbers() = dimension_numbers; + return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); + }); +} + +Status XlaBuilder::VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const { + if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + return InvalidArgument( + "Convolution arguments must have same number of " + "dimensions. Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str()); + } + int num_dims = ShapeUtil::Rank(lhs_shape); + if (num_dims < 2) { + return InvalidArgument( + "Convolution expects argument arrays with >= 3 dimensions. " + "Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str()); + } + int num_spatial_dims = num_dims - 2; + + const auto check_spatial_dimensions = + [&](const char* const field_name, + const tensorflow::protobuf::RepeatedField& + numbers) { + if (numbers.size() != num_spatial_dims) { + return InvalidArgument("Expected %d elements for %s, but got %d.", + num_spatial_dims, field_name, numbers.size()); + } + for (int i = 0; i < numbers.size(); ++i) { + if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { + return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + field_name, i, numbers.Get(i)); + } + } + return Status::OK(); + }; + TF_RETURN_IF_ERROR( + check_spatial_dimensions("input_spatial_dimensions", + dimension_numbers.input_spatial_dimensions())); + TF_RETURN_IF_ERROR( + check_spatial_dimensions("kernel_spatial_dimensions", + dimension_numbers.kernel_spatial_dimensions())); + return check_spatial_dimensions( + "output_spatial_dimensions", + dimension_numbers.output_spatial_dimensions()); +} + +XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); +} + +XlaOp XlaBuilder::ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return ConvGeneral(lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); +} + +XlaOp XlaBuilder::ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + + TF_RETURN_IF_ERROR( + VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers)); + + std::vector base_area_dimensions( + dimension_numbers.input_spatial_dimensions_size()); + for (std::vector::size_type i = 0; i < base_area_dimensions.size(); + ++i) { + base_area_dimensions[i] = + lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i)); + } + + std::vector window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (std::vector::size_type i = 0; i < window_dimensions.size(); + ++i) { + window_dimensions[i] = + rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + + return ConvGeneral(lhs, rhs, window_strides, + MakePadding(base_area_dimensions, window_dimensions, + window_strides, padding), + dimension_numbers); + }); +} + +XlaOp XlaBuilder::ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, + dimension_numbers); +} + +XlaOp XlaBuilder::ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_RETURN_IF_ERROR( + VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers)); + + std::vector window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (std::vector::size_type i = 0; i < window_dimensions.size(); + ++i) { + window_dimensions[i] = + rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + lhs_dilation, rhs_dilation)); + + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), + dimension_numbers)); + + *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + + return AddInstruction(std::move(instr), HloOpcode::kConvolution, + {lhs, rhs}); + }); +} + +StatusOr XlaBuilder::MakeWindow( + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation) const { + const auto verify_size = [&](const size_t x, const char* x_name) { + if (x == 0 || x == window_dimensions.size()) { + return Status::OK(); + } else { + return InvalidArgument( + "%s", tensorflow::strings::StrCat( + "Window has different number of window dimensions than of ", + x_name, + "\nNumber of window dimensions: ", window_dimensions.size(), + "\nNumber of ", x_name, ": ", x, "\n") + .c_str()); + } + }; + TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); + TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); + TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); + TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); + + Window window; + for (size_t i = 0; i < window_dimensions.size(); i++) { + auto dim = window.add_dimensions(); + dim->set_size(window_dimensions[i]); + if (!window_strides.empty()) { + dim->set_stride(window_strides[i]); + } else { + dim->set_stride(1); + } + if (!padding.empty()) { + dim->set_padding_low(padding[i].first); + dim->set_padding_high(padding[i].second); + } else { + dim->set_padding_low(0); + dim->set_padding_high(0); + } + if (!lhs_dilation.empty()) { + dim->set_base_dilation(lhs_dilation[i]); + } else { + dim->set_base_dilation(1); + } + if (!rhs_dilation.empty()) { + dim->set_window_dilation(rhs_dilation[i]); + } else { + dim->set_window_dilation(1); + } + dim->set_window_reversal(false); + } + return window; +} + +XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); + + instr.set_fft_type(fft_type); + for (int64 i : fft_length) { + instr.add_fft_length(i); + } + + return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); + }); +} + +XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + *instr.mutable_shape() = shape; + instr.set_infeed_config(config); + return AddInstruction(std::move(instr), HloOpcode::kInfeed); + }); +} + +void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + *instr.mutable_shape() = ShapeUtil::MakeNil(); + + // Check and set outfeed shape. + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "Outfeed shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + *instr.mutable_outfeed_shape() = shape_with_layout; + + instr.set_outfeed_config(outfeed_config); + + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + }); +} + +XlaOp XlaBuilder::CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (tensorflow::str_util::StartsWith(call_target_name, "$")) { + return InvalidArgument( + "Invalid custom_call_target \"%s\": Call targets that start with '$' " + "are reserved for internal use.", + call_target_name.c_str()); + } + *instr.mutable_shape() = shape; + instr.set_custom_call_target(call_target_name); + return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); + }); +} + +XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, + int64 cost_estimate_ns, const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.set_channel_name(channel_name); + instr.set_cost_estimate_ns(cost_estimate_ns); + return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands); + }); +} + +XlaOp XlaBuilder::Complex( + const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); +} + +XlaOp XlaBuilder::Conj(const XlaOp& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} + +XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return Or(And(Not(lhs), rhs, broadcast_dimensions), + And(lhs, Not(rhs), broadcast_dimensions)); +} + +XlaOp XlaBuilder::Not(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNot, operand); +} + +XlaOp XlaBuilder::ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::Abs(const XlaOp& operand) { + return UnaryOp(HloOpcode::kAbs, operand); +} + +XlaOp XlaBuilder::Atan2( + const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); +} + +XlaOp XlaBuilder::Exp(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExp, operand); +} + +XlaOp XlaBuilder::Expm1(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExpm1, operand); +} + +XlaOp XlaBuilder::Floor(const XlaOp& operand) { + return UnaryOp(HloOpcode::kFloor, operand); +} + +XlaOp XlaBuilder::Ceil(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCeil, operand); +} + +XlaOp XlaBuilder::Round(const XlaOp& operand) { + return UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} + +XlaOp XlaBuilder::Log(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog, operand); +} + +XlaOp XlaBuilder::Log1p(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog1p, operand); +} + +XlaOp XlaBuilder::Sign(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSign, operand); +} + +XlaOp XlaBuilder::Clz(const XlaOp& operand) { + return UnaryOp(HloOpcode::kClz, operand); +} + +XlaOp XlaBuilder::Cos(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCos, operand); +} + +XlaOp XlaBuilder::Sin(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSin, operand); +} + +XlaOp XlaBuilder::Tanh(const XlaOp& operand) { + return UnaryOp(HloOpcode::kTanh, operand); +} + +XlaOp XlaBuilder::Real(const XlaOp& operand) { + return UnaryOp(HloOpcode::kReal, operand); +} + +XlaOp XlaBuilder::Imag(const XlaOp& operand) { + return UnaryOp(HloOpcode::kImag, operand); +} + +XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { + return UnaryOp(HloOpcode::kIsFinite, operand); +} + +XlaOp XlaBuilder::Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferTransposeShape(operand_shape, permutation)); + for (int64 dim : permutation) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); + }); +} + +XlaOp XlaBuilder::Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReverseShape(operand_shape, dimensions)); + for (int64 dim : dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand}); + }); +} + +XlaOp XlaBuilder::Sort(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSort, operand); +} + +XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvertShape(operand_shape, new_element_type)); + return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + }); +} + +XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvertShape(operand_shape, new_element_type)); + return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, + {operand}); + }); +} + +XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Neg(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNegate, operand); +} + +XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, + const XlaOp& max) { + return TernaryOp(HloOpcode::kClamp, min, operand, max); +} + +XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (!static_operands.empty()) { + return Unimplemented("static_operands is not supported in Map"); + } + + HloInstructionProto instr; + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, + dimensions)); + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kMap, operands); + }); +} + +XlaOp XlaBuilder::RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Check the number of parameters per RNG distribution. + switch (distribution) { + case RandomDistribution::RNG_NORMAL: + case RandomDistribution::RNG_UNIFORM: + if (parameters.size() != 2) { + return InvalidArgument( + "RNG distribution (%s) expects 2 parameters, but got %ld", + RandomDistribution_Name(distribution).c_str(), parameters.size()); + } + break; + default: + LOG(FATAL) << "unhandled distribution " << distribution; + } + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + *instr.mutable_shape() = shape; + + instr.set_distribution(distribution); + + return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); + }); +} + +XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); +} + +XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); +} + +XlaOp XlaBuilder::While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Infer shape. + TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, + condition.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferWhileShape(condition_program_shape, + body_program_shape, init_shape)); + // Body comes before condition computation in the vector. + AddCalledComputation(body, &instr); + AddCalledComputation(condition, &instr); + return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); + }); +} + +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, + GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferGatherShape(input_shape, gather_indices_shape, + dimension_numbers, window_bounds)); + + *instr.mutable_gather_dimension_numbers() = dimension_numbers; + for (int64 bound : window_bounds) { + instr.add_gather_window_bounds(bound); + } + + return AddInstruction(std::move(instr), HloOpcode::kGather, + {input, gather_indices}); + }); +} + +XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); + TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape, + GetShape(true_operand)); + TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape, + true_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape, + GetShape(false_operand)); + TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, + false_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConditionalShape( + predicate_shape, true_operand_shape, false_operand_shape, + true_computation_shape, false_computation_shape)); + + // The index of true_computation must be 0 and that of false computation + // must be 1. + AddCalledComputation(true_computation, &instr); + AddCalledComputation(false_computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kConditional, + {predicate, true_operand, false_operand}); + }); +} + +XlaOp XlaBuilder::Reduce( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape, init_shape, dimensions_to_reduce, + called_program_shape)); + + for (int64 dim : dimensions_to_reduce) { + instr.add_dimensions(dim); + } + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kReduce, + {operand, init_value}); + }); +} + +XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); + std::iota(all_dimnos.begin(), all_dimnos.end(), 0); + return Reduce(operand, init_value, computation, all_dimnos); + }); +} + +XlaOp XlaBuilder::ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_RETURN_IF_ERROR( + ValidatePaddingValues(AsInt64Slice(operand_shape.dimensions()), + window_dimensions, window_strides)); + + std::vector> padding_values = + MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, + window_strides, padding); + return ReduceWindowWithGeneralPadding(operand, init_value, computation, + window_dimensions, window_strides, + padding_values); + }); +} + +XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReduceWindowShape(operand_shape, init_shape, + instr.window(), to_apply_shape)); + + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, + {operand, init_value}); + }); +} + +XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); + TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferBatchNormTrainingShape( + operand_shape, scale_shape, offset_shape, feature_index)); + + instr.set_epsilon(epsilon); + instr.set_feature_index(feature_index); + + return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining, + {operand, scale, offset}); + }); +} + +XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); + TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); + TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean)); + TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBatchNormInferenceShape( + operand_shape, scale_shape, offset_shape, + mean_shape, variance_shape, feature_index)); + + instr.set_epsilon(epsilon); + instr.set_feature_index(feature_index); + + return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference, + {operand, scale, offset, mean, variance}); + }); +} + +XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); + TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean)); + TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var)); + TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBatchNormGradShape( + operand_shape, scale_shape, batch_mean_shape, + batch_var_shape, grad_output_shape, feature_index)); + + instr.set_epsilon(epsilon); + instr.set_feature_index(feature_index); + + return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad, + {operand, scale, batch_mean, batch_var, grad_output}); + }); +} + +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, replica_group_ids, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} + +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (channel_id.has_value()) { + return Unimplemented( + "replica_group_ids and channel_id and is not supported in AllReduce"); + } + + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + for (int64 replica_group_id : replica_group_ids) { + instr.add_replica_group_ids(replica_group_id); + } + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, + {operand}); + }); +} + +XlaOp XlaBuilder::SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + return SelectAndScatterWithGeneralPadding( + operand, select, window_dimensions, window_strides, + MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, + window_strides, padding), + source, init_value, scatter); + }); +} + +XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& source_shape, GetShape(source)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape, + select.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape, + scatter.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferSelectAndScatterShape( + operand_shape, select_shape, instr.window(), + source_shape, init_shape, scatter_shape)); + + AddCalledComputation(select, &instr); + AddCalledComputation(scatter, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter, + {operand, source, init_value}); + }); +} + +XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReducePrecisionShape( + operand_shape, exponent_bits, mantissa_bits)); + instr.set_exponent_bits(exponent_bits); + instr.set_mantissa_bits(mantissa_bits); + return AddInstruction(std::move(instr), HloOpcode::kReducePrecision, + {operand}); + }); +} + +void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Send instruction produces a tuple of {aliased operand, U32 context}. + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN( + XlaOp send, + AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + send_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); +} + +XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Recv instruction produces a tuple of {receive buffer, U32 context}. + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, + AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = shape; + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} + +StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { + TF_RETURN_IF_ERROR(first_error_); + + // Verify that the handle is valid. + TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); + + bool is_constant = true; + std::set visited; + IsConstantVisitor(operand.handle(), &visited, &is_constant); + return is_constant; +} + +StatusOr XlaBuilder::BuildConstantSubGraph( + const XlaOp& root_op) const { + TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); + if (!is_constant) { + auto op_status = LookUpInstruction(root_op); + string op_string = + op_status.ok() ? op_status.ValueOrDie()->name() : ""; + return InvalidArgument( + "Operand to BuildConstantSubGraph depends on a parameter.\n\n" + " op requested for constant subgraph: %s\n\n" + "This is an internal error that typically happens when the XLA user " + "(e.g. TensorFlow) is attempting to determine a value that must be a " + "compile-time constant (e.g. an array dimension) but it is not capable " + "of being evaluated at XLA compile time.\n\n" + "Please file a usability bug with the framework being used (e.g. " + "TensorFlow).", + op_string.c_str()); + } + + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + LookUpInstruction(root_op)); + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); + if (!CanBeRoot(opcode)) { + return InvalidArgument("the operand with opcode %s cannot be root", + root->opcode().c_str()); + } + + HloComputationProto entry; + entry.set_id(GetUniqueId()); // Give the computation a global unique id. + entry.set_name(StrCat(name_, entry.id(), "_compute_constant")); + entry.set_root_id(root->id()); + ProgramShape* program_shape = entry.mutable_program_shape(); + *program_shape->mutable_result() = root->shape(); + + // We use std::set to keep the instruction ids in ascending order (which is + // also a valid denpendency order). The related ops will be added to the + // subgraph in the same order. + std::set related_ops; + tensorflow::gtl::FlatSet related_calls; // Related computations. + std::queue worklist; + worklist.push(root->id()); + related_ops.insert(root->id()); + while (!worklist.empty()) { + int64 node = worklist.front(); + worklist.pop(); + for (int64 id : instructions_[node].operand_ids()) { + if (related_ops.insert(id).second) { + worklist.push(id); + } + } + for (int64 called_id : instructions_[node].called_computation_ids()) { + related_calls.insert(called_id); + } + } + + // Add related ops to the computation. + for (int64 id : related_ops) { + auto* instr = entry.add_instructions(); + *instr = instructions_[id]; + // Ensures that the instruction names are unique among the graph. + const string& new_name = + StrCat(instr->name(), ".", entry.id(), ".", instr->id()); + instr->set_name(new_name); + } + + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_program_shape() = *program_shape; + for (auto& e : embedded_) { + if (related_calls.find(e.second.id()) != related_calls.end()) { + *module->add_computations() = e.second; + } + } + *module->add_computations() = std::move(entry); + + return std::move(computation); +} + +std::unique_ptr XlaBuilder::CreateSubBuilder( + const string& computation_name) { + auto sub_builder = MakeUnique(computation_name); + sub_builder->parent_builder_ = this; + sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; + return sub_builder; +} + +/* static */ ConvolutionDimensionNumbers +XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_kernel_output_feature_dimension( + kConvKernelOutputDimension); + dimension_numbers.set_kernel_input_feature_dimension( + kConvKernelInputDimension); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(i + 2); + dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); + } + return dimension_numbers; +} + +/* static */ Status XlaBuilder::Validate( + const ConvolutionDimensionNumbers& dnum) { + if (dnum.input_spatial_dimensions_size() < 2) { + return FailedPrecondition("input spacial dimension < 2: %d", + dnum.input_spatial_dimensions_size()); + } + if (dnum.kernel_spatial_dimensions_size() < 2) { + return FailedPrecondition("kernel spacial dimension < 2: %d", + dnum.kernel_spatial_dimensions_size()); + } + if (dnum.output_spatial_dimensions_size() < 2) { + return FailedPrecondition("output spacial dimension < 2: %d", + dnum.output_spatial_dimensions_size()); + } + + if (std::set( + {dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the input are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); + } + if (std::set({dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), + dnum.kernel_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); + } + if (std::set({dnum.output_batch_dimension(), + dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), + dnum.output_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.output_batch_dimension(), dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); + } + return Status::OK(); +} + +StatusOr XlaBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + TF_RETURN_IF_ERROR(first_error_); + + const int64 handle = instructions_.size(); + instr.set_id(handle); + instr.set_opcode(HloOpcodeString(opcode)); + if (instr.name().empty()) { + instr.set_name(StrCat(instr.opcode())); + } + for (const auto& operand : operands) { + if (operand.builder_ == nullptr) { + return InvalidArgument("invalid XlaOp with handle %lld", + operand.handle()); + } + if (operand.builder_ != this) { + return InvalidArgument("Do not add XlaOp from builder %s to builder %s", + operand.builder_->name().c_str(), + this->name().c_str()); + } + instr.add_operand_ids(operand.handle()); + } + + *instr.mutable_metadata() = metadata_; + if (sharding_) { + *instr.mutable_sharding() = *sharding_; + } + + instructions_.push_back(instr); + + XlaOp op(handle, this); + return op; +} + +void XlaBuilder::AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr) { + instr->add_called_computation_ids(computation.proto().entry_computation_id()); + for (const HloComputationProto& e : computation.proto().computations()) { + embedded_.insert({e.id(), e}); + } +} + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + + if (op.builder_ == nullptr) { + return InvalidArgument( + "invalid XlaOp with handle %lld; the builder of this op is freed", + op.handle()); + } + if (op.builder_ != this) { + return InvalidArgument( + "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "it in builder '%s'", + op.handle(), op.builder_->name().c_str(), this->name().c_str()); + } + + if (op.handle() >= instructions_.size() || op.handle() < 0) { + return InvalidArgument("no XlaOp value %lld", op.handle()); + } + return &instructions_[op.handle()]; +} + +XlaOp XlaBuilder::UnimplementedOp() { + NoteError(Unimplemented("Op not implemented")); + return {}; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..0329e42ed1aef8edd1537e888ddcd78f08584407 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -0,0 +1,1039 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class XlaBuilder; + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +class XlaOp { + public: + XlaOp() : handle_(0), builder_(nullptr) {} + ~XlaOp() {} + + const XlaBuilder* builder() const { return builder_; } + + bool operator==(const XlaOp& rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; + } + + bool operator!=(const XlaOp& rhs) const { + return handle_ != rhs.handle_ || builder_ != rhs.builder_; + } + + friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { + out << op.handle(); + return out; + } + + private: + XlaOp(int64 handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64 handle() const { return handle_; } + + friend class XlaBuilder; + + int64 handle_; + XlaBuilder* builder_; // Not owned. +}; + +// A convenient interface for building up computations. +// +// Thread-compatible. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + XlaBuilder(const string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + ~XlaBuilder(); + + // Returns the computation name. + const string& name() const { return name_; } + + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Enqueues a "retrieve parameter value" instruction for a parameter that was + // passed to the computation. + XlaOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + // Enqueues a constant with the value of the given literal onto the + // computation. + XlaOp ConstantLiteral(const LiteralSlice& literal); + + // Enqueues a constant onto the computation. Methods are templated on the + // native host type (NativeT) which corresponds to a specific XLA + // PrimitiveType as given in the following table: + // + // Native Type PrimitiveType + // ----------------------------- + // bool PRED + // int32 S32 + // int64 S64 + // uint32 U32 + // uint64 U64 + // float F32 + // double F64 + // + // Note: not all primitive types defined in xla_data.proto have a + // corresponding native type yet. + template + XlaOp ConstantR0(NativeT value); + template + XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(const tensorflow::core::Bitmap& values); + template + XlaOp ConstantR2( + std::initializer_list> values); + template + XlaOp ConstantFromArrayWithLayout(const Array& values, + const Layout& layout); + template + XlaOp ConstantFromArray(const Array& values); + template + XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); + template + XlaOp ConstantR2FromArray2D(const Array2D& values); + template + XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); + template + XlaOp ConstantR3FromArray3D(const Array3D& values); + template + XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); + template + XlaOp ConstantR4FromArray4D(const Array4D& values); + + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template + XlaOp ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + // The strides parameter determines the stride over the slice + XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + // Enqueues a slice operation in a given dimension, taking all other + // dimensions as they are; e.g. if dimno is 1 from start_index 2 to + // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand + // for: + // + // array[:, 2:4:1, :] + XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const XlaOp& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + + // Enqueues a tuple-element-get instruction onto the computation. + XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + + // Enqueues a general dot instruction onto the computation. + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. + XlaOp Infeed(const Shape& shape, const string& config = ""); + + // Enqueues an outfeed instruction onto the computation. This instruction + // generates outgoing data transfers for the given data. + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + XlaOp CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues a complex compose instruction onto the computation. + XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + XlaOp Conj(const XlaOp& operand); + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Element-wise logical operators + XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Not(const XlaOp& operand); + + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + // Convenience wrapper around the above that reduces all the dimensions in the + // operand shape. + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + + // Enqueues a windowed reduce instruction onto the computation. + XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + + // Enqueues an abs instruction onto the computation. + XlaOp Abs(const XlaOp& operand); + + // Enqueues a atan2 instruction onto the computation. + XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an exp instruction onto the computation. + XlaOp Exp(const XlaOp& operand); + + // Enqueues an expm1 instruction onto the computation. + XlaOp Expm1(const XlaOp& operand); + + // Enqueues a floor instruction onto the computation. + XlaOp Floor(const XlaOp& operand); + + // Enqueues a ceil instruction onto the computation. + XlaOp Ceil(const XlaOp& operand); + + // Enqueues a round instruction onto the computation, rounding to nearest even + // with half-way cases rounding away from zero. + XlaOp Round(const XlaOp& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + XlaOp Log(const XlaOp& operand); + + // Enqueues an log1p instruction (log(x+1)) onto the computation. + XlaOp Log1p(const XlaOp& operand); + + // Enqueues a sign instruction onto the computation. + XlaOp Sign(const XlaOp& operand); + + // Enqueues a count leading zeros instruction onto the computation. + XlaOp Clz(const XlaOp& operand); + + // Enqueues a cosine instruction onto the computation. + XlaOp Cos(const XlaOp& operand); + + // Enqueues a sine instruction onto the computation. + XlaOp Sin(const XlaOp& operand); + + // Enqueues a tanh instruction onto the computation. + XlaOp Tanh(const XlaOp& operand); + + // Enqueues a real-part instruction onto the computation. + XlaOp Real(const XlaOp& operand); + + // Enqueues an imaginary-part instruction onto the computation. + XlaOp Imag(const XlaOp& operand); + + // Enqueues a float32 sqrt instruction onto the computation. + // (float32 is specified as there is an implicit float32 0.5f constant + // exponent). + XlaOp SqrtF32(const XlaOp& operand); + + // Enqueues a float32 square instruction onto the computation. + // (float32 is specified as there is an implicit float32 2.0f constant + // exponent). + XlaOp SquareF32(const XlaOp& operand); + + // Enqueues a lhs^rhs computation onto the computation. + XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + XlaOp IsFinite(const XlaOp& operand); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a float32 reciprocal instruction onto the computation. + // (float32 is specified as there is an implicit float32 -1.0f constant + // exponent). + // + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. + XlaOp ReciprocalF32(const XlaOp& operand); + + // Enqueues a negate instruction onto the computation. + XlaOp Neg(const XlaOp& operand); + + // Enqueues a transpose instruction onto the computation. + XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + XlaOp Sort(const XlaOp& operand); + + // Enqueues a clamp instruction onto the computation. + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + + // Enqueues a map instruction onto the computation. + XlaOp Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. Returns values in the semi-open interval [a, b). + XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + + // Enqueues a while node onto the computation. + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + + // Enqueues a conditional node onto the computation. + XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + + // Enqueues a ReducePrecision node onto the computation. + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + + // Enqueues a Gather node onto the computation. + XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + + // Enqueues a Send node onto the computation, to send the given operand to + // a Recv instruction that shares the same channel handle. + void Send(const XlaOp& operand, const ChannelHandle& handle); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. + StatusOr Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr GetProgramShape() const; + + private: + StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands = {}); + + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); + + // Notes that the error occurred by: + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to Build()) + // * dying if die_immediately_on_error_ is true + void NoteError(const Status& error); + + XlaOp NoteErrorOrReturn(const std::function()>& op_creator); + + // Helper method that creates an empty op and notes error. + XlaOp UnimplementedOp(); + + StatusOr LookUpInstruction(const XlaOp& op) const; + + // Internal helper method that does the building for an arbitrary unary op. + XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs); + + XlaOp RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape); + + StatusOr InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + StatusOr AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + StatusOr Reshape(const Shape& shape, const XlaOp& operand); + + // Returns the (inferred) result for the program shape for the current + // computation and fills the root_id in the pointer. + StatusOr GetProgramShape(int64* root_id) const; + + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const; + + // A visitor which checks whether an operation is a compile-time constant, + // meaning that it doesn't depend on any parameters, or on any stateful + // operation such as `RngNormal` or `Infeed`. The visitor walks the + // computation starting at a given operation and sets is_constant to false iff + // a parameter or stateful operation is encountered. + void IsConstantVisitor(const int64 op_handle, std::set* visited, + bool* is_constant) const; + + // Checks bounds for convolution parameters. + Status VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const; + + // Helper function for creating a Window proto from user-supplied data. + // Returns error if the user-supplied data was invalid. + StatusOr MakeWindow( + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation) const; + + string name_; // Name to use for the built computation. + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tensorflow::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + std::vector instructions_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + tensorflow::gtl::FlatSet parameter_numbers_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional sharding_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; +}; + +template +XlaOp XlaBuilder::ConstantR0(NativeT value) { + return ConstantLiteral(*Literal::CreateR0(value)); +} + +template +XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); +} + +inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR2( + std::initializer_list> values) { + return ConstantLiteral(*Literal::CreateR2(values)); +} + +template +XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, + const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantFromArray(const Array& values) { + return ConstantLiteral(*Literal::CreateFromArray(values)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { + return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { + return ConstantFromArray(values); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { + return ConstantFromArray(values); +} + +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2df3ea3af0d4fcfb9bc803feebd96f09042ab1f3 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace op = xla::testing::opcode_matchers; + +using ::testing::HasSubstr; + +// TODO(b/74197823): Move the tests to service/. +class XlaBuilderTest : public ::testing::Test { + protected: + StatusOr> BuildHloModule(XlaBuilder* b) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + + // Returns the name of the test currently being run. + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } +}; + +TEST_F(XlaBuilderTest, OnePlusTwo) { + XlaBuilder b(TestName()); + b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + b.Add(x, b.ConstantR0(1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); +} + +TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { + XlaBuilder b(TestName()); + const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); + const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); + auto x = b.Parameter(0, x_shape, "x"); + auto y = b.Parameter(1, y_shape, "y"); + auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); + EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); +} + +TEST_F(XlaBuilderTest, XPlusX) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + b.Add(x, x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); +} + +TEST_F(XlaBuilderTest, ShapeInferenceError) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); +} + +TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { + XlaBuilder b_call("add"); + b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("parameter 0 already registered")); +} + +TEST_F(XlaBuilderTest, Call) { + XlaBuilder b_call("the_only_to_apply"); + auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + b_call.Add(p0, p1); + TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = b.ConstantR0(1); + auto two = b.ConstantR0(2); + b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), + op::Call(op::Constant(), op::Constant()))); +} + +TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + b.Add(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // Expected: + // + // x: f32[1,2,3] y: f32[1,2,1] + // | | + // | reshape: f32[1,2] + // | | + // | broadcast: f32[1,2,3] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // Expected: + // + // x: f32[2,3] y: f32[2,1,4] + // | | + // broadcast: f32[2,3,4] reshape: f32[2,4] + // | | + // | broadcast: f32[2,3,4] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { + XlaBuilder b1("b1"); + auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + XlaBuilder builder("main"); + builder.Add(p0, p0); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "built by builder 'b1', but is trying to use it in builder 'main'")); +} + +TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Parameter())); +} + +TEST_F(XlaBuilderTest, ReshapeHasTranspose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); +} + +TEST_F(XlaBuilderTest, Transpose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + b.Transpose(x, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Transpose(op::Parameter())); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +TEST_F(XlaBuilderTest, Xor) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Xor(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + LOG(ERROR) << module->ToString(); + EXPECT_THAT(root, + op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), + op::And(op::Parameter(0), op::Not(op::Parameter(1))))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc new file mode 100644 index 0000000000000000000000000000000000000000..72e3935696e0c44ae3893fc8f1ceb261fa5e2646 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" + +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +StatusOr XlaComputation::GetProgramShape() const { + TF_RET_CHECK(proto_.has_program_shape()); + return proto_.program_shape(); +} + +StatusOr> XlaComputation::Snapshot() const { + if (IsNull()) { + return InvalidArgument("Computation is invalid."); + } + auto session = MakeUnique(); + *session->mutable_hlo()->mutable_hlo_module() = proto_; + return std::move(session); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h new file mode 100644 index 0000000000000000000000000000000000000000..0ffba208b1f8683fe1d26107cbfd096b856267f1 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +class XlaComputation { + public: + XlaComputation() : unique_id_(-1) {} + XlaComputation(const HloModuleProto& proto) + : unique_id_(proto.id()), proto_(proto) {} + + ~XlaComputation() {} + + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + StatusOr GetProgramShape() const; + + const HloModuleProto& proto() const { return proto_; } + + // Requests that we snapshot the computation into a serializable protocol + // buffer form. + StatusOr> Snapshot() const; + + // Returns true if this object is a null Computation. + bool IsNull() const { return unique_id_ == -1; } + + private: + XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} + HloModuleProto* mutable_proto() { return &proto_; } + friend class XlaBuilder; + + int64 unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h index 23a622b1ad0e2f3b220645f62767271f28df24e9..1a51fdee680721a4a03fa5de79a81746d92af76b 100644 --- a/tensorflow/compiler/xla/device_util.h +++ b/tensorflow/compiler/xla/device_util.h @@ -29,7 +29,7 @@ namespace xla { // Returns a string that represents the device in terms of platform and ordinal; // e.g. the first CUDA device will be "cuda:0" -string DeviceIdentifier(perftools::gputools::StreamExecutor* stream_exec) { +string DeviceIdentifier(se::StreamExecutor* stream_exec) { return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", stream_exec->device_ordinal()); } diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..a1463aa15941b9c265db94e2eb3cc176fab6695b --- /dev/null +++ b/tensorflow/compiler/xla/error_spec.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ +#define TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ + +namespace xla { + +// Structure describing permissible absolute and relative error bounds. +struct ErrorSpec { + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} + + float abs; // Absolute error bound. + float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 392ad9010ab81923a089c7b00a79ddc281af92bb..a472747bd174e3bbd352f07f2ab092e678b81073 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -36,26 +36,15 @@ DeviceMemoryAllocator* ExecutableRunOptions::allocator() const { } ExecutableRunOptions& ExecutableRunOptions::set_stream( - perftools::gputools::Stream* stream) { + stream_executor::Stream* stream) { stream_ = stream; return *this; } -perftools::gputools::Stream* ExecutableRunOptions::stream() const { +stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } -ExecutableRunOptions& ExecutableRunOptions::set_inter_op_thread_pool( - tensorflow::thread::ThreadPool* inter_op_thread_pool) { - inter_op_thread_pool_ = inter_op_thread_pool; - return *this; -} - -tensorflow::thread::ThreadPool* ExecutableRunOptions::inter_op_thread_pool() - const { - return inter_op_thread_pool_; -} - ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; @@ -87,4 +76,11 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } +ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { + rng_seed_ = rng_seed; + return *this; +} + +int ExecutableRunOptions::rng_seed() const { return rng_seed_; } + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index d4fcbf0493c936ebcd0639a432e56b62ee15672c..416131be006e6ecddb47651f8b684c1d91df4892 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,26 +16,27 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Intentionally forward declared so that ExecutableRunOptions can be linked +// Pulls in the ::stream_executor -> ::xla::se namespace alias. +#include "tensorflow/compiler/xla/types.h" + +// These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't // need to be linked). -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; class Platform; -} -} +} // namespace stream_executor namespace tensorflow { namespace thread { class ThreadPool; -} -} +} // namespace thread +} // namespace tensorflow namespace Eigen { struct ThreadPoolDevice; -} +} // namespace Eigen namespace xla { @@ -61,14 +62,8 @@ class ExecutableRunOptions { // If set, this is the stream to run the computation on. The platform of the // stream must match the platform the executable was built for. A value of // nullptr indicates the option has not been set. - ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream); - perftools::gputools::Stream* stream() const; - - // Sets the thread pool on which to run parallel CPU backend - // computations. Does not take ownership. - ExecutableRunOptions& set_inter_op_thread_pool( - tensorflow::thread::ThreadPool* inter_op_thread_pool); - tensorflow::thread::ThreadPool* inter_op_thread_pool() const; + ExecutableRunOptions& set_stream(stream_executor::Stream* stream); + stream_executor::Stream* stream() const; // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. @@ -84,14 +79,17 @@ class ExecutableRunOptions { DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; + ExecutableRunOptions& set_rng_seed(int rng_seed); + int rng_seed() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; DeviceAssignment* device_assignment_ = nullptr; - perftools::gputools::Stream* stream_ = nullptr; - tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; + stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; + int rng_seed_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index fdc4bbdd8b162b7115788e267c2a53e73c186123..e8f29b83291a7cb238dc25b9f4bb743fe426a162 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -64,6 +65,16 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor) { + Layout layout; + layout.set_format(DENSE); + for (int i = major_to_minor.size() - 1; i >= 0; i--) { + layout.add_minor_to_major(major_to_minor[i]); + } + return layout; +} + /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { Layout layout; layout.set_format(SPARSE); @@ -87,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + // Opaque and token types have empty layouts. + return Layout(); + } + // A Layout proto corresponds to a single array, not a tuple. - DCHECK(!ShapeUtil::IsTuple(shape)); + CHECK(ShapeUtil::IsArray(shape)); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -115,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsOpaque(*shape)) { - shape->clear_layout(); - } else { + } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); + } else { + // Opaque, token types etc. have no layout. + shape->clear_layout(); } } @@ -139,8 +156,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( - const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { @@ -149,30 +165,34 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } - return tensorflow::Status::OK(); - } else if (ShapeUtil::IsOpaque(shape)) { - if (shape.has_layout()) { - return InvalidArgument("opaque should not have a layout field"); - } - return tensorflow::Status::OK(); - } else { - // Array shape. + return Status::OK(); + } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); + } else { + // Token, opaque, etc. shape. + if (shape.has_layout()) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( - const Layout& layout, const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, + const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (ShapeUtil::IsOpaque(shape)) { - return tensorflow::Status::OK(); + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); } if (layout.format() == INVALID_FORMAT) { @@ -224,7 +244,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { @@ -263,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || shape.layout().padded_dimensions_size() == 0) { return false; } @@ -313,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); - } else if (ShapeUtil::IsOpaque(shape)) { + } else if (!ShapeUtil::IsArray(shape)) { + // Opaque, token types etc. ignore layout. return true; } return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; @@ -383,7 +404,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { namespace { // Internal helper for recursively copying layouts. -tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { +Status CopyLayoutInternal(const Shape& src, Shape* dst) { if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); @@ -410,25 +431,21 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { dst->clear_layout(); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace /* static */ -tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, - Shape* dst) { +Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { - return false; - } if (ShapeUtil::IsTuple(lhs)) { - if (ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -437,9 +454,12 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, } } return true; - } else { + } else if (ShapeUtil::IsArray(lhs)) { return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } else { + // Layouts of non-array and non-tuple shapes is ignored. + return true; } } @@ -465,4 +485,25 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { return out; } +/*static*/ size_t LayoutUtil::Hash(const Layout& layout) { + using tensorflow::hash; + using tensorflow::Hash64Combine; + + size_t hash_value = hash()(layout.format()); + + for (int64 minor_to_major : layout.minor_to_major()) { + hash_value = Hash64Combine(hash_value, hash()(minor_to_major)); + } + + for (int64 padded_dim : layout.padded_dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(padded_dim)); + } + + hash_value = + Hash64Combine(hash_value, hash()(layout.padding_value())); + hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); + + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c54eb2201b66a4a0c5695bceb14bb2367133935..739bbe73675c7fb855627006028eafdf703d6540 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Similar to MakeLayout, but take indices in reverse order. + static Layout MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); @@ -61,12 +65,12 @@ class LayoutUtil { static void SetToDefaultLayout(ProgramShape* program_shape); // Validates that the layout within the given shape is correct. - static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + static Status ValidateLayoutInShape(const Shape& shape); // Validates that the provided layout satisfies invariants for the given // shape. - static tensorflow::Status ValidateLayoutForShape(const Layout& layout, - const Shape& shape); + static Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); // Clears the layout in the given Shape. After this function is called, // HasLayout will return false for the shape. @@ -179,8 +183,7 @@ class LayoutUtil { // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same // rank. within the shapes must have the same number of dimensions. - static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, - Shape* dst); + static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false // otherwise. Recursively compares layouts of tuples. @@ -195,6 +198,9 @@ class LayoutUtil { static bool AreDimensionsConsecutive(const Layout& layout, tensorflow::gtl::ArraySlice dims); + // Compute a hash for `layout`. + static size_t Hash(const Layout& layout); + private: TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4fd1d818e3e3b417eee9f6b14bb598bfb9480c6e..e4c825450dcd45a8fbeaacbb2ad145f94307176f 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { "elements, but shape is rank")); } +TEST_F(LayoutUtilTest, CopyTokenLayout) { + Shape src = ShapeUtil::MakeTokenShape(); + Shape dst = ShapeUtil::MakeTokenShape(); + + // Layouts are trivially the same for token types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyOpaqueLayout) { + Shape src = ShapeUtil::MakeOpaqueShape(); + Shape dst = ShapeUtil::MakeOpaqueShape(); + + // Layouts are trivially the same for opaque types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, ClearLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), @@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) { EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); } +TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) { + // Opaque and token types trivially have layouts. + for (Shape shape : + {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) { + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + LayoutUtil::ClearLayout(&shape); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + } +} + TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 0a9725db0a4fcf963cadcacf2cbc1d95d2c7239d..89353448e29ec3d97275dac288e23aa8e96e31b2 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -75,17 +75,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index c8ed3e3a2b009ddffdfb79a9a6ced8d5e736bee6..f42fb92359f40ec763866af094972046f6407ae1 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -40,10 +40,19 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { flags->set_xla_cpu_multi_thread_eigen(true); flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); flags->set_xla_eliminate_hlo_implicit_broadcast(true); - +#ifdef INTEL_MKL + flags->set_xla_cpu_use_mkl_dnn(true); +#endif // INTEL_MKL + flags->set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. flags->set_xla_gpu_use_cudnn_batchnorm(false); + + // Run all GPU work on one stream by default. Using multiple streams + // increases memory usage and we lack strong motivating benchmarks for tuning + // the heuristics needed to decide when to run on multiple streams. See + // b/77879207. + flags->set_xla_gpu_disable_multi_streaming(true); } // Allocates flag_values and flag_objects; this function must not be called more @@ -220,6 +229,11 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), + tensorflow::Flag( + "xla_gpu_max_kernel_unroll_factor", + int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), + flag_values->xla_gpu_max_kernel_unroll_factor(), + "Specify the maximum kernel unroll factor for the GPU backend."), tensorflow::Flag( "xla_dump_optimized_hlo_proto_to", flag_values->mutable_xla_dump_optimized_hlo_proto_to(), @@ -288,6 +302,10 @@ void AllocateFlags() { flag_values->xla_gpu_use_cudnn_batchnorm(), "Allows the GPU backend to implement batchnorm HLOs using cudnn, " "rather than expanding them to a soup of HLOs."), + tensorflow::Flag("xla_cpu_use_mkl_dnn", + bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), + flag_values->xla_cpu_use_mkl_dnn(), + "Generate calls to MKL-DNN in the CPU backend."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index a3b4286f4c12bf39a44c63dd6e7d303a46a418c3..7b6ae311c1099dccb8dceb2f49743c1b185cd5ab 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf9679cafec72c2e9dc5796e9058c6703239c508 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -0,0 +1,741 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_comparison.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" + +using tensorflow::strings::Appendf; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + +namespace xla { +namespace literal_comparison { +namespace { + +// Helper function for comparing a floating point type, FloatT, bitwise equal +// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT +// -- on miscompare, a nice error message is given in the AssertionFailure. +template +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { + auto ulhs = tensorflow::bit_cast(lhs); + auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); + if (ulhs != urhs) { + return InvalidArgument( + "floating values are not bitwise-equal; and equality testing " + "was requested: %s=%g=%a vs %s=%g=%a", + StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, + StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + } + return Status::OK(); +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +Status CompareEqual(NativeT lhs, NativeT rhs) { + if (lhs == rhs) { + return Status::OK(); + } + return InvalidArgument("Expected equality of these values:\n %s\n %s", + StrCat(lhs).c_str(), StrCat(rhs).c_str()); +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(float lhs, float rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(double lhs, double rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(complex64 lhs, complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res.ok()) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} + +// A recursive function which iterates through every index of expected and +// actual literal and compares their values elementwise. Returns true if all +// elements are equal. +template +Status Equal(LiteralSlice expected, LiteralSlice actual, + tensorflow::gtl::MutableArraySlice multi_index, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); + return CompareEqual(expected_value, actual_value); + } + + Status result; + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index[dimension] = i; + result.Update(Equal(expected, actual, multi_index, dimension + 1)); + } + return result; +} + +// Gets the total element count. For tuples, this is not the count of tuple +// elements, but the sum of elements of each tuple element. +int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } +} + +// Returns whether the actual and expected values are mismatched with respect to +// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +template +bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } +} + +template <> +bool NanMismatch(complex64 expected, complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +} + +template <> +bool NanMismatch(half expected, half actual, bool relaxed_nans) { + return NanMismatch(static_cast(expected), + static_cast(actual), relaxed_nans); +} + +// Converts the given floating-point value to a string. +template +string FpValueToString(NativeT value) { + return Printf("%8.4g", static_cast(value)); +} + +template <> +string FpValueToString(complex64 value) { + return Printf("%8.4g + %8.4fi", value.real(), value.imag()); +} + +// Returns the absolute value of the given floating point value. This function +// is used instead of std::abs directly in order to allow type-dependent +// implementations for NearComparator. +template +float FpAbsoluteValue(NativeT value) { + return std::abs(value); +} + +template <> +float FpAbsoluteValue(bfloat16 value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +float FpAbsoluteValue(half value) { + return FpAbsoluteValue(static_cast(value)); +} + +// Helper class for comparing floating-point literals within an error bound. +template +class NearComparator { + public: + // Compares the two array literals elementwise and returns a comparison + // result. The comparison is ok() if all actual and expected elements are + // within the given error bound. In case of error, the status contains a + // detailed message about the discrepancy. + static Status Compare(const LiteralSlice& expected, + const LiteralSlice& actual, ErrorSpec error, + bool detailed_message, + const MiscompareCallback& miscompare_callback) { + NearComparator comparator(expected, actual, error, + detailed_message, miscompare_callback); + return comparator.Run(); + } + + private: + // Data structure encapsulating metadata about a single element mismatch. + struct Mismatch { + NativeT actual; + NativeT expected; + float rel_error; + float abs_error; + + // The linear index of the failure within the shape. This linear index is + // from the 'actual' literal. + int64 linear_index; + + bool operator<(const Mismatch& other) const { + return rel_error < other.rel_error; + } + + string ToString(const Shape& shape) const { + return Printf( + "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", + FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + Literal::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + rel_error, abs_error); + } + }; + + NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, + ErrorSpec error, bool detailed_message, + const MiscompareCallback& miscompare_callback) + : expected_(expected), + actual_(actual), + error_(error), + detailed_message_(detailed_message), + miscompare_callback_(miscompare_callback), + abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), + abs_error_buckets_(kErrorBucketBounds.size(), 0), + rel_error_buckets_(kErrorBucketBounds.size(), 0) {} + + // Runs the comparison between expected and actual literals. + Status Run() { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, ToStringTruncated(expected_)); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, ToStringTruncated(actual_)); + + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); + if (!ShapeUtil::IsArray(expected_.shape())) { + return InvalidArgument("Expected array shape; got %s.", + ShapeUtil::HumanString(expected_.shape()).c_str()); + } + + mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); + mismatches_.PopulateWithValue(false); + + CompareLiterals(); + + if (num_mismatches_ == 0) { + return Status::OK(); + } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { + miscompare_callback_(expected_, actual_, mismatches_); + } + return InvalidArgument("%s", ErrorMessage().c_str()); + } + + // Insert the given absolute value into the absolute value bucket vector. The + // bounds of the buckets are given by kAbsValueBucketBounds. + void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { + // Adjust the bucket containing the absolute values of the 'actual' + // elements. + const float abs_value = FpAbsoluteValue(value); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + if (i == abs_value_buckets_.size() - 1 || + (abs_value >= kAbsValueBucketBounds[i] && + abs_value < kAbsValueBucketBounds[i + 1])) { + // The first value of the pair is the count of elements in the bucket, + // the second is the count of mismatches in the bucket. + abs_value_buckets_[i].first++; + if (is_mismatch) { + abs_value_buckets_[i].second++; + } + return; + } + } + } + + // Insert the given error into the given error bucket vector. + void UpdateErrorBucket( + float error, tensorflow::gtl::MutableArraySlice error_buckets) { + CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < error_buckets.size(); ++i) { + if (error >= kErrorBucketBounds[i]) { + error_buckets[i]++; + } + } + } + + // Compares the two given elements from the expected and actual literals at + // the given literal_index and keeps track of various mismatch statistics. + void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + const bool is_nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + float abs_error; + float rel_error; + if (actual == expected) { + abs_error = 0; + rel_error = 0; + } else if (is_nan_mismatch) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is used + // for sorting a std::set of the top mismatchs, and a nan value here will + // result in undefined behavior because nan's do not satisfy the strict + // weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = FpAbsoluteValue(actual - expected); + rel_error = abs_error / FpAbsoluteValue(expected); + } + const bool is_abs_mismatch = abs_error > error_.abs; + const bool is_rel_mismatch = rel_error > error_.rel; + const bool is_mismatch = + is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + + // Update the error of the relative bucket only if the *absolute* error + // bound is exceeded and vice versa. + if (is_abs_mismatch) { + num_abs_mismatches_++; + UpdateErrorBucket(rel_error, &rel_error_buckets_); + } + if (is_rel_mismatch) { + num_rel_mismatches_++; + UpdateErrorBucket(abs_error, &abs_error_buckets_); + } + + UpdateAbsValueBucket(actual, is_mismatch); + + if (!is_mismatch) { + return; + } + + num_mismatches_++; + + // Keep track of the kTopRelativeErrorCount relative error mismatches. + if (top_rel_mismatches_.size() < kTopRelativeErrorCount || + rel_error > top_rel_mismatches_.begin()->rel_error) { + Mismatch mismatch = {actual, expected, rel_error, abs_error, + linear_index}; + top_rel_mismatches_.insert(mismatch); + if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { + top_rel_mismatches_.erase(top_rel_mismatches_.begin()); + } + } + + mismatches_.data()[linear_index] = true; + } + + // Compares the two literals elementwise. + void CompareLiterals() { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual_.shape().layout(), + expected_.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected_.data(); + tensorflow::gtl::ArraySlice actual_data = + actual_.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + CompareValues(expected_data[i], actual_data[i], i); + } + return; + } + std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + CompareLiteralsSlow(0, &multi_index); + } + + // Slow path for CompareLiterals when 'actual' and 'expected' literals have + // different layouts. In this case, multidimensional indices are constructed + // and indexed for each element. + void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { + if (dimension == multi_index->size()) { + CompareValues(expected_.Get(*multi_index), + actual_.Get(*multi_index), + IndexUtil::MultidimensionalIndexToLinearIndex( + actual_.shape(), *multi_index)); + } else { + for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + CompareLiteralsSlow(dimension + 1, multi_index); + } + } + } + + // Returns an error message string with a detailed breakdown of the + // mismatches. Called after calling Run(). + string ErrorMessage() { + string out; + int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); + + auto percent_string = [](float a, float b) { + float pct = b == 0.0 ? 0.0 : 100.0 * a / b; + return Printf("%0.4f%%", pct); + }; + + Appendf(&out, + "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, + percent_string(num_mismatches_, element_count).c_str(), + ShapeUtil::HumanString(actual_.shape()).c_str(), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + if (num_nan_mismatches_ > 0) { + StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); + } + Appendf(&out, "Top relative error mismatches:\n"); + for (auto it = top_rel_mismatches_.rbegin(); + it != top_rel_mismatches_.rend(); ++it) { + StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + } + + if (!detailed_message_) { + return out; + } + + StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); + CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + const int64 bucket_size = abs_value_buckets_[i].first; + const int64 bucket_mismatches = abs_value_buckets_[i].second; + string mismatch_str = bucket_mismatches > 0 + ? Printf(", mismatches %lld", bucket_mismatches) + : ""; + Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count).c_str(), + mismatch_str.c_str()); + } + + auto print_accum_buckets = [&](const string& header, int64 total, + tensorflow::gtl::ArraySlice buckets) { + StrAppend(&out, header, ":\n"); + Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total).c_str()); + CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < kErrorBucketBounds.size(); ++i) { + Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total).c_str()); + } + }; + Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count).c_str()); + print_accum_buckets( + "Relative error breakdown of elements exceeding abs error bound", + num_abs_mismatches_, rel_error_buckets_); + Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count).c_str()); + print_accum_buckets( + "Absolute error breakdown of elements exceeding rel error bound", + num_rel_mismatches_, abs_error_buckets_); + return out; + } + + // 'actual' and 'expected' literals being compared. + LiteralSlice expected_; + LiteralSlice actual_; + + // The error bounds of the comparison. + ErrorSpec error_; + + // Whether to include detailed breakdown of mismatches in the error message. + bool detailed_message_; + + // Callback to invoke on miscompare. + MiscompareCallback miscompare_callback_; + + // Number of element element mismatches encountered so far. + int64 num_mismatches_ = 0; + + // Number of elements with a nan mismatch. + int64 num_nan_mismatches_ = 0; + + // Number of elements which exceed the absolute/relative error bound. + int64 num_abs_mismatches_ = 0; + int64 num_rel_mismatches_ = 0; + + // A Literal containing which elements did not match in the expected and + // actual literals. mismatches_ contains PREDs and is of the same sizes as + // the comparison literals. + Literal mismatches_; + + // The number of mismatches to report in the output, sorted by relative error + // magnitude. + static constexpr int64 kTopRelativeErrorCount = 5; + + // The set of mismatches with the largest relative error. The size of this set + // is bounded by kTopRelativeErrorCount. + std::multiset top_rel_mismatches_; + + // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the + // bounds of these buckets. abs_value_buckets_ contains a pair for each + // bucket: the element count and failure count. + static constexpr std::array kAbsValueBucketBounds = { + 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; + std::vector> abs_value_buckets_; + + // Buckets for relative and absolute errors. The relative error buckets only + // contains those elements which exceed the *absolute* error bound, and vice + // versa. This makes it easy to see the effect of adjusting the relative (or + // absolute) error bound on the success of the comparison. kErrorBucketBounds + // are the lower bounds of the buckets in both vectors. The error buckets are + // a cumulative distribution so an error value may appear in more than one + // bucket. For example an error value of 0.003 may appear in the buckets + // bounded by 0.01, 0.1, and 1.0. + static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, + 0.01, 0.1, 1}; + std::vector abs_error_buckets_; + std::vector rel_error_buckets_; +}; + +template +constexpr std::array NearComparator::kAbsValueBucketBounds; +template +constexpr std::array NearComparator::kErrorBucketBounds; + +// Helper function for comparing two literals for nearness. Handles tuple-shapes +// via recursion. shape_index is the ShapeIndex of expected (or actual) +// currently being compared. +Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback, + const ShapeIndex& shape_index) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + + if (ShapeUtil::IsTuple(expected.shape())) { + Status return_status; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); + ShapeIndex element_index = shape_index; + element_index.push_back(i); + Status res = + NearHelper(expected_element, actual_element, error, detailed_message, + miscompare_callback, element_index); + if (!res.ok()) { + string err_message = Printf("\nArray at shape index %s%s", + element_index.ToString().c_str(), + res.error_message().c_str()); + if (return_status.ok()) { + return_status = res; + } else { + return_status = AppendStatus(return_status, res.error_message()); + } + } + } + if (!return_status.ok() && shape_index.empty()) { + // Emit a top-level error message containing the top-level shape in case + // of mismatch. + int64 total_elements = RecursiveElementCount(actual.shape()); + return_status = InvalidArgument( + "\nMismatches in shape %s (%lld elements):\n%s", + ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, + return_status.error_message().c_str()); + } + return return_status; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + switch (expected.shape().element_type()) { + case BF16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F32: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case C64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + default: + LOG(FATAL) << "Unsupported primitive type in near comparator: " + << PrimitiveType_Name(expected.shape().element_type()) + << ". Must be floating-point type."; + } + } + + // Non-floating point literal. + return literal_comparison::Equal(expected, actual); +} + +} // namespace + +Status EqualShapes(const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return InvalidArgument("tupleness-mismatch! want: %s got %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (ShapeUtil::IsTuple(expected)) { + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return InvalidArgument( + "want tuple element count: %lld got tuple element count: %lld", + ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + } + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + Status result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result.ok()) { + return AppendStatus(result, StrCat("mismatch in tuple index", i)); + } + } + } else { + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return InvalidArgument("want rank of %s got rank of %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (expected.element_type() != actual.element_type()) { + return InvalidArgument( + "mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()).c_str(), + PrimitiveType_Name(actual.element_type()).c_str()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return InvalidArgument("want dimensions_size %d got dimensions_size %d", + expected.dimensions_size(), + actual.dimensions_size()); + } + for (int i = 0; i < expected.dimensions_size(); ++i) { + if (expected.dimensions(i) != actual.dimensions(i)) { + return InvalidArgument( + "mismatch in dimension #%d expected: %s actual: %s", i, + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + } + } + return Status::OK(); +} + +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, actual.ToString()); + + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update( + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); + } + break; + } + default: + LOG(FATAL) + << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + if (result.ok()) { + return Status::OK(); + } + + return AppendStatus(result, + tensorflow::strings::Printf( + "\nat index: %s\nexpected: %s\nactual: %s", + Literal::MultiIndexAsString(multi_index).c_str(), + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str())); +} + +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback) { + return NearHelper(expected, actual, error, detailed_message, + miscompare_callback, + /*shape_index=*/{}); +} + +string ToStringTruncated(const LiteralSlice& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; +} + +} // namespace literal_comparison +} // namespace xla diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h new file mode 100644 index 0000000000000000000000000000000000000000..00a13e361932e74a9a1e614d5c851d3851208852 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library for comparing literals without taking a dependency on testing +// libraries. + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ + +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace literal_comparison { + +// Returns ok if the given shapes have the same rank, dimension sizes, and +// primitive types. +Status EqualShapes(const Shape& expected, const Shape& actual); + +// Returns ok if the expected and actual literals are (bitwise) equal for all +// elements in the literal. Also, asserts that the rank, dimensions sizes, and +// primitive type are equal. +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); + +using MiscompareCallback = + std::function; + +// Inspects whether the expected and actual literals are within the given error +// bound for all elements. Also, inspects whether the rank, dimensions sizes, +// and dimension bounds are equivalent. +// +// Tuples are matched recursively. +// +// When comparing tensors of non-floating-point type, this inspects for exact +// equality, ignoring the ErrorSpec. +// +// If the shape of the literals is neither a complex/floating-point tensor nor a +// tuple which contains a complex/floating-point tensor, Near() is equivalent to +// Equal(). We don't raise an error in this case, because we want to allow +// callers to call Near() even if they have no preconceptions about the shapes +// being compared. +// +// If detailed_message is true, then the error message in the assertion result +// will contain a more detailed breakdown of mismatches. +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback); + +// Calling ToString on a literal with over 100 million elements takes around +// 3 minutes. The utility of printing a literal with >1000 elements is +// questionable, especially when writing the Literal proto to disk is orders +// of magnitude faster. +string ToStringTruncated(const LiteralSlice& literal); + +} // namespace literal_comparison +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e0a9b148b443e90a0c4f3e19660b6234d49eef84..6b295897004cebce003ddd3999aacf63915ffe5f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -44,8 +45,16 @@ namespace { constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; -// Converts between little and big endian, assuming elements in the array are 16 -// bits long. +// Converts between little and big endian. +// +// Precondition: size % 2 == 0 (elements in the array are 16 bits long) +void ConvertEndianShort(string* bytes) { + CHECK_EQ(bytes->size() / 2, 0); + for (int64 i = 0; i < bytes->size(); i += 2) { + std::swap((*bytes)[i], (*bytes)[i + 1]); + } +} + void ConvertEndianShort(char* bytes, int64 size) { CHECK_EQ(size / 2, 0); for (int64 i = 0; i < size; i += 2) { @@ -53,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(LiteralSlice literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -86,92 +136,89 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - piece.set_buffer(new char[piece.size_bytes()]); - if (LayoutUtil::IsSparseArray(subshape)) { - piece.set_sparse_indices(new SparseIndexArray( - LayoutUtil::MaxSparseElements(subshape.layout()), - ShapeUtil::Rank(subshape))); - } +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); -void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + SetPiece(*shape_, root_piece_, allocate_arrays); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; } - owns_buffers_ = other.owns_buffers_; +} - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + return *this; } -std::unique_ptr Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -186,9 +233,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ std::unique_ptr Literal::ConvertBF16ToF32( + const LiteralSlice& bf16_literal) { + return ConvertType(bf16_literal); +} + +/* static */ std::unique_ptr Literal::ConvertF32ToBF16( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + template Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -223,7 +280,7 @@ Status Literal::CopySliceFromInternal( Literal::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](const std::vector& indexes) { + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -248,6 +305,28 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + std::vector Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector elements; @@ -255,22 +334,21 @@ std::vector Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -313,7 +391,9 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -343,14 +423,14 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { #undef COPY_ELEMENTS default: return Unimplemented( - "Unhandled primitive type %s", + "Copying a Literal object with element type %s is not implemented.", PrimitiveType_Name(subshape().element_type()).c_str()); } } return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -363,36 +443,32 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(dest_subshape).c_str(), ShapeUtil::HumanString(src_subshape).c_str()); } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); } Status Literal::MoveFrom(Literal&& src_literal, @@ -406,37 +482,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -491,7 +562,10 @@ Status Literal::CopySliceFrom(const Literal& src_literal, default: break; } - return Unimplemented("Unhandled primitive type %d", shape().element_type()); + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); } /* static */ Literal Literal::Zero(PrimitiveType primitive_type) { @@ -702,7 +776,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -714,7 +788,7 @@ std::unique_ptr Literal::Relayout( return result; } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -733,7 +807,48 @@ std::unique_ptr Literal::Relayout( return result; } -StatusOr> Literal::Reshape( +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -747,7 +862,8 @@ StatusOr> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -761,7 +877,79 @@ StatusOr> Literal::Reshape( return std::move(output); } -std::unique_ptr Literal::Transpose( +/* static */ std::unique_ptr Literal::ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); + + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U8: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +std::unique_ptr LiteralBase::Transpose( tensorflow::gtl::ArraySlice permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -792,15 +980,31 @@ std::unique_ptr Literal::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(permuted_shape); - DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr Literal::Slice( +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + +std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -808,79 +1012,46 @@ std::unique_ptr Literal::Slice( DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; int64 dimension = limit_indices[dnum] - start_indices[dnum]; - CHECK_GT(dimension, 0); + CHECK_GE(dimension, 0) << "dnum = " << dnum; result_dimensions.push_back(dimension); } const auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); - - auto result_literal = MakeUnique(result_shape); - - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - float value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); case C64: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - complex64 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case S32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - int32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case U32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - uint32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr Literal::CloneToUnique() const { +std::unique_ptr LiteralBase::CloneToUnique() const { auto result = MakeUnique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -903,7 +1074,7 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, case U64: return StrCat(Get(multi_index, shape_index)); case F16: - return StrCat(Get(multi_index, shape_index)); + return StrCat(static_cast(Get(multi_index, shape_index))); case F32: return StrCat(Get(multi_index, shape_index)); case BF16: @@ -920,8 +1091,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -953,7 +1124,8 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, return StrCat( GetSparseElement(sparse_element_number, shape_index)); case F16: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); case F32: return StrCat( GetSparseElement(sparse_element_number, shape_index)); @@ -974,7 +1146,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr Literal::GetIntegralAsS64( +StatusOr LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -997,7 +1169,58 @@ StatusOr Literal::GetIntegralAsS64( } } -tensorflow::gtl::ArraySlice Literal::GetSparseIndex( +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsTuple(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + +Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set(multi_index, value); + break; + case U8: + Set(multi_index, value); + break; + case S32: + Set(multi_index, value); + break; + case S64: + Set(multi_index, value); + break; + case U32: + Set(multi_index, value); + break; + case U64: + Set(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1009,7 +1232,50 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -void Literal::Piece::SortSparseElements() { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { + case PRED: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 8 bit types. + case S8: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U8: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 16 bit types. + case BF16: + return std::move( + *Literal::CreateR0(GetFirstElement())); + case F16: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S16: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U16: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 32 bit types. + case F32: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S32: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U32: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 64 bit types. + case C64: + return std::move( + *Literal::CreateR0(GetFirstElement())); + case F64: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S64: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U64: + return std::move(*Literal::CreateR0(GetFirstElement())); + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + } +} + +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal(); @@ -1060,7 +1326,7 @@ void Literal::Piece::SortSparseElements() { } template -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data(); @@ -1071,9 +1337,11 @@ void Literal::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1232,13 +1500,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1246,7 +1515,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1256,6 +1525,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; @@ -1271,7 +1553,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1285,8 +1567,9 @@ void Literal::EachCellAsString( } namespace { -template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1296,13 +1579,43 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); + dest_data[i] = converter(src_data[i]); } return result_literal; } +template +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return tensorflow::bit_cast(src); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +// This template specialization is here to make the compiler happy. bit_cast has +// a static check that the types are the same size. This specialization should +// never be used because the source and destination types are checked for +// identical sizes higher up. +template +typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + LOG(FATAL) << "Invalid bitcast between types of different sizes."; +} + template -std::unique_ptr ConvertToC64(const Literal& src_literal) { +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1320,21 +1633,33 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, + bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return ConvertBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); + if (bitcast) { + return BitcastBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } else { + return ConvertBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } } template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal, \ + bitcast); CONVERT_IF_TYPES_MATCH(PRED) CONVERT_IF_TYPES_MATCH(S8) CONVERT_IF_TYPES_MATCH(S32) @@ -1348,25 +1673,32 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: - return ConvertToC64(src_literal); + if (!bitcast) { + return ConvertToC64(src_literal); + } + break; // Other types are not yet supported. default: - return InvalidArgument( - "Unimplemented: Convert from type %s to type %s", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + break; } + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } -} // namespace - -StatusOr> Literal::Convert( - PrimitiveType primitive_dest_type) const { - TF_RET_CHECK(ShapeUtil::IsArray(shape())); - switch (shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type); +StatusOr> ConvertSwitch( + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { + TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + if (literal.shape().element_type() == primitive_dest_type) { + return literal.CloneToUnique(); + } + switch (literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ + bitcast); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) CONVERT_IF_DEST_TYPE_MATCHES(S32) @@ -1381,15 +1713,65 @@ StatusOr> Literal::Convert( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return InvalidArgument("Unimplemented: Convert from type %s to type %s", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented( + "%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } +} // namespace + +StatusOr> LiteralBase::Convert( + PrimitiveType primitive_dest_type) const { + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); +} + +StatusOr> LiteralBase::BitcastConvert( + PrimitiveType primitive_dest_type) const { + if (primitive_util::BitWidth(shape().element_type()) != + primitive_util::BitWidth(primitive_dest_type)) { + return InvalidArgument( + "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "%d", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str(), + primitive_util::BitWidth(shape().element_type()), + primitive_util::BitWidth(primitive_dest_type)); + } + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); +} + +StatusOr> LiteralBase::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } + return Convert(dest_shape.element_type()); + } + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralSlice(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique(); + *converted = Literal::MoveIntoTuple(&elements); + return std::move(converted); +} + template -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get(*multi_index) == other.Get(*multi_index)); } @@ -1403,7 +1785,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector multi_index; @@ -1431,28 +1813,28 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { @@ -1470,11 +1852,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1527,41 +1909,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; + return true; + }); } -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue(root_piece().data(), @@ -1571,7 +1953,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1613,7 +2081,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1639,16 +2107,14 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_f16s() = string( reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto->mutable_f16s()->data()), - proto->f16s().size()); + ConvertEndianShort(proto->mutable_f16s()); } break; case BF16: *proto->mutable_bf16s() = string( reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto->mutable_bf16s()->data()), - proto->bf16s().size()); + ConvertEndianShort(proto->mutable_bf16s()); } break; case F32: @@ -1671,12 +2137,12 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1697,7 +2163,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -1764,21 +2230,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -1800,33 +2264,40 @@ StatusOr> Literal::CreateFromProto( auto literal = MakeUnique(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +/* static */ string Literal::MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index) { + return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -1834,11 +2305,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -1846,51 +2317,55 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + + piece->emplace_back(std::move(child_piece)); } } -LiteralView::~LiteralView() {} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK_NE(src_buf_ptr, nullptr); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(shape_.get()); } -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } - owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index d996004888ab521790b4c5a10da2a93f0d98d12f..8e4159e360e042beb31a75c432a3c7dfa7356007 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -52,14 +51,509 @@ limitations under the License. namespace xla { +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class Literal; + friend class LiteralSlice; + friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; +}; + // Class representing literal values in XLA. // -// TODO(b/67651157): The methods in this class should be reduced to a minimal -// set of methods which construct Literals and accessors methods. Other methods -// which perform computation on Literals (Reshape, Slice, etc) should be moved -// elsewhere, and perhaps combined with evaluator code which operates on -// Literals. -class Literal { +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { public: Literal() : Literal(ShapeUtil::MakeNil()) {} @@ -74,48 +568,162 @@ class Literal { Literal(const Literal& other) = delete; Literal& operator=(const Literal& other) = delete; Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); Literal& operator=(Literal&& other); - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const Literal& other) const; - bool operator!=(const Literal& other) const { return !(*this == other); } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } - // Serialize to and from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); - LiteralProto ToProto() const; + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); - // Return the shape of the literal. - const Shape& shape() const { return shape_; } + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return &shape_; } + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); - // Returns a pointer to (or size of) the underlying buffer holding the array - // at the given shape index. CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - void* untyped_data(const ShapeIndex& shape_index = {}); - int64 size_bytes(const ShapeIndex& shape_index = {}) const; + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); + + // Fills this literal with the given value. + template + void PopulateWithValue(NativeT value); + + // Factory methods below. + // + + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -163,10 +771,6 @@ class Literal { values, const Layout& layout); - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - // Creates a literal with a sparse layout and the given indices and values. // The shape is initialized from the given dimensions. The minor dimension of // the indices array must equal the rank of the shape (i.e. size of the @@ -206,147 +810,16 @@ class Literal { tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort = true); - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape); - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const Literal& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - - // Creates a new value that has the equivalent value as this literal, but - // conforms to new_layout; e.g. a literal matrix that was in {0, 1} - // minor-to-major dimension layout can be re-layed-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; - // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( @@ -390,81 +863,16 @@ class Literal { // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template - static std::unique_ptr CreateR3Projected( - std::initializer_list> values, - int64 projection); - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z); - - // Clones this literal into a new Literal, or new std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // Gets or sets an element in the literal at the given index. The multi_index - // is CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - - // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; + static std::unique_ptr CreateR3Projected( + std::initializer_list> values, + int64 projection); - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; + // Creates a literal that projects the (x, y) dimensions given in values into + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template @@ -475,6 +883,9 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + static std::unique_ptr MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; @@ -506,127 +917,104 @@ class Literal { return MakeTupleOwned(std::move(v)); } - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32( + const LiteralSlice& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16( + const LiteralSlice& f32_literal); + + // Creates a literal with a new shape with the given new dimensions using the + // data in the given input literal. For reshaping purposes the (flat) data + // buffer of the input literal is assumed to have the given minor_to_major + // layout order. + static std::unique_ptr ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal); + + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; + // End of factory methods. - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). - // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will + // be returned for a 2-dimensional index with dimension 0 index equal to 7, + // dimension 1 equal to 8. + static string MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index); - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + private: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - // Return the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); } - // Return the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - protected: - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); + Piece& root_piece() const override { return *root_piece_; }; // Internal template helper for the Literal::CopySliceFrom(), matching its // arguments one by one. template - Status CopySliceFromInternal(const Literal& src_literal, + Status CopySliceFromInternal(const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); @@ -654,152 +1042,71 @@ class Literal { int64 minor_loop_size = 1; }; - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Return the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Return the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - }; + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return *pieces_.mutable_element(shape_index); - } - const Piece& piece(const ShapeIndex& shape_index) const { - return pieces_.element(shape_index); - } + Piece* root_piece_ = nullptr; - // Returns the piece at the root of the shape (empty ShapeIndex). - Piece& root_piece() { return piece({}); } - const Piece& root_piece() const { return piece({}); } + // Implementation details shared between Populate() and PopulateParallel() + template + Status PopulateInternal(const FnType& generator, bool parallel); - // Deallocate the buffers held by this literal (if the literal owns the - // buffer). + // Deallocate the buffers held by this literal. void DeallocateBuffers(); - Shape shape_; - ShapeTree pieces_; - - // Whether the buffers held in pieces_ are owned by this Literal. - bool owns_buffers_; - - // LiteralView must access and manipulate Pieces of other Literals. - friend class LiteralView; -}; // namespace xla - + friend class LiteralBase; +}; std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralView contains pointers to buffers -// owned by the viewed Literal. -// -// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable -// and mutable) similar to (Mutable)ArraySlice. -class LiteralView : public Literal { +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { public: - // Create and return a view of the given literal rooted at the given shape - // index within the given literal. A factory is used rather than a public - // constructor because only const LiteralViews are supported. It's still - // possible to create non-const LiteralViews via the copy constructors, but - // the factory method makes it a bit less likely. Implementing literal slices - // will fix this undesirable situation (b/71550060). - static const LiteralView Create(const Literal& literal, - const ShapeIndex& view_root = {}); + LiteralSlice() : LiteralBase() {} - LiteralView(const LiteralView& other); - LiteralView& operator=(const LiteralView& other); - - virtual ~LiteralView(); + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); private: - LiteralView(const Literal& literal, const ShapeIndex& view_root); + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; + +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. - // Helper for the copy constructor and copy assignment operator. - void CopyFrom(const LiteralView& other); + private: + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); + + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template -tensorflow::gtl::ArraySlice Literal::Piece::data() const { +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -808,12 +1115,11 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), - ShapeUtil::ElementsIn(subshape())); + reinterpret_cast(buffer()), element_count()); } template -tensorflow::gtl::MutableArraySlice Literal::Piece::data() { +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -822,11 +1128,11 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), ShapeUtil::ElementsIn(subshape())); + reinterpret_cast(buffer()), element_count()); } template -NativeT Literal::Piece::Get( +NativeT LiteralBase::Piece::Get( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -834,15 +1140,15 @@ NativeT Literal::Piece::Get( } template -void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } template -tensorflow::gtl::ArraySlice Literal::data( +tensorflow::gtl::ArraySlice LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } @@ -854,13 +1160,13 @@ tensorflow::gtl::MutableArraySlice Literal::data( } template -inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT Literal::Get( +inline NativeT LiteralBase::Get( tensorflow::gtl::ArraySlice multi_index) const { return root_piece().Get(multi_index); } @@ -1107,13 +1413,13 @@ template } template -NativeT Literal::GetFirstElement() const { +NativeT LiteralBase::GetFirstElement() const { return data().at(0); } template -NativeT Literal::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { CHECK( LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); return data(shape_index)[sparse_element_number]; @@ -1146,7 +1452,7 @@ template } template -void Literal::EachCell( +void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { @@ -1237,19 +1543,20 @@ void Literal::PopulateSparse(SparseIndexArray indices, CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - root_data.remove_suffix(max_elements - values.size()); + // Piece::data() returns an ArraySlice of size equal to the number of indices + // in the SparseIndexArray. So there is no need to adjust the size of the data + // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); *this->root_piece().sparse_indices() = std::move(indices); if (sort) { auto root_data = this->root_piece().data(); - root_data.remove_suffix(root_data.size() - num_elements); this->root_piece().sparse_indices()->SortWithValues(root_data); } DCHECK(this->root_piece().sparse_indices()->Validate(shape())); } template -Status Literal::Populate(const FnType& generator) { +Status Literal::PopulateInternal(const FnType& generator, bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); @@ -1259,11 +1566,11 @@ Status Literal::Populate(const FnType& generator) { if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); - DimensionVector minor_scan_indexes(rank, 0); int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](const std::vector& indexes) { + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + DimensionVector minor_scan_indexes(rank, 0); const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); @@ -1271,17 +1578,35 @@ Status Literal::Populate(const FnType& generator) { minor_scan_indexes[stride_config.minor_dimension] = i; literal_data.at(index + i) = generator(minor_scan_indexes); } - return true; }; - ShapeUtil::ForEachIndex(this_shape, stride_config.base, - stride_config.dimensions, stride_config.step, - init_function); + if (parallel) { + ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, + stride_config.dimensions, + stride_config.step, init_function); + } else { + ShapeUtil::ForEachIndex( + this_shape, stride_config.base, stride_config.dimensions, + stride_config.step, + [&init_function](tensorflow::gtl::ArraySlice indexes) { + init_function(indexes); + return true; + }); + } } else { // For scalars. literal_data.at(0) = generator({}); } return Status::OK(); } +template +Status Literal::Populate(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/false); +} + +template +Status Literal::PopulateParallel(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/true); +} template void Literal::PopulateWithValue(NativeT value) { @@ -1303,7 +1628,7 @@ template } template -std::unique_ptr Literal::Replicate(int64 times) const { +std::unique_ptr LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { @@ -1338,6 +1663,38 @@ std::unique_ptr Literal::Replicate(int64 times) const { return literal; } +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + auto literal = MakeUnique(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index b3583c2eb75de8297d5e7507430491f119bd4462..53b926163c472c3ed7b72bf8b035d13996d59e34 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -30,6 +32,7 @@ limitations under the License. namespace xla { namespace { +using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -214,11 +217,9 @@ TEST_F(LiteralUtilTest, CreateSparse) { std::vector expected_values = {8, 9, 7, 10}; EXPECT_EQ(literal->sparse_indices()->data(), - tensorflow::gtl::ArraySlice( - expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), - expected_values.size()), - tensorflow::gtl::ArraySlice(expected_values)); + ArraySlice(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal->data(), ArraySlice(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -290,7 +291,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format on std::vector> seen; literal->EachCellAsString( - [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { + [&seen](ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -501,6 +502,24 @@ TEST_F(LiteralUtilTest, IsAllComplex) { ->IsAllComplex({8.0f, 9.0f})); } +TEST_F(LiteralUtilTest, IsAllFirst) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR1({false, true})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({false, false})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0(0.0f); auto scalar_one = Literal::CreateR0(1.0f); @@ -604,11 +623,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { // clang-format on auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, original->Get( - {indices[2], indices[3], indices[0], indices[1]})); - }); + reshape->EachCell([&](ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); + }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { @@ -845,7 +863,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](const std::vector& indexes) { + auto init_proc = [&](ArraySlice indexes) { source->Set(indexes, ++seqnr); return true; }; @@ -861,7 +879,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](const std::vector& indexes) { + auto check_proc = [&](ArraySlice indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); @@ -957,7 +975,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { Literal::CreateR1({2.0, 4.0}).get(), &nil_literal}); - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); @@ -968,7 +986,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); @@ -1048,8 +1066,8 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + auto literal = MakeUnique(shape); + auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1061,7 +1079,49 @@ TEST_F(LiteralUtilTest, Populate) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](const std::vector& indexes) { + auto check_function = [&](ArraySlice indexes) { + auto value = literal->Get(indexes); + matched = matched && (value == generator(indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + check_function); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, PopulateParallel) { + struct PopulateData { + std::vector dimensions; + std::vector layout; + } populate_data[] = { + {{}, {}}, + {{0}, {0}}, + {{16}, {0}}, + {{2, 0}, {1, 0}}, + {{4, 16}, {1, 0}}, + {{21, 12}, {0, 1}}, + {{6, 11, 17}, {2, 0, 1}}, + {{6, 11, 5, 17}, {3, 2, 0, 1}}, + }; + for (const auto& data : populate_data) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), data.dimensions, + data.layout); + auto literal = MakeUnique(shape); + auto generator = [&](ArraySlice indexes) -> uint32 { + // Offsets from linear index just to avoid R0 literals to be initialized + // with zero. + return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + indexes) + + 17; + }; + TF_EXPECT_OK(literal->PopulateParallel(generator)); + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + bool matched = true; + auto check_function = [&](ArraySlice indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; @@ -1214,15 +1274,34 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { EXPECT_EQ(*conv, *c64); EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64->Convert(S32).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(LiteralUtilTest, BitcastConvert) { + auto original = + Literal::CreateR1({tensorflow::bit_cast(2.5f), + tensorflow::bit_cast(-42.25f), + tensorflow::bit_cast(100.f), 0xbeef}); + auto expected = Literal::CreateR1( + {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, + original->BitcastConvert(F32)); +} + +TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { + auto literal = Literal::CreateR0(1234); + Status status = literal->BitcastConvert(F64).status(); + EXPECT_NE(Status::OK(), status); + EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), + "bit widths are different")); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1295,36 +1374,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, LiteralViewTest) { +TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); - EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); - EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); - EXPECT_EQ(LiteralView::Create(nil, {}), nil); + EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); + EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); } -TEST_F(LiteralUtilTest, MutatingLiteralView) { +TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); EXPECT_EQ( nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); @@ -1340,19 +1419,57 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) { 555.0f); } -TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { +TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); - const auto tuple_view = - LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); - const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { + std::vector int64_values = {1, 2, 3}; + const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); + + BorrowingLiteral literal(reinterpret_cast(int64_values.data()), + literal_shape); + + EXPECT_EQ(literal.Get({0}), 1); + EXPECT_EQ(literal.Get({1}), 2); + EXPECT_EQ(literal.Get({2}), 3); +} + +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { + std::vector one_two_three = {1, 2, 3}; + const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); + + std::vector hundred = {100}; + const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1}); + + std::vector src_buf_ptrs; + src_buf_ptrs.emplace_back( + reinterpret_cast(one_two_three.data())); + src_buf_ptrs.emplace_back(reinterpret_cast(hundred.data())); + auto literal_tuple = BorrowingLiteral( + src_buf_ptrs, + ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape})); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{0}), + 1); + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{1}), + 100); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{1}, /*shape_index=*/{0}), + 2); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{2}, /*shape_index=*/{0}), + 3); +} + TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1455,11 +1572,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_EQ(literal.Get({1, 1}), 4.0); } -TEST_F(LiteralUtilTest, LiteralViewCopy) { +TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralView::Create(*matrix); - LiteralView matrix_view_copy(matrix_view); + const auto matrix_view = LiteralSlice(*matrix); + LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); @@ -1684,7 +1801,7 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(half{2.0})); + tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( Literal::CreateSparse( dimensions, indices, @@ -1693,5 +1810,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { + std::unique_ptr literal = Literal::CreateR0(9); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 8db8c6f3de84a6c46625eadbb6b0f83d2262e5f7..3c74e070da529b7f1431e01fbaf31932f582db44 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault( // Inserts the key-value pair into the collection. Dies if key was already // present. -template -void InsertOrDie(Collection* const collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& data) { - auto p = collection->insert(std::make_pair(key, data)); +template +void InsertOrDie(Collection* const collection, Key&& key, Value&& value) { + auto p = collection->insert( + std::make_pair(std::forward(key), std::forward(value))); CHECK(p.second) << "duplicate key: " << key; } @@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) { } // Inserts `value` into `set`. Dies if it was already present. -template -void InsertOrDie(Set* const set, const typename Set::value_type& value) { - CHECK(set->insert(value).second) << "duplicate value: " << value; +template +void InsertOrDie(Set* const set, Value&& value) { + CHECK(set->insert(std::forward(value)).second) + << "duplicate value: " << value; } } // namespace xla diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h index c58c19db2cacbe9b038160f27b9bd76aa58146eb..bfcdfc62f9541ab09b94a48d5121e16bad4d43cd 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/ptr_util.h @@ -28,26 +28,8 @@ limitations under the License. #include "tensorflow/core/util/ptr_util.h" namespace xla { - -template -std::unique_ptr WrapUnique(T* ptr) { - return tensorflow::WrapUnique(ptr); -} - -template -typename tensorflow::helper::MakeUniqueResult::scalar MakeUnique( - Args&&... args) { - return tensorflow::MakeUnique(std::forward(args)...); -} - -// Overload for array of unknown bound. -// The allocation of arrays needs to use the array form of new, -// and cannot take element constructor arguments. -template -typename tensorflow::helper::MakeUniqueResult::array MakeUnique(size_t n) { - return tensorflow::MakeUnique(n); -} - +using tensorflow::MakeUnique; +using tensorflow::WrapUnique; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index e2972f06016ab3555c4fc0cc4616993fe6764b1e..83834c1ff65ea2f9989fe08279c29056d9070adb 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -12,6 +12,7 @@ py_library( deps = [ ":pywrap_xla", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -20,6 +21,7 @@ py_test( srcs = ["xla_client_test.py"], main = "xla_client_test.py", srcs_version = "PY2AND3", + tags = ["no_oss"], deps = [ ":xla_client", "//tensorflow/python:platform_test", @@ -48,9 +50,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", @@ -72,15 +76,3 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla/service:cpu_plugin", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cb7bb21e092c80d7360c23f3d6b00409a75dce23..445cee1aa7b462f7ae2b6b0771ff57f0c8f3db99 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/default/thread_annotations.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace xla { - namespace swig { // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of @@ -89,40 +89,70 @@ StatusOr> TransferFromOutfeedLocalReplica( return client->TransferFromOutfeedLocal(shape, device_ordinal); } -LocalShapedBuffer::LocalShapedBuffer( - std::unique_ptr shaped_buffer) +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) : shaped_buffer_(std::move(shaped_buffer)) {} -const std::unique_ptr& LocalShapedBuffer::shaped_buffer() - const { - return shaped_buffer_; +const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { + return &shaped_buffer_; +} + +ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } + +LocalShapedBufferTuple::LocalShapedBufferTuple( + std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + DCHECK(element != nullptr); + } } -static StatusOr> ToBuffer( - LocalClient* client, int device_ordinal, const Literal& arg) { +LocalShapedBufferTuple::~LocalShapedBufferTuple() { + for (LocalShapedBuffer* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr LocalShapedBufferTuple::Release(int i) { + LocalShapedBuffer* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int LocalShapedBufferTuple::size() const { return elements_.size(); } + +static StatusOr ToBuffer(LocalClient* client, + int device_ordinal, + const Literal& arg) { return client->LiteralToShapedBuffer(arg, device_ordinal, client->backend().memory_allocator()); } /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral( +StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - std::unique_ptr buf; - if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie(); - } else { - buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); - } - return new LocalShapedBuffer(std::move(buf)); + StatusOr buf = [&] { + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, *relaid); + } + return ToBuffer(client, /*device_ordinal=*/0, argument); + }(); + TF_RETURN_IF_ERROR(buf.status()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -std::unique_ptr LocalShapedBuffer::ToLiteral() const { +StatusOr> LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); + return client->ShapedBufferToLiteral(*shaped_buffer()); } CompiledLocalComputation::CompiledLocalComputation( @@ -158,14 +188,14 @@ StatusOr> CompiledLocalComputation::Execute( << device_ordinal; // Transfer arguments in - std::vector> scoped_buffers; + std::vector scoped_buffers; scoped_buffers.reserve(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { const Literal& argument = arguments[i]; const tensorflow::gtl::optional& shape_with_layout = shapes_with_layout[i]; - StatusOr> pushed; + StatusOr pushed; if (shape_with_layout) { std::unique_ptr relaid = argument.Relayout(shape_with_layout.value()); @@ -185,7 +215,7 @@ StatusOr> CompiledLocalComputation::Execute( std::vector argument_buffers; argument_buffers.reserve(scoped_buffers.size()); for (auto& buffer : scoped_buffers) { - argument_buffers.push_back(buffer.get()); + argument_buffers.push_back(&buffer); } DeviceAssignment device_assignment = @@ -197,12 +227,10 @@ StatusOr> CompiledLocalComputation::Execute( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); options.set_allocator(client->backend().memory_allocator()); - options.set_inter_op_thread_pool( - client->backend().inter_op_thread_pool()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - StatusOr> result_buffer_status = + StatusOr result_buffer_status = executable_->Run(argument_buffers, options); if (!result_buffer_status.ok()) { results[replica] = result_buffer_status.status(); @@ -210,8 +238,8 @@ StatusOr> CompiledLocalComputation::Execute( } // Transfer result out - results[replica] = - client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie()); + results[replica] = client->ShapedBufferToLiteral( + std::move(result_buffer_status).ValueOrDie()); }); } } @@ -236,22 +264,21 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( std::vector argument_buffers; argument_buffers.reserve(argument_handles.size()); for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer().get()); + argument_buffers.push_back(handle->shaped_buffer()); } // Execute ExecutableRunOptions options; options.set_allocator(client->backend().memory_allocator()); - options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); - std::unique_ptr result_buffer = + ScopedShapedBuffer result_buffer = executable_->Run(argument_buffers, options).ConsumeValueOrDie(); return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -274,18 +301,31 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } +string LocalComputation::GetSerializedProto() const { + string result; + if (!computation_.proto().SerializeToString(&result)) { + LOG(ERROR) << "Failed to serialize the HloModuleProto."; + return ""; + } + return result; +} + StatusOr LocalComputation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation_.GetProgramShape()); return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -294,19 +334,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -314,205 +356,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); +} + +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); +} + +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); +} + +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -520,23 +593,19 @@ ComputationDataHandle LocalComputationBuilder::Conditional( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -558,10 +627,12 @@ _FORWARD_BINOP(Or) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) +_FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) +_FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -591,6 +662,54 @@ void DeleteLocalComputation(LocalComputation* computation) { delete computation; } -} // namespace swig +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer) { + if (!ShapeUtil::IsTuple( + local_shaped_buffer->shaped_buffer()->on_device_shape())) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString( + local_shaped_buffer->shaped_buffer()->on_device_shape()) + .c_str()); + } + DeviceMemoryAllocator* allocator = + local_shaped_buffer->shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); + + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); + + ShapeTree& shape_tree = tuple_buffer.buffers(); + const Shape& tuple_shape = tuple_buffer.on_device_shape(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); + + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); + + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator))); + } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); +} + +} // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index d3e9503ea10b011520ec5148a756ef4d421f244c..0da3964676e9c6729229686f38bb05c8b2427bff 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,15 +17,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { - namespace swig { // Initializes the number of replicas that XLA will be initialized with (when @@ -59,17 +59,51 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral( + static StatusOr FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout); - LocalShapedBuffer(std::unique_ptr shaped_buffer); - const std::unique_ptr& shaped_buffer() const; - std::unique_ptr ToLiteral() const; + + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + const ScopedShapedBuffer* shaped_buffer() const; + + StatusOr > ToLiteral() const; + + // Transfers ownership of the encapsulated ShapedBuffer to the caller, + // analogous to std::unique_ptr::release(). + ShapedBuffer Release(); + + private: + ScopedShapedBuffer shaped_buffer_; +}; + +// Result of a tuple destructuring operation on a LocalShapedBuffer -- this +// appears to be a simpler mechanism for the time being than an alternative like +// using SWIG to transform std::vectors into Python lists of SWIG objects +// directly. +class LocalShapedBufferTuple { + public: + // Note: any LocalShapedBuffer elements that are not Release()'d will be + // deallocated in the destructor. + explicit LocalShapedBufferTuple(std::vector elements); + + ~LocalShapedBufferTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int size() const; private: - std::unique_ptr shaped_buffer_; + std::vector elements_; }; +// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements +// in LocalShapedBufferTuple form. +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer); + // Wraps a LocalExecutable produced by compiling a // LocalComputation. The Execute method forwards to that of the // underlying LocalExecutable, and additionally handles tranferring @@ -95,25 +129,42 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; + + // Returns the HloModuleProto contained in the XlaComputation in the + // serialized binary format. Logs an internal error and returns an empty + // string on failure. + string GetSerializedProto() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -133,155 +184,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); + + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); + + StatusOr IsConstant(const LocalOp& operand); + + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -303,10 +336,12 @@ class LocalComputationBuilder { _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) + _FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -325,7 +360,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. @@ -334,7 +369,6 @@ void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); void DeleteLocalComputation(LocalComputation* computation); } // namespace swig - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 456e341f877e529f7fc5ebc81d85862bfa291943..477df6fde25d0db760e08df9d335bd12e31ccb55 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -141,6 +138,33 @@ bool GetIntAttr(PyObject* o, const char* field, int64* result) { return true; } +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, + const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } + if (!PyString_Check(attr)) { + string message = tensorflow::strings::Printf("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr).c_str()); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + } } %} @@ -150,41 +174,53 @@ bool GetIntAttr(PyObject* o, const char* field, int64* result) { tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - return NULL; +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::CompiledLocalComputation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; } - temp.set_handle(handle); - $1 = &temp; } -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBuffer*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + $typemap(out, xla::swig::LocalShapedBufferTuple*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -197,7 +233,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -206,7 +242,16 @@ tensorflow::ImportNumpy(); $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; } } @@ -214,8 +259,9 @@ tensorflow::ImportNumpy(); if (!$1.ok()) { PyErr_SetString( PyExc_RuntimeError, $1.ToString().c_str()); - return NULL; + SWIG_fail; } + Py_INCREF(Py_None); $result = Py_None; } @@ -225,7 +271,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -237,13 +283,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "Argument sequence element cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i] = numpy::PyIntOrPyLongToLong(py_int); if (temps[i] == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } Py_DECREF(py_int); Py_DECREF(o); @@ -251,33 +297,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - return NULL; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - return NULL; + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { + SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -289,7 +325,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); @@ -298,7 +334,7 @@ tensorflow::ImportNumpy(); LocalShapedBuffer* lsbp; if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), SWIG_POINTER_EXCEPTION)) == -1) { - return NULL; + SWIG_fail; } temps.push_back(lsbp); Py_DECREF(o); @@ -312,7 +348,7 @@ tensorflow::ImportNumpy(); literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - return NULL; + SWIG_fail; } $1 = literal_status.ValueOrDie().get(); } @@ -324,7 +360,7 @@ tensorflow::ImportNumpy(); %typemap(out) StatusOr< std::unique_ptr > { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); } @@ -332,7 +368,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector& (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -341,7 +377,7 @@ tensorflow::ImportNumpy(); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); Py_DECREF(o); @@ -355,7 +391,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::OpMetadataFromPyObject($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -367,7 +403,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -382,7 +418,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -396,7 +432,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector& (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -405,7 +441,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -416,7 +452,7 @@ tensorflow::ImportNumpy(); std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -428,7 +464,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -442,18 +478,18 @@ tensorflow::ImportNumpy(); PyObject* py_int = numpy::PyNumberToPyInt($input); if (!py_int) { PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - return NULL; + SWIG_fail; } const long value = numpy::PyIntOrPyLongToLong(py_int); if (value == -1 && PyErr_Occurred()) { Py_DECREF(py_int); - return NULL; + SWIG_fail; } if (!PrimitiveType_IsValid(value)) { PyErr_SetString( PyExc_TypeError, "Argument not valid for PrimitiveType enum"); Py_DECREF(py_int); - return NULL; + SWIG_fail; } $1 = static_cast(value); } @@ -464,19 +500,19 @@ tensorflow::ImportNumpy(); (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); if (!o) { - return NULL; + SWIG_fail; } PyObject* first = PyTuple_GetItem(o, 0); if (!first) { Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* first_pyint = numpy::PyNumberToPyInt(first); if (!first_pyint) { @@ -484,13 +520,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "First pair item cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* second = PyTuple_GetItem(o, 1); if (!second) { Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } PyObject* second_pyint = numpy::PyNumberToPyInt(second); if (!second_pyint) { @@ -499,21 +535,21 @@ tensorflow::ImportNumpy(); "Second pair item cannot be converted to int"); Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); if (first_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); if (second_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } temps.push_back(std::make_pair(first_value, second_value)); Py_DECREF(o); @@ -531,26 +567,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( $input, "lhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_contracting_dimensions); if (length == -1) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); if (!item) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -561,26 +597,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( $input, "rhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_contracting_dimensions); if (length == -1) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); if (!item) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -591,26 +627,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_batch_dimensions = PyObject_GetAttrString( $input, "lhs_batch_dimensions"); if (!lhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_batch_dimensions); if (length == -1) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); if (!item) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_batch_dimensions(dimension); Py_DECREF(item); @@ -621,26 +657,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_batch_dimensions = PyObject_GetAttrString( $input, "rhs_batch_dimensions"); if (!rhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_batch_dimensions); if (length == -1) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); if (!item) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_batch_dimensions(dimension); Py_DECREF(item); @@ -656,20 +692,20 @@ tensorflow::ImportNumpy(); (PaddingConfig padding_config) { PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); if (!dimensions) { - return NULL; + SWIG_fail; } int length = PySequence_Size(dimensions); if (length == -1) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(dimensions, i); if (!item) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } int64 edge_padding_low, edge_padding_high, interior_padding; if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) @@ -677,7 +713,7 @@ tensorflow::ImportNumpy(); || !GetIntAttr(item, "interior_padding", &interior_padding)) { Py_DECREF(item); Py_DECREF(dimensions); - return NULL; + SWIG_fail; } Py_DECREF(item); @@ -699,32 +735,32 @@ tensorflow::ImportNumpy(); int64 value; if (!GetIntAttr($input, "input_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_batch_dimension(value); if (!GetIntAttr($input, "input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_feature_dimension(value); if (!GetIntAttr($input, "output_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_batch_dimension(value); if (!GetIntAttr($input, "output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_input_feature_dimension(value); @@ -733,24 +769,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "input_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_input_spatial_dimensions(dimension); Py_DECREF(item); @@ -759,24 +795,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_kernel_spatial_dimensions(dimension); Py_DECREF(item); @@ -785,24 +821,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "output_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_output_spatial_dimensions(dimension); Py_DECREF(item); @@ -819,16 +855,37 @@ tensorflow::ImportNumpy(); if ($input == Py_None) { $1 = NULL; } else { - PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); - if (!o) { - return NULL; + if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { + build_options.set_generate_hlo_graph(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { + build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { + build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { + build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + + PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); + if (o == NULL) { + SWIG_fail; } if (o != Py_None) { - if (!PyString_Check(o)) { - PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); - return NULL; + if (!PyBool_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); + SWIG_fail; } - build_options.set_generate_hlo_graph(PyString_AsString(o)); + build_options.set_hlo_profile(o == Py_True); } Py_DECREF(o); @@ -841,7 +898,7 @@ tensorflow::ImportNumpy(); if (!statusor.ok()) { PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } build_options.set_result_layout(statusor.ValueOrDie()); } @@ -862,12 +919,17 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBufferTuple; +%unignore xla::swig::LocalShapedBufferTuple::Release; +%unignore xla::swig::LocalShapedBufferTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; @@ -886,6 +948,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Collapse; %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; %unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::SliceInDim; %unignore xla::swig::LocalComputationBuilder::DynamicSlice; %unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; %unignore xla::swig::LocalComputationBuilder::ConcatInDim; @@ -906,6 +969,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::LocalComputationBuilder::While; %unignore xla::swig::LocalComputationBuilder::Conditional; +%unignore xla::swig::LocalComputationBuilder::IsConstant; %unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::LocalComputationBuilder::Ge; @@ -927,10 +991,12 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; @@ -942,6 +1008,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ReciprocalF32; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DeleteCompiledLocalComputation; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 3d87480728aab1d4ebbc71c6c7504d37cae5edaf..68648a3a176363de69a56ecb8070f82862874e94 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -170,8 +170,7 @@ static string PyObjectCppStr(PyObject* o) { return ExtractStringAndDecref(s); } -// Safely returns a repr of the given Python object o as a C++ string. -static string PyObjectCppRepr(PyObject* o) { +string PyObjectCppRepr(PyObject* o) { PyObject* r = PyObject_Repr(o); return ExtractStringAndDecref(r); } @@ -182,16 +181,6 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { PyObjectCppRepr(o).c_str()); }; - auto get_attr = [o, &error](const string& field) -> StatusOr { - PyObject* result = - PyObject_GetAttrString(o, const_cast(field.c_str())); - if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to get attribute of Shape object:", field)); - } - return result; - }; - auto call_method = [o, &error](const string& method) -> StatusOr { PyObject* result = PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); @@ -203,12 +192,16 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { }; PyObject* np_type; - TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype")); + TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype")); if (np_type->ob_type != &PyArrayDescr_Type) { - return error("Shape attribute np_dtype is not an integer numpy dtype"); + return error( + "Return value of shape method numpy_dtype " + "is not an integer numpy dtype"); } if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { - return error("Shape attribute np_dtype is not a valid integer numpy dtype"); + return error( + "Return value of shape method numpy_dtype " + "is not a valid integer numpy dtype"); } const PrimitiveType element_type = NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); @@ -347,13 +340,13 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const Literal& literal) { +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM( - tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + PyTuple_SET_ITEM(tuple, i, + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } return tuple; } else { @@ -438,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index adfcc3b8588dce01718bb19dea936bace483be4d..64f0aae0f9790f0199ac6cb931a5c9f6dc356f4c 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -74,7 +74,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const Literal& literal); +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,7 +90,7 @@ StatusOr > XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array); template @@ -101,12 +101,16 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { } template -void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { +void CopyLiteralToNumpyArray(const LiteralSlice& literal, + PyArrayObject* py_array) { NativeT* dest = static_cast(PyArray_DATA(py_array)); auto source = literal.data(); std::copy(source.begin(), source.end(), dest); } +// Safely returns a repr of the given Python object o as a C++ string. +string PyObjectCppRepr(PyObject* o); + // Workarounds for Python 2 and 3 interop PyObject* LongToPyIntOrPyLong(long x); // NOLINT diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 9bda9d09294bc75acaa35d8e4a512820046e8920..c025127c3cf1871d4def1297ed36c046cae61d4b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,11 +28,12 @@ import numpy as np from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api +from tensorflow.compiler.xla.service import hlo_pb2 -# Most functions are snake_case for consistency with other modules, -# whereas method names of ComputationBuilder and LocalComputation are -# CamelCase for consistency with XLA. +# Most functions are snake_case for consistency with other modules, whereas +# method names of ComputationBuilder and LocalComputation are CamelCase for +# consistency with XLA. # pylint: disable=invalid-name @@ -88,10 +89,12 @@ _UNARY_OPS = [ 'Not', 'Abs', 'Exp', + 'Expm1', 'Floor', 'Round', 'Ceil', 'Log', + 'Log1p', 'Sign', 'Cos', 'Sin', @@ -123,24 +126,34 @@ _BINARY_OPS = [ 'Pow', ] + XLA_ELEMENT_TYPE_TO_DTYPE = { - xla_data_pb2.F32: np.dtype(np.float32), - xla_data_pb2.F64: np.dtype(np.float64), - xla_data_pb2.S32: np.dtype(np.int32), - xla_data_pb2.S64: np.dtype(np.int64), - xla_data_pb2.U32: np.dtype(np.uint32), - xla_data_pb2.U64: np.dtype(np.uint64), - xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.PRED: np.dtype('bool'), + xla_data_pb2.S8: np.dtype('int8'), + xla_data_pb2.S16: np.dtype('int16'), + xla_data_pb2.S32: np.dtype('int32'), + xla_data_pb2.S64: np.dtype('int64'), + xla_data_pb2.U8: np.dtype('uint8'), + xla_data_pb2.U16: np.dtype('uint16'), + xla_data_pb2.U32: np.dtype('uint32'), + xla_data_pb2.U64: np.dtype('uint64'), + xla_data_pb2.F16: np.dtype('float16'), + xla_data_pb2.F32: np.dtype('float32'), + xla_data_pb2.F64: np.dtype('float64'), + xla_data_pb2.C64: np.dtype('complex64'), xla_data_pb2.TUPLE: np.dtype(np.object), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(v): k - for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} +DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et + for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] class LocalBuffer(object): @@ -156,14 +169,14 @@ class LocalBuffer(object): self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_py(npval, layout_fn=None): - npval = require_numpy_array_layout(npval) + def from_pyval(pyval, layout_fn=None): + pyval = require_numpy_array_layout(pyval) if layout_fn: - shape = Shape.from_numpy(npval) + shape = Shape.from_pyval(pyval) shape = shape.map_leaves(layout_fn) else: shape = None - return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(pyval, shape)) def to_py(self): return self.c_local_shaped_buffer.ToLiteral() @@ -173,6 +186,14 @@ class LocalBuffer(object): self._delete(self.c_local_shaped_buffer) self.c_local_shaped_buffer = None + def destructure(self): + assert self.c_local_shaped_buffer is not None + result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + size = result.size() + destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size)) + return destructured + def is_deleted(self): return self.c_local_shaped_buffer is None @@ -181,53 +202,104 @@ class LocalBuffer(object): class Shape(object): - """XLA shape. + """Represents an XLA shape. - Represents an XLA shape by a corresponding Python/Numpy type and a - list of dimensions, which are themselves Shapes in case this one - represents an XLA tuple. + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + + Callers are expected to instantiate this class only via the static + constructors: tuple_shape, array_shape, and from_pyval. """ - def __init__(self, np_dtype, dimensions, minor_to_major=None): + @staticmethod + def tuple_shape(tuple_shapes): + """Construct a tuple shape.""" + if (not isinstance(tuple_shapes, (tuple, list)) or + not all(isinstance(t, Shape) for t in tuple_shapes)): + raise TypeError('tuple_shapes must be a tuple of Shapes') + return Shape(tuple_shapes, tuple) + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None): + """Construct an array shape.""" + if (not isinstance(dimensions, tuple) or + not all(isinstance(i, int) for i in dimensions)): + dimensions = tuple(int(i) for i in dimensions) + return Shape(dimensions, np.dtype(element_type), + minor_to_major=minor_to_major) + + @staticmethod + def from_pyval(pyval): + def convert(pyval): + if isinstance(pyval, tuple): + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + pyval = require_numpy_array_layout(pyval) + return Shape.array_shape(pyval.dtype, np.shape(pyval)) + return convert(pyval) + + def __init__(self, dimensions, dtype, minor_to_major=None): assert isinstance(dimensions, tuple) - self.np_dtype = np_dtype self._dimensions = dimensions + self._dtype = dtype + self._is_tuple = dtype == tuple self._minor_to_major = minor_to_major self._check_minor_to_major() def __eq__(self, other): # pylint: disable=protected-access - return (self.np_dtype == other.np_dtype and + return (self._dtype == other._dtype and self._dimensions == other._dimensions and self._minor_to_major == other._minor_to_major) def __repr__(self): - return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' - 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, - self._minor_to_major) - - def element_type(self): - return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' + '_is_tuple={!r}), _minor_to_major={!r}').format( + self._dtype, self._dimensions, self._is_tuple, + self._minor_to_major) def is_tuple(self): - return self.element_type() == xla_data_pb2.TUPLE + return self._is_tuple - def dimensions(self): - if self.is_tuple(): - raise ValueError('Tuple shape has no dimensions') - return self._dimensions - - def minor_to_major(self): - return self._minor_to_major + def is_array(self): + return not self._is_tuple def tuple_shapes(self): if not self.is_tuple(): - raise ValueError('Shape is not a tuple shape') + raise ValueError('not a tuple shape') + return self._dimensions + + def numpy_dtype(self): + """Like element_type(), but returns dtype('O') in case of a tuple shape.""" + if self.is_tuple(): + return np.dtype(np.object) + else: + return self.element_type() + + def xla_element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())] + + def element_type(self): + if not self.is_array(): + raise ValueError('not an array shape') + return self._dtype + + def dimensions(self): + if not self.is_array(): + raise ValueError('not an array shape') return self._dimensions def rank(self): return len(self.dimensions()) + def minor_to_major(self): + return self._minor_to_major + def map_leaves(self, f): """Map f over each leaf-level array subshape. @@ -240,7 +312,7 @@ class Shape(object): """ if self.is_tuple(): children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) - return Shape(np.dtype('O'), children) + return Shape.tuple_shape(children) else: mapped = f(self) return self if mapped is None else mapped @@ -254,44 +326,24 @@ class Shape(object): assert sorted(mtm) == range(len(mtm)), self def update_minor_to_major(self, minor_to_major): + if not self.is_array(): + raise ValueError('not an array shape') if not isinstance(minor_to_major, tuple): raise TypeError('minor_to_major must be a tuple') - updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) + updated = Shape.array_shape( + self.element_type(), self.dimensions(), minor_to_major) updated._check_minor_to_major() # pylint: disable=protected-access return updated - @staticmethod - def from_numpy(npval): - - def convert(npval): - if isinstance(npval, tuple): - return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) - else: - return Shape(npval.dtype, np.shape(npval)) - - return convert(require_numpy_array_layout(npval)) - def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] if element_type == xla_data_pb2.TUPLE: - dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) - return Shape(dtype, dims) - - -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] + shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims) + return Shape.tuple_shape(shapes) + else: + return Shape.array_shape(dtype, dims) def require_numpy_array_layout(value): @@ -310,6 +362,10 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None + self.dump_optimized_hlo_proto_to = None + self.dump_unoptimized_hlo_proto_to = None + self.dump_per_pass_hlo_proto_to = None + self.hlo_profile = False def transfer_to_infeed(value, replica_number=None): @@ -366,6 +422,17 @@ class LocalComputation(object): assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + def GetProto(self): + """Get the HloModuleProto proto object in this local computation. + + Returns: + An HloModuleProto proto object that has the whole-graph information. + """ + + serialized = self.c_local_computation.GetSerializedProto() + proto = hlo_pb2.HloModuleProto.FromString(serialized) + return proto + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. @@ -407,7 +474,7 @@ class LocalComputation(object): compile_options=None, layout_fn=None): return self.Compile( - argument_shapes=[Shape.from_numpy(arg) for arg in arguments], + argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, layout_fn=layout_fn) @@ -415,7 +482,7 @@ class LocalComputation(object): """Execute with Python values as arguments and return value.""" if not self.is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - argument_shapes = [Shape.from_numpy(arg) for arg in arguments] + argument_shapes = [Shape.from_pyval(arg) for arg in arguments] if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes @@ -477,9 +544,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -487,9 +554,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -499,10 +564,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -511,7 +576,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -522,7 +587,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -533,7 +598,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -544,7 +609,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -555,7 +620,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -571,15 +636,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -591,23 +655,22 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( - Shape.from_numpy(value), name=name, parameter_num=parameter_num) + Shape.from_pyval(value), name=name, parameter_num=parameter_num) def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -617,10 +680,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -630,14 +692,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -649,27 +709,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): - """Reshape op.""" - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + """Enqueues a reshape op onto the computation. + + Args: + operand: LocalOp representing the array to be reshaped. + dimensions: sequence of integers encoding the order in which dimensions + are collapsed or None, in which case dimensions are flattened in order. + new_sizes: sequence of integers encoding the new dimension sizes (shape). + + Returns: + A LocalOp representing the added Reshape op. + """ + if dimensions is None: + ndim = len(self.GetShape(operand).dimensions()) + dimensions = tuple(range(ndim)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -678,67 +746,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -746,17 +803,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -765,191 +818,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), - start_indices, - limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) + + def SliceInDim(self, operand, start_index, limit_index, stride, dimno): + """Enqueues a slice-in-dimension operation onto the computation. + + Args: + operand: LocalOp for the N dimensional array to be sliced. + start_index: an integer containing the start index of the slice. + limit_index: an integer containing the end index of the slice. + stride: an integer containing the stride size for the slice. + dimno: an integer indicating the dimension along which to slice. + + Returns: + A LocalOp representing the added Slice op. + """ + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ - shape = Shape(self.GetShape(mu).np_dtype, dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ - shape = Shape(self.GetShape(a).np_dtype, dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + shape = Shape.array_shape(self.GetShape(a).element_type(), dims) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -957,98 +996,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. + """ + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) + + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. + + Args: + operand: a ComputationDataHandle to test. + + Returns: bool indicating whether `operand` is a compile-time constant, + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. + """ + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1058,14 +1104,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1082,6 +1123,61 @@ class ComputationBuilder(object): dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers + def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a + triple (lhs_spec, rhs_spec, out_spec) where each element is a string of + length N+2 identifying by position (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. + + Returns: a LocalOp representing the ConvGenralDilated operation. + """ + if not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. @@ -1095,15 +1191,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c9d09cd5d57e001fd48d2dba9f2b0ee18374231b..71e1d60a4e23dbfef333223c396e109533da9365 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -164,6 +164,16 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testGetProto(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + built = c.Build() + proto = built.GetProto() # HloModuleProto + self.assertTrue(len(proto.computations) == 1) + self.assertTrue(len(proto.computations[0].instructions) == 3) + def testSum2DF64(self): c = self._NewComputation() c.Add( @@ -319,7 +329,7 @@ class LocalBufferTest(LocalComputationTest): def _Execute(self, c, arguments): compiled_c = c.Build().CompileWithExampleArguments(arguments) - arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments] + arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments] result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) return result_buffer.to_py() @@ -350,11 +360,60 @@ class LocalBufferTest(LocalComputationTest): c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) arg = NumpyArrayF32(1.11) compiled_c = c.Build().CompileWithExampleArguments([arg]) - arg_buffer = xla_client.LocalBuffer.from_py(arg) + arg_buffer = xla_client.LocalBuffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + def testDestructureTupleEmpty(self): + t = () + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 0) + + def testDestructureTupleOneArrayElement(self): + t = (np.array([1, 2, 3, 4], dtype=np.int32),) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 1) + array = pieces[0] + got = array.to_py() + want = NumpyArrayS32([1, 2, 3, 4]) + np.testing.assert_equal(want, got) + + def testDestructureTupleTwoArrayElementDifferentType(self): + t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32)) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + got = array0.to_py() + want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) + np.testing.assert_equal(want, got) + got = array1.to_py() + want = NumpyArrayS32([2, 3, 4, 5]) + np.testing.assert_equal(want, got) + + def testDestructureTupleNested(self): + t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + tuple0, array1 = pieces + got = array1.to_py() + want = NumpyArrayS32([5]) + np.testing.assert_equal(want, got) + got = tuple0.to_py() + self.assertEqual(type(got), tuple) + self.assertEqual(len(got), 2) + np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) + np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + class SingleOpTest(LocalComputationTest): """Tests for single ops. @@ -509,6 +568,46 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=result) + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = ("NHWC", "OIHW", "CWNH") + c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), + c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) @@ -521,6 +620,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Expm1(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -533,6 +638,12 @@ class SingleOpTest(LocalComputationTest): c.Log(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.log(arr)) + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log1p(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -762,6 +873,23 @@ class SingleOpTest(LocalComputationTest): [3, 2]) self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + def testSliceInDim(self): + c = self._NewComputation() + c.SliceInDim( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) + c.SliceInDim( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) + def testDynamicSlice(self): c = self._NewComputation() c.DynamicSlice( @@ -838,6 +966,17 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testIsConstant(self): + c = self._NewComputation() + a = c.ConstantS32Scalar(3) + b = c.ConstantS32Scalar(1) + x = c.ParameterFromNumpy(NumpyArrayS32(0)) + const_expr = c.Sub(b, a) + non_const_expr = c.Mul(const_expr, x) + self.assertTrue(c.IsConstant(const_expr)) + self.assertFalse(c.IsConstant(non_const_expr)) + # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" @@ -1132,7 +1271,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose( c, expected=np.sum(input_array, axis=tuple(dims))) - _ReduceAndTest(0) _ReduceAndTest(0) _ReduceAndTest(0, 1) _ReduceAndTest(0, 2) @@ -1260,7 +1398,7 @@ class EmbeddedComputationsTest(LocalComputationTest): def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - c.Infeed(xla_client.Shape.from_numpy(to_infeed[0])) + c.Infeed(xla_client.Shape.from_pyval(to_infeed[0])) compiled_c = c.Build().CompileWithExampleArguments() for item in to_infeed: xla_client.transfer_to_infeed(item) @@ -1272,7 +1410,7 @@ class EmbeddedComputationsTest(LocalComputationTest): def testInfeedThenOutfeedS32(self): to_round_trip = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0])) + x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0])) c.Outfeed(x) compiled_c = c.Build().CompileWithExampleArguments() @@ -1282,7 +1420,7 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( - xla_client.Shape.from_numpy(to_round_trip[0])) + xla_client.Shape.from_pyval(to_round_trip[0])) execution.join() self.assertEqual(want, got) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a9acdae380af5b7f9efb3d08302fc717108f5e40..c289c84cff743871a7126cb932d6cda823ceb696 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,29 +30,23 @@ limitations under the License. namespace xla { -/* static */ std::unique_ptr> ReferenceUtil::TransposeArray2D( - const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); - for (int64 w = 0; w < operand.width(); ++w) { - for (int64 h = 0; h < operand.height(); ++h) { - (*result)(w, h) = operand(h, w); - } - } - - return result; -} +namespace { -/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( - const Array2D& lhs, const Array2D& rhs) { +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { CHECK_EQ(lhs.width(), rhs.height()); int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = MakeUnique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). - __xla_cpu_runtime_EigenSingleThreadedMatMulF32( + impl_fn( /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, k, /*transpose_lhs=*/0, @@ -60,22 +54,24 @@ namespace xla { return result; } +} // namespace + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} + /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = MakeUnique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( @@ -94,7 +90,7 @@ namespace xla { Padding padding) { return ConvArray3DGeneralDimensionsDilated( lhs, rhs, kernel_stride, padding, 1, 1, - ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); + XlaBuilder::CreateDefaultConvDimensionNumbers(1)); } /*static*/ std::unique_ptr> @@ -144,7 +140,7 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( std::pair kernel_stride, Padding padding) { return ConvArray4DGeneralDimensions( lhs, rhs, kernel_stride, padding, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); } /* static */ std::unique_ptr> @@ -188,18 +184,6 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); } -/* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, - const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { - std::vector dim_lengths{static_cast(operand.size())}; - return ReduceWindow1DGeneric( - operand, init, reduce_func, window, stride, - xla::MakePadding(dim_lengths, window, stride, padding)); -} - /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& operand, float init, @@ -239,23 +223,28 @@ ReferenceUtil::ReduceWindow1DAdd( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; - return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, - padding); + std::vector dim_lengths{static_cast(operand.size())}; + return ReduceWindow1DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); } -/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, + const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { std::vector dim_lengths{operand.height(), operand.width()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0], window_counts[1]); @@ -271,7 +260,7 @@ ReferenceUtil::ReduceWindow1DAdd( if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && i0_base + i0_win < operand.n1() && i1_base + i1_win < operand.n2()) { - val += operand(i0_base + i0_win, i1_base + i1_win); + val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win)); } } } @@ -281,6 +270,17 @@ ReferenceUtil::ReduceWindow1DAdd( return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( + const Array2D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + std::vector dim_lengths{operand.height(), operand.width()}; + return ReduceWindow2DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -472,7 +472,7 @@ ReferenceUtil::SelectAndScatter4DGePlus( i3_base + i3_win < operand.n4()) { float tmp = operand(i0_base + i0_win, i1_base + i1_win, i2_base + i2_win, i3_base + i3_win); - if (tmp >= val) { + if (tmp > val) { val = tmp; scatter_0 = i0_base + i0_win; scatter_1 = i1_base + i1_win; @@ -572,7 +572,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module("ReferenceUtil"); + HloModuleConfig config; + HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 3ec96f2f38b8f91e1549419b60481327fa9bbd5f..8fa6961d197dce519cf151283b8bc0836a4615c0 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -39,10 +39,22 @@ namespace xla { class ReferenceUtil { public: // Returns the result of a transpose operation on the input matrix. - static std::unique_ptr> TransposeArray2D( - const Array2D& operand); + template + static std::unique_ptr> TransposeArray2D( + const Array2D& operand) { + auto result = MakeUnique>(operand.width(), operand.height()); + for (int64 w = 0; w < operand.width(); ++w) { + for (int64 h = 0; h < operand.height(); ++h) { + (*result)(w, h) = operand(h, w); + } + } + + return result; + } // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( @@ -187,9 +199,10 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); - static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); + static std::unique_ptr> ReduceWindow2DGeneric( + const Array2D& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, @@ -215,6 +228,7 @@ class ReferenceUtil { // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. + // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, const tensorflow::gtl::ArraySlice& window, @@ -251,9 +265,9 @@ class ReferenceUtil { const Array3D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 3); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; for (int i = 0; i < 3; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -285,9 +299,9 @@ class ReferenceUtil { const Array4D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 4); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; for (int i = 0; i < 4; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -316,13 +330,14 @@ class ReferenceUtil { return result; } - // Slices with modulo-wrapping. + // Slices with index clamping template - static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, - int64 start, int64 size) { + static std::vector ClampSlice1D( + const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { - result.push_back(input[(start + i) % input.size()]); + result.push_back(input[(start + i)]); } return result; } @@ -538,12 +553,11 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 3); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()}; + int64 pad_low[3]; + int64 pad_high[3]; + int64 pad_interior[3]; + int64 output_bounds[3]; for (int64 i = 0; i < 3; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); @@ -559,7 +573,7 @@ class ReferenceUtil { Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; + int indices[] = {0, 0, 0}; for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { @@ -597,12 +611,12 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 4); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + int64 pad_low[4]; + int64 pad_high[4]; + int64 pad_interior[4]; + int64 output_bounds[4]; for (int64 i = 0; i < 4; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1775666652f303ee095a11405537106a3eb9b056 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -0,0 +1,79 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/compiler/xla:xla.bzl", + "xla_proto_library", + "xla_py_grpc_library", +) + +xla_proto_library( + name = "xla_service_proto", + srcs = ["xla_service.proto"], + use_grpc_plugin = True, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + ], +) + +cc_library( + name = "grpc_stub", + srcs = ["grpc_stub.cc"], + hdrs = ["grpc_stub.h"], + deps = [ + ":xla_service_proto", + "//tensorflow/compiler/xla:service_interface", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + ], +) + +tf_cc_binary( + name = "grpc_service_main_cpu", + srcs = ["grpc_service_main.cc"], + deps = [ + ":grpc_service", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@grpc//:grpc++", + ], +) + +tf_cc_test( + name = "grpc_client_test", + srcs = ["grpc_client_test.cc"], + data = [ + "//tensorflow/compiler/xla/rpc:grpc_service_main_cpu", + ], + deps = [ + ":grpc_stub", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@grpc//:grpc++", + ], +) + +cc_library( + name = "grpc_service", + srcs = ["grpc_service.cc"], + hdrs = ["grpc_service.h"], + deps = [ + ":xla_service_proto", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "@grpc//:grpc++", + ], +) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple C++ test to exercise the GRPC capabilities of XLA. +// +// Launches an RPC service in a subprocess and connects to it over a socket +// using an RPCStub. +#include +#include + +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/rpc/grpc_stub.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/net.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class GRPCClientTestBase : public ::testing::Test { + protected: + GRPCClientTestBase() { + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + string service_main_path = tensorflow::io::JoinPath( + test_srcdir, "compiler/xla/rpc/grpc_service_main_cpu"); + int port = tensorflow::internal::PickUnusedPortOrDie(); + subprocess_.SetProgram( + service_main_path, + {service_main_path, tensorflow::strings::Printf("--port=%d", port)}); + subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT, + tensorflow::ACTION_DUPPARENT); + subprocess_.SetChannelAction(tensorflow::CHAN_STDERR, + tensorflow::ACTION_DUPPARENT); + CHECK(subprocess_.Start()); + LOG(INFO) << "Launched subprocess"; + + auto channel = + ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port), + ::grpc::InsecureChannelCredentials()); + channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); + LOG(INFO) << "Channel to server is connected on port " << port; + + xla_service_ = grpc::XlaService::NewStub(channel); + stub_.reset(new GRPCStub(xla_service_.get())); + client_.reset(new Client(stub_.get())); + } + + ~GRPCClientTestBase() override { + LOG(INFO) << "Killing subprocess"; + subprocess_.Kill(SIGKILL); + } + + tensorflow::SubProcess subprocess_; + std::unique_ptr xla_service_; + std::unique_ptr stub_; + std::unique_ptr client_; +}; + +TEST_F(GRPCClientTestBase, ItsAlive) { + ASSERT_NE(xla_service_, nullptr); + ASSERT_NE(stub_, nullptr); + ASSERT_NE(client_, nullptr); +} + +TEST_F(GRPCClientTestBase, AxpyTenValues) { + XlaBuilder builder("axpy_10"); + auto alpha = builder.ConstantR0(3.1415926535); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = builder.ConstantR1( + {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = builder.Mul(alpha, x); + auto axpy = builder.Add(ax, y); + + std::vector expected = { + 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, + 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; + std::unique_ptr expected_literal = + Literal::CreateR1(expected); + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( + computation, {}, nullptr)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + ErrorSpec(0.0001))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e1435fa30a24c320ddbedb84d37b369a3158a54 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/rpc/grpc_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace xla { + +/* static */ StatusOr> GRPCService::NewService( + se::Platform* platform) { + std::unique_ptr grpc_service(new GRPCService()); + TF_ASSIGN_OR_RETURN(grpc_service->service_, + ::xla::Service::NewService(platform)); + return std::move(grpc_service); +} + +::grpc::Status DelegateRPC(std::function op) { + Status s = op(); + return tensorflow::ToGrpcStatus(s); +} + +::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, + const UnregisterRequest* arg, + UnregisterResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Unregister(arg, result); }); +} + +::grpc::Status GRPCService::DeconstructTuple(::grpc::ServerContext* context, + const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->DeconstructTuple(arg, result); + }); +} + +::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); +} + +::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->WaitForExecution(arg, result); + }); +} + +::grpc::Status GRPCService::TransferToClient(::grpc::ServerContext* context, + const TransferToClientRequest* arg, + TransferToClientResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferToClient(arg, result); + }); +} + +::grpc::Status GRPCService::TransferToServer(::grpc::ServerContext* context, + const TransferToServerRequest* arg, + TransferToServerResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferToServer(arg, result); + }); +} + +::grpc::Status GRPCService::TransferToInfeed(::grpc::ServerContext* context, + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferToInfeed(arg, result); + }); +} + +::grpc::Status GRPCService::TransferFromOutfeed( + ::grpc::ServerContext* context, const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferFromOutfeed(arg, result); + }); +} + +::grpc::Status GRPCService::ResetDevice(::grpc::ServerContext* context, + const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ResetDevice(arg, result); }); +} + +::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, + const GetShapeRequest* arg, + GetShapeResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->GetShape(arg, result); }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h new file mode 100644 index 0000000000000000000000000000000000000000..ca1b09b648013ad45d806040c5ddcf11d9e5604e --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ + +#include "grpcpp/server_context.h" +#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" +#include "tensorflow/compiler/xla/service/service.h" + +namespace xla { + +// Service implementation which wraps a XLA Service with a GRPC interface. +class GRPCService : public grpc::XlaService::Service { + public: + // Factory for creating a RPCService. The parameter platform is the platform + // that the service should target. If platform is null then the default + // platform is used. + static StatusOr> NewService( + se::Platform* platform = nullptr); + + ::grpc::Status Unregister(::grpc::ServerContext* context, + const UnregisterRequest* arg, + UnregisterResponse* result) override; + + ::grpc::Status DeconstructTuple(::grpc::ServerContext* context, + const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; + + ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; + + ::grpc::Status WaitForExecution(::grpc::ServerContext* context, + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; + + ::grpc::Status TransferToClient(::grpc::ServerContext* context, + const TransferToClientRequest* arg, + TransferToClientResponse* result) override; + + ::grpc::Status TransferToServer(::grpc::ServerContext* context, + const TransferToServerRequest* arg, + TransferToServerResponse* result) override; + + ::grpc::Status TransferToInfeed(::grpc::ServerContext* context, + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; + + ::grpc::Status TransferFromOutfeed( + ::grpc::ServerContext* context, const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; + + ::grpc::Status ResetDevice(::grpc::ServerContext* context, + const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; + + ::grpc::Status GetShape(::grpc::ServerContext* context, + const GetShapeRequest* arg, + GetShapeResponse* result) override; + + private: + std::unique_ptr<::xla::Service> service_; + + GRPCService() {} + GRPCService(const GRPCService&) = delete; + void operator=(const GRPCService&) = delete; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..c68c857c304138ff4318e243f66547c6acce1005 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic server binary that exposes a xla::Service through a GRPC interface +// on a configurable port. +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "tensorflow/compiler/xla/rpc/grpc_service.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace { + +int RealMain(int argc, char** argv) { + int32 port = 1685; + std::vector flag_list = { + tensorflow::Flag("port", &port, "port to listen on"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parsed_values_ok) { + LOG(ERROR) << usage; + return 2; + } + tensorflow::port::InitMain(argv[0], &argc, &argv); + + std::unique_ptr service = + xla::GRPCService::NewService().ConsumeValueOrDie(); + + ::grpc::ServerBuilder builder; + string server_address(tensorflow::strings::Printf("localhost:%d", port)); + + builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); + + LOG(INFO) << "Server listening on " << server_address; + server->Wait(); + + return 0; +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { return xla::RealMain(argc, argv); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b8ab158e1396d7087a407be180ab44d2e16e121 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/rpc/grpc_stub.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace xla { + +GRPCStub::~GRPCStub() = default; + +Status MakeRPC( + const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { + ::grpc::ClientContext context; + ::grpc::Status s = rpc_method(&context); + return tensorflow::FromGrpcStatus(s); +} + +Status GRPCStub::TransferToClient(const TransferToClientRequest* request, + TransferToClientResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferToClient(context, *request, response); + }); +} + +Status GRPCStub::TransferToServer(const TransferToServerRequest* request, + TransferToServerResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferToServer(context, *request, response); + }); +} + +Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferToInfeed(context, *request, response); + }); +} + +Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferFromOutfeed(context, *request, response); + }); +} + +Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ResetDevice(context, *request, response); + }); +} + +Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ExecuteGraph(context, *request, response); + }); +} + +Status GRPCStub::ExecuteGraphParallel( + const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ExecuteGraphParallel(context, *request, response); + }); +} + +Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->WaitForExecution(context, *request, response); + }); +} + +Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->DeconstructTuple(context, *request, response); + }); +} + +Status GRPCStub::GetComputationGraphStats( + const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetComputationGraphStats(context, *request, response); + }); +} + +Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetShape(context, *request, response); + }); +} + +Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetDeviceHandles(context, *request, response); + }); +} + +Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->CreateChannelHandle(context, *request, response); + }); +} + +Status GRPCStub::ComputeConstantGraph( + const ComputeConstantGraphRequest* request, + ComputeConstantResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ComputeConstantGraph(context, *request, response); + }); +} + +// Methods used by GlobalData. +Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Unregister(context, *request, response); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h new file mode 100644 index 0000000000000000000000000000000000000000..8dfcb761387d608abbb1f62974f49b976a7ff7ff --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_STUB_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_STUB_H_ + +#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" +#include "tensorflow/compiler/xla/service_interface.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +class GRPCStub : public ServiceInterface { + public: + explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} + ~GRPCStub() override; + + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; + + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; + + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; + + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; + + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; + + Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; + + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; + + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; + + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; + + Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; + + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; + + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; + + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; + + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; + + // Methods used by GlobalData. + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; + + grpc::XlaService::Stub* service() { return grpc_stub_; } + + private: + grpc::XlaService::Stub* grpc_stub_; + + TF_DISALLOW_COPY_AND_ASSIGN(GRPCStub); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_GRPC_STUB_H_ diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto new file mode 100644 index 0000000000000000000000000000000000000000..551ae895e05586daec0ffcd425f4950f76bdd50d --- /dev/null +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA service API. +// +// Users 1) build up computations and 2) create allocations via this API. +// Computations are composed of data flowing between arbitrarily-sized +// vector-oriented operations. +// +// Users build up computations using a ComputationHandle, and talk about +// allocations using GlobalDataHandles. +// +// There are currently no checkpointing capabilities or distribution/replication +// guarantees. The service runs on a single machine (e.g. one task) and that is +// its failure domain. +// +// Canonical example of "alpha * X + Y": +// * Make a computation. +// * Add alpha and X and Y as parameters. +// * Request the multiplication of alpha and X. +// * Request the addition of that result and Y. +// +// Then, pass the computation and appropriately shaped inputs to the XLA +// service's Execute method, which provides a result as a GlobalDataHandle. +// +// All data in XLA computations are conceptually immutable. +// +// Note: this API is subject to change / refinement over time -- use the +// provided client libraries to insulate code from changes to this service API. + +syntax = "proto3"; + +import "tensorflow/compiler/xla/xla.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +package xla; + +service XlaService { + ///////////////////////// + // Global data requests + + // Unregisters a global allocation. + // + // If the handle given is not currently allocated, a NOT_FOUND status is + // returned. + rpc Unregister(UnregisterRequest) returns (UnregisterResponse) { + } + + // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each + // element in the tuple. + rpc DeconstructTuple(DeconstructTupleRequest) + returns (DeconstructTupleResponse) { + } + + // Unpack requests that a global data handle, with a tuple shape, has global + // data handles created for each of its constituent members. This is the + // equivalent of the "destructuring assignment" present in various programming + // languages. + rpc Unpack(UnpackRequest) returns (UnpackResponse) { + } + + // Requests the shape of the referenced global data. + rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { + } + + // Requests the statistics of the given computation. + rpc GetComputationGraphStats(ComputationGraphStatsRequest) + returns (ComputationStatsResponse) { + } + + // Loads a variable number of values with a given element type from ColumnIO. + rpc LoadData(LoadDataRequest) returns (LoadDataResponse) { + } + + // Transfers the given global data to the client in the form of a Literal. + rpc TransferToClient(TransferToClientRequest) + returns (TransferToClientResponse) { + } + + // Transfers the given literal to the server to be stored in a global + // allocation, which is returned. + rpc TransferToServer(TransferToServerRequest) + returns (TransferToServerResponse) { + } + + // Transfers the given literal to the Infeed buffer of the device. + rpc TransferToInfeed(TransferToInfeedRequest) + returns (TransferToInfeedResponse) { + } + + // Transferred literal from the Outfeed buffer of the device. + rpc TransferFromOutfeed(TransferFromOutfeedRequest) + returns (TransferFromOutfeedResponse) { + } + + // Resets the device, clearing all existing state on the device. + rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { + } + + // Computes the value of a constant expression. The request contains the + // computation graph for the constant expression. + rpc ComputeConstantGraph(ComputeConstantGraphRequest) + returns (ComputeConstantResponse) { + } + + // Requests one or more device handles from the target. The returned device + // handles can be used to specify the device on which to execute computations + // or transfer data. + rpc GetDeviceHandles(GetDeviceHandlesRequest) + returns (GetDeviceHandlesResponse) { + } + + // Creates a channel handle that can be used to transfer data between + // two computations via a pair of Send and Recv instructions. + rpc CreateChannelHandle(CreateChannelHandleRequest) + returns (CreateChannelHandleResponse) { + } + + // Invokes the provided computation with the provided global data passed as + // immutable arguments. The request contains the whole computation graph. + // Returns global data output and execution timing. + rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) { + } + + // Invokes the provided list of computations in parallel with the provided + // global data for each computation. Returns a list of global data output and + // execution timing. + rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) + returns (ExecuteParallelResponse) { + } + + // Waits until the given execution (aysnchronously launched) is complete, and + // returns the global data output. + rpc WaitForExecution(WaitForExecutionRequest) + returns (WaitForExecutionResponse) { + } +} diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4a076ac0909d24f6c7355a323d4b78151d3fe2ac..2942edbf71f29304ebb240f0547808ae0af1ac87 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -12,21 +12,26 @@ package_group( ], ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( - name = "session_proto", - srcs = ["session.proto"], + name = "hlo_proto", + srcs = ["hlo.proto"], visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) -xla_proto_library( - name = "hlo_proto", +tf_proto_library_py( + name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], - deps = ["//tensorflow/compiler/xla:xla_data_proto"], + visibility = ["//visibility:public"], ) xla_proto_library( @@ -106,6 +111,7 @@ tf_cc_test( ":bfloat16_normalization", ":bfloat16_support", ":hlo", + ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -118,6 +124,42 @@ tf_cc_test( ], ) +cc_library( + name = "bfloat16_propagation", + srcs = ["bfloat16_propagation.cc"], + hdrs = ["bfloat16_propagation.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_dce", + ":hlo_pass", + ":tuple_simplifier", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_propagation_test", + srcs = ["bfloat16_propagation_test.cc"], + deps = [ + ":bfloat16_propagation", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -163,7 +205,22 @@ tf_cc_test( cc_library( name = "hlo_evaluator", - srcs = ["hlo_evaluator.cc"], + srcs = [ + "hlo_evaluator.cc", + "hlo_evaluator_typed_visitor.h", + "hlo_evaluator_typed_visitor_bfloat16.cc", + "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex64.cc", + "hlo_evaluator_typed_visitor_double.cc", + "hlo_evaluator_typed_visitor_float.cc", + "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int32.cc", + "hlo_evaluator_typed_visitor_int64.cc", + "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint32.cc", + "hlo_evaluator_typed_visitor_uint64.cc", + "hlo_evaluator_typed_visitor_uint8.cc", + ], hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", @@ -196,7 +253,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -212,6 +269,7 @@ cc_library( "dfs_hlo_visitor.cc", "hlo_computation.cc", "hlo_instruction.cc", + "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", "hlo_sharding.cc", @@ -219,18 +277,21 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "hlo_clone_context.h", "hlo_computation.h", + "hlo_domain_metadata.h", "hlo_instruction.h", + "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", "hlo_sharding.h", ], deps = [ + ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", ":hlo_reachability", ":name_uniquer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -243,11 +304,52 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], ) +tf_cc_test( + name = "dfs_hlo_visitor_with_default_test", + srcs = ["dfs_hlo_visitor_with_default_test.cc"], + deps = [ + ":hlo", + ":hlo_runner", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "pattern_matcher", + hdrs = ["pattern_matcher.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "pattern_matcher_test", + srcs = ["pattern_matcher_test.cc"], + deps = [ + ":hlo", + ":pattern_matcher", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_reachability", srcs = ["hlo_reachability.cc"], @@ -281,7 +383,9 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -290,33 +394,25 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) -cc_library( - name = "versioned_computation_handle", - srcs = ["versioned_computation_handle.cc"], - hdrs = ["versioned_computation_handle.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -333,6 +429,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -437,45 +534,6 @@ tf_cc_test( ], ) -cc_library( - name = "user_computation", - srcs = ["user_computation.cc"], - hdrs = ["user_computation.h"], - deps = [ - ":hlo", - ":session_proto", - ":shape_inference", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "user_computation_test", - srcs = ["user_computation_test.cc"], - deps = [ - ":hlo_matchers", - ":user_computation", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -521,10 +579,8 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", - ":compilation_cache", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":execution_tracker", @@ -535,11 +591,8 @@ cc_library( ":hlo_module_config", ":hlo_proto_util", ":platform_util", - ":session_proto", ":source_map_util", ":transfer_manager", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -566,7 +619,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":hlo", @@ -575,8 +627,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -586,6 +636,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -599,7 +650,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":platform_util", ":service", "//tensorflow/compiler/xla:status_macros", @@ -664,6 +714,23 @@ cc_library( ], ) +tf_cc_test( + name = "shaped_buffer_test", + srcs = ["shaped_buffer_test.cc"], + deps = [ + ":cpu_plugin", + ":device_memory_allocator", + ":platform_util", + ":shaped_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:ptr_util", + "//tensorflow/core:test", + ], +) + cc_library( name = "executable", srcs = ["executable.cc"], @@ -675,13 +742,11 @@ cc_library( ":computation_layout", ":device_memory_allocator", ":hlo", - ":hlo_cost_analysis", ":hlo_execution_profile", ":hlo_graph_dumper", + ":hlo_proto", ":pool", - ":session_proto", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -701,6 +766,7 @@ cc_library( srcs = ["compiler.cc"], hdrs = ["compiler.h"], deps = [ + ":buffer_value", ":executable", ":hlo", ":hlo_module_config", @@ -757,7 +823,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -777,34 +842,12 @@ cc_library( ], ) -cc_library( - name = "computation_tracker", - srcs = ["computation_tracker.cc"], - hdrs = ["computation_tracker.h"], - deps = [ - ":hlo", - ":hlo_module_config", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -838,33 +881,6 @@ tf_cc_test( ], ) -cc_library( - name = "liveness_util", - srcs = ["liveness_util.cc"], - hdrs = ["liveness_util.h"], - deps = [ - ":hlo", - ":hlo_dataflow_analysis", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - ], -) - -tf_cc_test( - name = "liveness_util_test", - srcs = ["liveness_util_test.cc"], - deps = [ - ":hlo", - ":liveness_util", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "buffer_liveness", srcs = [ @@ -876,7 +892,6 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -913,6 +928,7 @@ cc_library( ], deps = [ ":buffer_liveness", + ":buffer_value_containers", ":heap_simulator", ":hlo", ":hlo_proto", @@ -935,8 +951,8 @@ tf_cc_test( srcs = ["buffer_assignment_test.cc"], deps = [ ":buffer_assignment", + ":buffer_value", ":call_graph", - ":computation_tracker", ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", @@ -950,6 +966,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -966,7 +983,6 @@ cc_library( ":hlo_dataflow_analysis", ":hlo_proto", ":hlo_value", - ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -988,9 +1004,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -999,11 +1015,11 @@ cc_library( srcs = ["heap_simulator.cc"], hdrs = ["heap_simulator.h"], deps = [ + ":buffer_value", + ":buffer_value_containers", ":hlo", ":hlo_ordering", ":hlo_proto", - ":liveness_util", - ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1015,10 +1031,11 @@ tf_cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ + ":buffer_value", ":heap_simulator", ":hlo", ":hlo_ordering", - ":logical_buffer", + ":hlo_value", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", @@ -1028,6 +1045,38 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_module_group_metadata", + srcs = ["hlo_module_group_metadata.cc"], + hdrs = ["hlo_module_group_metadata.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_module_group_util", + srcs = ["hlo_module_group_util.cc"], + hdrs = ["hlo_module_group_util.h"], + deps = [ + ":hlo", + ":hlo_module_group_metadata", + ":hlo_reachability", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], @@ -1051,12 +1100,14 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ + ":buffer_value", ":hlo", ":hlo_ordering", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1091,11 +1142,55 @@ tf_cc_test( deps = [ ":hlo_matchers", ":instruction_fusion", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_creation_utils", + srcs = ["hlo_creation_utils.cc"], + hdrs = ["hlo_creation_utils.h"], + deps = [ + ":hlo", + ":shape_inference", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + ], +) + +tf_cc_test( + name = "hlo_creation_utils_test", + srcs = ["hlo_creation_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_evaluator", + ":hlo_matchers", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "batchnorm_expander", srcs = ["batchnorm_expander.cc"], @@ -1103,19 +1198,30 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - ":hlo_query", - ":shape_inference", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) +cc_library( + name = "gather_expander", + srcs = ["gather_expander.cc"], + hdrs = ["gather_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1143,9 +1249,10 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_pass", ":hlo_query", - ":shape_inference", + ":pattern_matcher", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1172,6 +1279,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -1180,28 +1288,112 @@ tf_cc_test( ) cc_library( - name = "while_loop_simplifier", - srcs = ["while_loop_simplifier.cc"], - hdrs = ["while_loop_simplifier.h"], + name = "batch_dot_simplification", + srcs = ["batch_dot_simplification.cc"], + hdrs = ["batch_dot_simplification.h"], deps = [ - ":call_inliner", ":hlo", - ":hlo_evaluator", + ":hlo_creation_utils", ":hlo_pass", - "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) tf_cc_test( - name = "while_loop_simplifier_test", - srcs = ["while_loop_simplifier_test.cc"], + name = "batch_dot_simplification_test", + srcs = ["batch_dot_simplification_test.cc"], deps = [ + ":batch_dot_simplification", + ":hlo", ":hlo_matchers", - ":while_loop_simplifier", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/core:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "gather_expander_test", + srcs = ["gather_expander_test.cc"], + deps = [ + ":gather_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:test_macros_header", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + ], +) + +cc_library( + name = "conditional_simplifier", + srcs = ["conditional_simplifier.cc"], + hdrs = ["conditional_simplifier.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "conditional_simplifier_test", + srcs = ["conditional_simplifier_test.cc"], + deps = [ + ":conditional_simplifier", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "while_loop_simplifier", + srcs = ["while_loop_simplifier.cc"], + hdrs = ["while_loop_simplifier.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_simplifier_test", + srcs = ["while_loop_simplifier_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_simplifier", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -1434,6 +1626,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -1444,23 +1637,20 @@ tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ - ":computation_tracker", ":cpu_plugin", ":hlo", ":hlo_cost_analysis", ":local_service", ":service", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1490,8 +1680,10 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -1541,11 +1733,38 @@ tf_cc_test( ], ) +cc_library( + name = "buffer_value", + srcs = ["buffer_value.cc"], + hdrs = ["buffer_value.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "buffer_value_containers", + hdrs = ["buffer_value_containers.h"], + deps = [ + ":buffer_value", + ":logical_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], hdrs = ["logical_buffer.h"], deps = [ + ":buffer_value", ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:shape_util", @@ -1561,6 +1780,7 @@ cc_library( srcs = ["hlo_value.cc"], hdrs = ["hlo_value.h"], deps = [ + ":buffer_value", ":hlo", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -1613,6 +1833,44 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_liveness_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -1724,20 +1982,6 @@ tf_cc_test( ], ) -cc_library( - name = "compilation_cache", - srcs = ["compilation_cache.cc"], - hdrs = ["compilation_cache.h"], - deps = [ - ":executable", - ":hlo_module_config", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "layout_assignment", srcs = [ @@ -1749,10 +1993,12 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1776,7 +2022,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_ordering", ":hlo_pass", - ":liveness_util", ":logical_buffer", ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", @@ -1823,11 +2068,30 @@ cc_library( ], ) +cc_library( + name = "hlo_module_dce", + srcs = ["hlo_module_dce.cc"], + hdrs = ["hlo_module_dce.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_liveness_analysis", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], deps = [ + ":hlo", ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", @@ -1857,13 +2121,13 @@ cc_library( hdrs = ["hlo_rematerialization.h"], deps = [ ":buffer_liveness", + ":buffer_value", ":call_graph", ":flatten_call_graph", ":hlo", ":hlo_dce", ":hlo_ordering", ":hlo_scheduling", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1888,6 +2152,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1911,6 +2176,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_dce_test", + srcs = ["hlo_module_dce_test.cc"], + deps = [ + ":hlo", + ":hlo_module_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], @@ -1927,9 +2213,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1977,6 +2263,7 @@ cc_library( hdrs = ["hlo_cse.h"], deps = [ ":hlo", + ":hlo_domain_map", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1999,6 +2286,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -2040,6 +2328,78 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_domain_map", + srcs = ["hlo_domain_map.cc"], + hdrs = ["hlo_domain_map.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_sharding_metadata", + srcs = ["hlo_sharding_metadata.cc"], + hdrs = [ + "hlo_sharding_metadata.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_domain_isolator", + srcs = ["hlo_domain_isolator.cc"], + hdrs = ["hlo_domain_isolator.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_library( + name = "hlo_domain_remover", + srcs = ["hlo_domain_remover.cc"], + hdrs = ["hlo_domain_remover.h"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_domain_test", + srcs = ["hlo_domain_test.cc"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_remover", + ":hlo_parser", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], @@ -2068,8 +2428,14 @@ tf_cc_test( cc_library( name = "device_memory_allocator", - srcs = ["device_memory_allocator.cc"], - hdrs = ["device_memory_allocator.h"], + srcs = [ + "device_memory_allocator.cc", + "owning_device_memory.cc", + ], + hdrs = [ + "device_memory_allocator.h", + "owning_device_memory.h", + ], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2104,6 +2470,24 @@ cc_library( ], ) +xla_test( + name = "elemental_ir_emitter_test", + srcs = ["elemental_ir_emitter_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "hlo_module_config", srcs = ["hlo_module_config.cc"], @@ -2158,7 +2542,6 @@ cc_library( name = "hlo_tfgraph_builder", srcs = ["hlo_tfgraph_builder.cc"], hdrs = ["hlo_tfgraph_builder.h"], - visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", @@ -2175,7 +2558,6 @@ tf_cc_test( srcs = ["hlo_tfgraph_builder_test.cc"], deps = [ ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", @@ -2190,6 +2572,7 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", @@ -2228,6 +2611,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -2237,6 +2621,7 @@ tf_cc_test( srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", @@ -2244,7 +2629,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2280,7 +2666,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2315,6 +2701,24 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + ], +) + +tf_cc_test( + name = "hlo_proto_util_test", + srcs = ["hlo_proto_util_test.cc"], + deps = [ + ":hlo", + ":hlo_proto", + ":hlo_proto_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -2354,6 +2758,7 @@ cc_library( srcs = ["hlo_runner.cc"], hdrs = ["hlo_runner.h"], deps = [ + ":computation_placer", ":executable", ":hlo", ":transfer_manager", @@ -2365,7 +2770,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -2401,8 +2806,8 @@ tf_cc_test( ":tuple_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2413,7 +2818,9 @@ cc_library( deps = [ ":call_inliner", ":hlo", + ":hlo_creation_utils", ":tuple_util", + "//tensorflow/core:lib", ], ) @@ -2423,9 +2830,10 @@ tf_cc_test( deps = [ ":while_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2451,11 +2859,54 @@ tf_cc_test( ":hlo_matchers", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) +cc_library( + name = "while_loop_constant_sinking", + srcs = ["while_loop_constant_sinking.cc"], + hdrs = ["while_loop_constant_sinking.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_constant_sinking_test", + srcs = ["while_loop_constant_sinking_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_constant_sinking", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "despecializer", + srcs = ["despecializer.cc"], + hdrs = ["despecializer.h"], + deps = [ + ":bfloat16_normalization", + ":defuser", + ":hlo", + ":hlo_pass", + ":hlo_pass_pipeline", + ":implicit_broadcast_remover", + "//tensorflow/compiler/xla:statusor", + ], +) + cc_library( name = "source_map_util", srcs = ["source_map_util.cc"], @@ -2468,16 +2919,96 @@ cc_library( ], ) -# ----------------------------------------------------------------------------- +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + ], +) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +tf_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":hlo_matchers", + ":indexed_array_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo", + ":hlo_lexer", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_casting_utils", + hdrs = ["hlo_casting_utils.h"], + deps = ["//tensorflow/core:lib"], +) + +tf_cc_test( + name = "hlo_casting_utils_test", + srcs = ["hlo_casting_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fb857559f972a220a19b108baa4c441e09b90e1f..3b36939b8a6900f047bbec225aa232e0e805b5d1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -26,10 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -44,8 +45,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { + namespace { +namespace m = match; + // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -88,25 +92,6 @@ bool ReshapeIsBitcast( valid_bitcast_callback(operand->shape(), reshape->shape()); } -// Adds a scalar computation to the module to enable optimizations with dot -// converting into reduction. -HloComputation* CreateScalarBinaryComputation(HloModule* module, - PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - HloComputation* scalar_computation = - module->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_computation; -} -} // namespace - // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -122,6 +107,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBitcastConvert(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleConcatenate(HloInstruction* concatenate) override; @@ -170,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub) override; + Status HandleMap(HloInstruction* map) override; + Status HandleMaximum(HloInstruction* maximum) override; Status HandleMinimum(HloInstruction* minimum) override; @@ -213,8 +202,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloComputation* AddReduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, {dim}, AddReduce_computation)); @@ -245,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); - // A Reshape or Broadcast that feeds an element-wise operation with a unique - // non-scalar operand can sink to after the operation. - StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast); + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -284,6 +272,26 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + StatusOr OptimizeDotOfGather(HloInstruction* dot); + + HloComputation* GetOrCreateScalarAddComputation() { + if (scalar_add_computation_) { + return scalar_add_computation_; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation_ = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation_; + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -300,10 +308,15 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplication on platforms where it causes a slowdown. + // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; + + // Cached computation for adding two scalar F32. + HloComputation* scalar_add_computation_ = nullptr; }; +} // namespace + bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, @@ -348,8 +361,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( } Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { - auto lhs = add->mutable_operand(0); - auto rhs = add->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); + // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { @@ -364,7 +378,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Canonicalization: Put constants on the right. This makes the reassociation // rules below simpler. VLOG(10) << "trying transform [Const + A => A + Const]"; - if (lhs->IsConstant() && !rhs->IsConstant()) { + if (Match(add, m::Add(m::Constant(), m::NonConstant()))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); @@ -377,20 +391,13 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // (A + C1) + (B + C2) => A + B + (C1 + C2). // VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; - if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && - !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { - auto* c1 = lhs->mutable_operand(1); - auto* c2 = rhs; - TF_ASSIGN_OR_RETURN( - Shape sum_of_constants_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2)); - - auto* sum_of_constants = - computation_->AddInstruction(HloInstruction::CreateBinary( - sum_of_constants_shape, HloOpcode::kAdd, c1, c2)); + HloInstruction *a, *c1, *c2; + if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2)))) { + TF_ASSIGN_OR_RETURN(auto* sum_of_constants, + MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); return ReplaceWithNewInstruction( - add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, - lhs->mutable_operand(0), + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, sum_of_constants)); } @@ -399,11 +406,11 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. - if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { + HloInstruction* op; + if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { return ReplaceWithNewInstruction( - bitcast, HloInstruction::CreateUnary( - bitcast->shape(), HloOpcode::kBitcast, - bitcast->mutable_operand(0)->mutable_operand(0))); + bitcast, + HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op)); } // All bitcasts can be eliminated (assuming layout constraints are // satisified). @@ -411,13 +418,19 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleBitcastConvert( + HloInstruction* bitcast) { + // Eliminate bitcast converts between same shape. + ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (copy->operand(0)->opcode() == HloOpcode::kCopy) { + HloInstruction* op; + if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary( - copy->shape(), HloOpcode::kCopy, - copy->mutable_operand(0)->mutable_operand(0))); + copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); @@ -457,12 +470,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } else if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. - bool is_effective_low_pad = - operands[0]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[0]->operand(0)->shape()); - bool is_effective_high_pad = - operands[1]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[1]->operand(0)->shape()); + bool is_effective_low_pad = Match( + operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); + bool is_effective_high_pad = Match( + operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); if (!is_effective_low_pad && !is_effective_high_pad) { return Status::OK(); } @@ -494,13 +505,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const Literal& literal) { + const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { elems.push_back( - BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); + BuildTupleConstant(computation, LiteralSlice(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { @@ -516,12 +527,24 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { return ReplaceInstruction( constant, BuildTupleConstant(computation_, constant->literal())); } + + // If a literal is all the same element replace it with a scalar broadcast. + if (ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsAllFirst()) { + std::unique_ptr unique_scalar = + MakeUnique(constant->literal().GetFirstScalarLiteral()); + HloInstruction* scalar = computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(unique_scalar))); + return ReplaceWithNewInstruction( + constant, + HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); + } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { - auto lhs = sub->mutable_operand(0); - auto rhs = sub->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { @@ -530,7 +553,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { // Canonicalize subtraction of a constant to addition. VLOG(10) << "trying transform [A - Const => A + (-Const)]"; - if (rhs->IsConstant() && !lhs->IsConstant()) { + if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { HloInstruction* negative_const = computation_->AddInstruction( HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); return ReplaceWithNewInstruction( @@ -542,56 +565,53 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { } Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - auto lhs = divide->mutable_operand(0); - auto rhs = divide->mutable_operand(1); + Shape* shape; + HloInstruction *a, *b, *c, *d; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) { return Status::OK(); } // exp(A)/exp(B) => exp(A-B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) + .WithShape(m::Shape(&shape)))) { VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); - HloInstruction* subtract = - computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + HloInstruction* subtract = computation_->AddInstruction( + HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, - subtract)); + divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract)); } // A/exp(B) => A*exp(-B) - if (rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) { VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b)); HloInstruction* new_exp = computation_->AddInstruction( HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, new_exp)); } // A/pow(B,C) => A*pow(B,-C) - if (rhs->opcode() == HloOpcode::kPower) { + if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) { VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); // The output shape of the created negate operator should be the same as the // input. - const Shape& negate_shape = rhs->operand(1)->shape(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); + const Shape& negate_shape = c->shape(); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c)); // And the power operator should retain the output shape of the old one. - const Shape& new_power_shape = rhs->shape(); - HloInstruction* new_power = computation_->AddInstruction( - HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, - rhs->mutable_operand(0), negate)); + const Shape& new_power_shape = b->shape(); + HloInstruction* new_power = + computation_->AddInstruction(HloInstruction::CreateBinary( + new_power_shape, HloOpcode::kPower, b, negate)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_power)); + divide->shape(), HloOpcode::kMultiply, a, new_power)); } // Simplifying integral division would produce unexpected results. @@ -603,65 +623,46 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // // (Backends can do this transformation, but generally only if the constant is // a scalar.) - if (lhs->opcode() != HloOpcode::kConstant && - rhs->opcode() == HloOpcode::kConstant) { + if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { HloInstruction* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(lhs->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = - computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kDivide, one, rhs)); + Literal::One(a->shape().element_type()).CloneToUnique())); + HloInstruction* inverse = computation_->AddInstruction( + HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, inverse)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, inverse)); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) - if (lhs->opcode() == HloOpcode::kDivide && - rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - const Shape a_times_d_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, - lhs->operand(0), rhs->operand(1))); - auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( - a_times_d_shape, HloOpcode::kMultiply, lhs->mutable_operand(0), - rhs->mutable_operand(1))); - TF_ASSIGN_OR_RETURN( - const Shape b_times_c_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, - lhs->operand(1), rhs->operand(0))); - auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), - rhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c)); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), + m::Divide(m::Op(&c), m::Op(&d))))) { + TF_ASSIGN_OR_RETURN(auto a_times_d, + MakeBinaryHlo(HloOpcode::kMultiply, a, d)); + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); + TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, + a_times_d, b_times_c)); + + return ReplaceInstruction(divide, new_divide); } // (A / B) / C => A / (B * C) - if (lhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(const Shape b_times_c_shape, - ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs->operand(1), rhs)); - auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - lhs->mutable_operand(0), b_times_c)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a, b_times_c)); } // A / (B / C) => (A*C) / B - if (rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(const Shape a_times_c_shape, - ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs, rhs->operand(1))); - auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - a_times_c_shape, HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { + TF_ASSIGN_OR_RETURN(auto a_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, a, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - a_times_c, rhs->mutable_operand(0))); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a_times_c, b)); } return Status::OK(); @@ -669,8 +670,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); int64 lhs_collapsing_dim = dot->dot_dimension_numbers().lhs_contracting_dimensions(0); if (lhs->IsRank2Transpose()) { @@ -787,8 +788,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, @@ -917,9 +918,137 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.rhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0 || + dnums.rhs_batch_dimensions_size() != 0 || + dot->shape().dimensions_size() != 2) { // dot output 2D + VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; + return nullptr; + } + + // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). + // Currently a Gather is a DynamicSlice. + auto is_dynamic_slice_constant_combination = + [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { + // First operand is a DynamicSlice(Constant). + if (a->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + auto* dynamic_slice_op = a->operand(0); + if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { + return false; + } + // Second operand is a Constant. + if (b->opcode() != HloOpcode::kConstant) { + return false; + } + // The DynamicSlice output is a vector. + const Shape& dynamic_slice_shape = a->shape(); + if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { + return false; + } + // Constant size is the same before and after slice in the contracting + // dimension, otherwise we either must precompute for all possible slice + // indices or dot is invalid. + const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); + if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != + dynamic_slice_shape.dimensions(a_contracting_dimension)) { + return false; + } + return true; + }; + + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + + if (!is_dynamic_slice_constant_combination( + lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && + !is_dynamic_slice_constant_combination( + rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { + VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " + "dot(ctB, DS(ctA)), where the two constants have equal " + "contracting dimensions."; + return nullptr; + } + + // LHS is DynamicSlice: + // input: dot(DS(ctA), ctB)) + // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. + // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. + + // RHS is DynamicSlice: + // input: dot(ctA, DS(ctB)) + // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). + // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. + + bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + + // ctA: + HloInstruction* left_operand = + lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; + // ctB: + HloInstruction* right_operand = + lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); + // Build ctA x ctB. + const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); + const int n = + right_operand->shape().dimensions(1 - rhs_contracting_dimension); + auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); + auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( + memoized_shape, left_operand, right_operand, dnums)); + // Get pair {start, 0} or {0, start}. + HloInstruction* original_start_indices = + lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); + // Position of start: + int index_of_non_zero_start = lhs_is_dynamic_slice + ? 1 - lhs_contracting_dimension + : 1 - rhs_contracting_dimension; + // Position of zero: + int index_of_zero_start = 1 - index_of_non_zero_start; + + // Slice out start and 0 components and reorder if necessary. + auto indices_type = original_start_indices->shape().element_type(); + Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + HloInstruction* non_zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_non_zero_start}, + {index_of_non_zero_start + 1}, {1})); + HloInstruction* zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_zero_start}, + {index_of_zero_start + 1}, {1})); + HloInstruction* new_start_indices = + lhs_is_dynamic_slice + ? computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {non_zero_start, zero_start}, 0)) + : computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {zero_start, non_zero_start}, 0)); + + // Build DynamicSlice(ctA x ctB). + const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; + const int new_slice_n = lhs_is_dynamic_slice ? n : 1; + auto* memoized_lookup = + computation_->AddInstruction(HloInstruction::CreateDynamicSlice( + dot->shape(), memoized_inst, new_start_indices, + {new_slice_m, new_slice_n})); + + return memoized_lookup; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or // below. @@ -946,6 +1075,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_concat_optimized); } + // Simplify dot(ConstA, Gather(Index, ConstB)) to: + // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately + // batched version of dot. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, + OptimizeDotOfGather(dot)); + if (dot_of_gather_optimized) { + VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " + "gather(i, dot*(constA, constB))"; + return ReplaceInstruction(dot, dot_of_gather_optimized); + } + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); @@ -971,8 +1111,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { - auto lhs = multiply->mutable_operand(0); - auto rhs = multiply->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { @@ -985,10 +1125,9 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } // exp(A) * exp(B) => exp(A+B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( - multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + multiply->shape(), HloOpcode::kAdd, lhs, rhs)); return ReplaceWithNewInstruction( multiply, HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); @@ -999,20 +1138,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); - auto operand = log->mutable_operand(0); - if (operand->opcode() == HloOpcode::kExp && - ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { + HloInstruction *a, *b; + if (Match(log, m::Log(m::Exp(m::Op(&a)))) && + ReplaceInstructionIfSameShape(log, a)) { return Status::OK(); } // ln(pow(A,B)) => B*ln(A) - if (operand->opcode() == HloOpcode::kPower) { - auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( - log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); + if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { + auto new_log = computation_->AddInstruction( + HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); return ReplaceWithNewInstruction( - log, - HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, - new_log, operand->mutable_operand(1))); + log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, b)); } return Status::OK(); @@ -1115,11 +1253,12 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, } // namespace Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { - auto operand = broadcast->mutable_operand(0); + HloInstruction* operand; + CHECK(Match(broadcast, m::Broadcast(m::Op(&operand)))); + auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. - if (std::is_sorted(broadcast->dimensions().begin(), - broadcast->dimensions().end()) && + if (std::is_sorted(dims.begin(), dims.end()) && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " @@ -1137,8 +1276,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, - broadcast->dimensions())); + broadcast, + HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -1152,7 +1291,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { if (merely_inserts_or_deletes_1_sized_dimensions && deleted_indices.empty()) { std::reverse(inserted_indices.begin(), inserted_indices.end()); - auto dims = broadcast->dimensions(); for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } @@ -1167,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); changed_ |= sink_succeeded; if (sink_succeeded) { return Status::OK(); @@ -1196,6 +1334,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return user->ReplaceAllUsesWith(new_broadcast); } } + return Status::OK(); + } + + // Merge two consecutive broadcasts into a single one. + if (operand->opcode() == HloOpcode::kBroadcast) { + std::vector new_dimensions; + for (auto dim : operand->dimensions()) { + new_dimensions.push_back(dims[dim]); + } + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast( + broadcast->shape(), operand->mutable_operand(0), new_dimensions)); } return Status::OK(); } @@ -1214,30 +1365,28 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { // Complex(Real(c), Imag(c)) -> c Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { - auto real = complex->mutable_operand(0); - auto imag = complex->mutable_operand(1); - if (real->opcode() == HloOpcode::kReal && - imag->opcode() == HloOpcode::kImag && - real->operand(0) == imag->operand(0)) { - return ReplaceInstruction(complex, real->mutable_operand(0)); + HloInstruction *c0, *c1; + if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) && + c0 == c1) { + return ReplaceInstruction(complex, c0); } return Status::OK(); } // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { - auto operand = real->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(real, operand->mutable_operand(0)); + HloInstruction* op; + if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) { + return ReplaceInstruction(real, op); } return Status::OK(); } // Imag(Complex(r, i)) -> i Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { - auto operand = imag->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(imag, operand->mutable_operand(1)); + HloInstruction* op; + if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) { + return ReplaceInstruction(imag, op); } return Status::OK(); } @@ -1290,17 +1439,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { padding_dimension->set_edge_padding_high(0); } } - TF_ASSIGN_OR_RETURN(Shape nonzero_pad_shape, - ShapeInference::InferPadShape(pad->operand(0)->shape(), - pad->operand(1)->shape(), - nonzero_padding)); + + TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, + MakePadHlo(pad->mutable_operand(0), + pad->mutable_operand(1), nonzero_padding)); // Copy the layout from the original pad instructions. The new pad and the // slice instruction should all have the same layout. - TF_RETURN_IF_ERROR( - LayoutUtil::CopyLayoutBetweenShapes(pad->shape(), &nonzero_pad_shape)); - HloInstruction* nonzero_pad = computation_->AddInstruction( - HloInstruction::CreatePad(nonzero_pad_shape, pad->mutable_operand(0), - pad->mutable_operand(1), nonzero_padding)); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + pad->shape(), nonzero_pad->mutable_shape())); // Second, construct the slice instruction to perform the negative padding. std::vector start_indices; @@ -1313,7 +1459,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (padding_dimension.edge_padding_low() < 0) { start = -1 * padding_dimension.edge_padding_low(); } - int64 end = nonzero_pad_shape.dimensions(i); + int64 end = nonzero_pad->shape().dimensions(i); if (padding_dimension.edge_padding_high() < 0) { end += padding_dimension.edge_padding_high(); } @@ -1322,16 +1468,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { strides.push_back(1); } - // Verify that the slice shape matches the pad shape. TF_ASSIGN_OR_RETURN( - Shape inferred_slice_shape, - ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices, - end_indices, strides)); - TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); + HloInstruction * slice, + MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); - std::unique_ptr slice = HloInstruction::CreateSlice( - pad->shape(), nonzero_pad, start_indices, end_indices, strides); - return ReplaceWithNewInstruction(pad, std::move(slice)); + // Verify that the slice shape matches the pad shape. + TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape())); + + return ReplaceInstruction(pad, slice); } return Status::OK(); @@ -1339,8 +1483,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - auto lhs = power->mutable_operand(0); - auto rhs = power->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( Literal::One(power->shape().element_type()).CloneToUnique()); @@ -1360,9 +1504,10 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { } // pow(exp(A),B) => exp(A*B) - if (lhs->opcode() == HloOpcode::kExp) { + HloInstruction *a, *b; + if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) { auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( - power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); + power->shape(), HloOpcode::kMultiply, a, b)); return ReplaceWithNewInstruction( power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, a_times_b)); @@ -1412,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor:: - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast) { +StatusOr +AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast) { + TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); bool changed = false; - if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + if (ShapeUtil::IsScalar(broadcast->shape())) { return false; } - HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); - for (HloInstruction* user : reshape_or_broadcast->users()) { + HloInstruction* operand = broadcast->mutable_operand(0); + for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; } @@ -1438,55 +1584,50 @@ StatusOr AlgebraicSimplifierVisitor:: continue; } - int64 reshape_or_broadcast_operand_index = -1; // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_count = 0; - for (int64 i = 0; i < user->operand_count(); ++i) { - if (ShapeUtil::IsScalar(user->operand(i)->shape())) { - ++scalar_count; - } else { - reshape_or_broadcast_operand_index = i; + int64 scalar_broadcast_count = 0; + int64 broadcast_use_count = 0; + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; + } else if (broadcast == user_operand) { + ++broadcast_use_count; } } - if (scalar_count != user->operand_count() - 1) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } - VLOG(4) << "Sinking reshape or broadcast after user:"; - VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); + std::vector new_operands; + new_operands.reserve(user->operand_count()); + + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()), + user_operand->mutable_operand(0), {}))); + } else { + CHECK_EQ(broadcast, user_operand); + new_operands.push_back(operand); + } + } + VLOG(4) << "Sinking broadcast after user:"; + VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), - reshape_or_broadcast); - auto new_user_operands = user->operands(); - new_user_operands[reshape_or_broadcast_operand_index] = operand; - auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(operand->shape().dimensions()), - LayoutUtil::MinorToMajor(operand->shape())), - new_user_operands)); + HloInstruction* new_user = + computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()), + new_operands)); VLOG(4) << " new user: " << new_user->ToString(); - HloInstruction* new_reshape_or_broadcast = nullptr; - if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user)); - } else { - TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user, reshape_or_broadcast->dimensions())); - } - VLOG(4) << " new reshape/broadcast: " - << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); + HloInstruction* new_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + user->shape(), new_user, broadcast->dimensions())); + VLOG(4) << " new broadcast: " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -1529,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } - // A Reshape that feeds a unary element-wise operation can sink the - // reshape after the unary element-wise operation. - TF_ASSIGN_OR_RETURN( - bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - changed_ |= sink_succeeded; - if (sink_succeeded) { - return Status::OK(); - } - // Make this a bitcast if possible. if (is_layout_sensitive_ && ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { @@ -1604,6 +1735,14 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } + + // If any dimension of update is 0, elide the DynamicUpdateSlice. This + // optimization becomes invalid should we later prefer to warn about out of + // bound indices. + if (ShapeUtil::HasZeroElements(update->shape())) { + return ReplaceInstruction(dynamic_update_slice, + dynamic_update_slice->mutable_operand(0)); + } return Status::OK(); } @@ -1635,6 +1774,46 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); + } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1679,15 +1858,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape()) || - ShapeUtil::HasZeroElements(arg->shape())) { - auto reshape = computation_->AddInstruction( - HloInstruction::CreateReshape(reduce->shape(), arg)); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateMap(reduce->shape(), - {reshape, init_value}, function)); - } return Status::OK(); } @@ -1707,22 +1877,33 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateMap(reduce_window->shape(), - {operand, reduce_window->mutable_operand(1)}, + {reduce_window->mutable_operand(1), operand}, function)); } - VLOG(10) << "Considering folding Pad: " << operand->ToString() - << "\ninto reduce-window: " << reduce_window->ToString(); - // This optimization folds a pad op into reduce_window. - if (operand->opcode() != HloOpcode::kPad) { + HloInstruction* pad; + const HloInstruction* convert = nullptr; + if (operand->opcode() == HloOpcode::kPad) { + pad = operand; + } else if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->opcode() == HloOpcode::kPad) { + convert = operand; + pad = operand->mutable_operand(0); + } else { VLOG(10) << "Not folding pad into reduce-window as there is no pad."; return Status::OK(); } + VLOG(10) << "Considering folding Pad: " << pad->ToString() + << "\ninto reduce-window: " << reduce_window->ToString() + << (convert != nullptr ? tensorflow::strings::StrCat( + "\nvia convert: ", convert->ToString()) + : ""); + // Do not fold interior padding into ReduceWindow since the backends do not // support it. - const PaddingConfig& pad_config = operand->padding_config(); + const PaddingConfig& pad_config = pad->padding_config(); if (HasInteriorPadding(pad_config)) { VLOG(10) << "Not folding pad into reduce-window due to interior padding."; return Status::OK(); @@ -1730,14 +1911,27 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If reduce_window already has padding, the pad value of the pad op and the // init value of reduce_window must match to allow folding the pad. - const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* pad_value = pad->operand(1); const HloInstruction* reduce_init_value = reduce_window->operand(1); if (pad_value != reduce_init_value) { + auto literals_are_equivalent = [&] { + auto& pad_literal = pad_value->literal(); + auto& reduce_init_literal = reduce_init_value->literal(); + if (pad_literal == reduce_init_literal) { + return true; + } + auto converted_pad_literal = pad_literal.ConvertToShape( + reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + if (!converted_pad_literal.ok()) { + return false; + } + return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - pad_value->literal() != reduce_init_value->literal()) { + !literals_are_equivalent()) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1746,7 +1940,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If the pad puts a single non-identity value in each window that we're // reducing, then this is a broadcast. - HloInstruction* pad_operand = operand->mutable_operand(0); + HloInstruction* pad_operand = pad->mutable_operand(0); auto is_effective_broadcast = [&] { if (window_util::HasStride(window)) { VLOG(10) << "Window has stride."; @@ -1790,6 +1984,18 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Found window covers a single unpadded element."; return true; }; + + HloInstruction* new_reduce_window_operand; + if (convert != nullptr) { + new_reduce_window_operand = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad_operand->shape(), + convert->shape().element_type()), + pad_operand)); + } else { + new_reduce_window_operand = pad_operand; + } + if (is_effective_broadcast()) { VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; auto fadd = [this](std::unique_ptr x) { @@ -1798,7 +2004,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcastSequence( /*output_shape=*/reduce_window->shape(), - /*operand=*/pad_operand, fadd)); + /*operand=*/new_reduce_window_operand, fadd)); } // Carry out the folding of the pad into reduce_window. @@ -1815,10 +2021,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( window_dim.set_padding_high(window_dim.padding_high() + pad_dim.edge_padding_high()); } + return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateReduceWindow( /*shape=*/reduce_window->shape(), - /*operand=*/pad_operand, + /*operand=*/new_reduce_window_operand, /*init_value=*/reduce_window->mutable_operand(1), /*window=*/new_window, /*reduce_computation=*/function)); @@ -2000,6 +2207,39 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } +Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { + auto* map_computation = map->to_apply(); + auto* map_root = map_computation->root_instruction(); + if (map_root->opcode() == HloOpcode::kParameter) { + ReplaceInstructionIfSameShape( + map, map->mutable_operand(map_root->parameter_number())); + return Status::OK(); + } + if (map_root->opcode() == HloOpcode::kConstant) { + if (!ShapeUtil::IsScalar(map_root->shape())) { + return Status::OK(); + } + auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); + if (ShapeUtil::IsScalar(map->shape())) { + return ReplaceWithNewInstruction(map, std::move(clone)); + } + return ReplaceWithNewInstruction( + map, + HloInstruction::CreateBroadcast( + map->shape(), computation_->AddInstruction(std::move(clone)), {})); + } + std::vector new_operands; + for (auto* root_operand : map_root->operands()) { + if (root_operand->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + new_operands.push_back( + map->mutable_operand(root_operand->parameter_number())); + } + auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); + return ReplaceWithNewInstruction(map, std::move(clone)); +} + Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 43315f5cdc7afbe79039420320f4a0d0535e11f1..c48196e861a559a5abfa360841ec70b39356fa2b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs AlgebraicSimplications. +// A pass which performs algebraic simplifications. class AlgebraicSimplifier : public HloPassInterface { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to @@ -57,10 +57,10 @@ class AlgebraicSimplifier : public HloPassInterface { bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; - // Enable dot simplication on platforms where it is profitable. + // Enable dot simplification on platforms where it is profitable. bool enable_dot_strength_reduction_; - // Enable convolution simplication on platforms where it is profitable. + // Enable convolution simplification on platforms where it is profitable. bool enable_conv_simplification_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0f08eb3a3267c4b7b04958270a5788fc48d3fa04..2605b0488cb7c6850746df94c4ab05d6b5d35de5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -35,6 +36,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" +using ::testing::ElementsAre; + namespace xla { namespace { @@ -71,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Reduce(Reduce(A)) -> Reduce(A) +TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r4f32, "param")); + std::vector dims0({0}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7}); + HloInstruction* reduce0 = builder.AddInstruction( + HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation)); + std::vector dims1({1, 2}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, + dims1, add_computation)); + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -140,6 +181,39 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMap); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, zero)); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -162,6 +236,37 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({3.14f, 3.14f, 3.14f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); +} + +TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({3.14, 3.14, 4}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1284,32 +1389,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); - HloInstruction* movable_reshape = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), - HloOpcode::kMaximum, movable_reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Maximum(param, zero))); -} - // Regression test for a bug in the reshape sinking transformation, where // moving a reshape to a scalar led to a crash. TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { @@ -1666,14 +1745,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1699,8 +1778,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1718,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1733,14 +1812,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1756,14 +1835,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1891,12 +1970,13 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(b.Build()); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); + auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - if (!simplifier.Run(&module).ValueOrDie()) { + if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } auto* root = computation->root_instruction(); @@ -2011,15 +2091,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Minimum(param0, min_value), max_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2041,15 +2121,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2072,15 +2152,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2102,15 +2182,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2134,8 +2214,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2143,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2159,17 +2239,15 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); @@ -2178,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2196,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2209,8 +2285,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); @@ -2218,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2227,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - HloModule module(TestName()); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2256,7 +2333,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2279,15 +2356,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2305,6 +2382,92 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); } +// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to +// ReduceWindow(Convert(op), x). +TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* parameter = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); + + HloInstruction* convert = + builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad->shape(), F32), pad)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module->AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, convert, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { HloComputation::Builder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); @@ -2313,12 +2476,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2431,6 +2594,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* input_array = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, input_array, {1})); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root->dimensions(), ElementsAre(2)); +} + +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + // The initial dimensions go to places 0 and 2 in the 3-dim array, + // and to places 1 and 3 in the 4-dim array, + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, param0, {0, 2})); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -2769,8 +2981,234 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { {/*m=*/1, /*k=*/16, /*n=*/1}, // }; +// Test that DynamicUpdateSlice update param with any dimension equal to zero +// gets removed. +TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + HloComputation::Builder builder(TestName()); + const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); + HloInstruction* const operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, dslice_shape, "operand")); + const Shape update_shape = ShapeUtil::MakeShape(F32, {0}); + HloInstruction* const update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + HloInstruction* const start_indices = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0}))); + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + dslice_shape, operand, update, start_indices)); + const HloComputation* const computation = + module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), operand); +} + INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); + +struct DotOfGatherTestSpec { + int64 m; + int64 k; + int64 n; + int s; // start index for dynamic slice on the non-contracting dimension + int64 lcd; // left contracting dimension + int64 rcd; // right contracting dimension + bool neg; // is negative testcase +}; + +class DotOfGatherSimplificationTest + : public HloVerifiedTestBase, + public ::testing::WithParamInterface {}; + +// input: dot(DS(ctA), ctB)) +// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. +// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. +TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.m); + + // For negative tests, increase k of the dynamic slice argument to prevent the + // optimization (constants ctA, ctB must have equal contracting dimensions). + int64 k_increase = spec.neg ? 5 : 0; + int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + int32 start_row = (spec.lcd == 0) ? 0 : spec.s; + int32 start_col = (spec.lcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + + int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = 1; + int64 dot_col_size = spec.n; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +// input: dot(ctA, DS(ctB)) +// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}). +// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. +TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.n); + + int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + // For negative tests increase k of the dynamic slice argument to prevent the + // optimization + int64 k_increase = spec.neg ? 5 : 0; + int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + int32 start_row = (spec.rcd == 0) ? 0 : spec.s; + int32 start_col = (spec.rcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = spec.m; + int64 dot_col_size = 1; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +std::vector DotOfGatherPositiveNegativeTests() { + std::vector positives = { + // "Classical dot", i.e. matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to + // dot(ct, ct) before DotOfGather optimization kicks in. + // Contract on rows: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + // Reverse matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + // Contract on columns: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + }; + std::vector all; + for (int i = 0; i < positives.size(); i++) { + DotOfGatherTestSpec positive_test = positives[i]; + all.push_back(positive_test); + DotOfGatherTestSpec negative_test = positive_test; + negative_test.neg = true; + all.push_back(negative_test); + } + return all; +} + +INSTANTIATE_TEST_CASE_P( + DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, + ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 4e80679c11dfdf7fdf8077a9f354139a4cab6803..95b4cb6d2e694063b648b264bd2454ae0a5469ff 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -31,75 +31,117 @@ limitations under the License. namespace xla { StatusOr AllocationTracker::Register( - std::unique_ptr shaped_buffer, const string& tag) { + ScopedShapedBuffer shaped_buffer, const string& tag) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(std::move(shaped_buffer), tag); + std::vector replicated_buffers; + replicated_buffers.emplace_back(std::move(shaped_buffer)); + return RegisterInternal(std::move(replicated_buffers), tag); } +StatusOr AllocationTracker::RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag) { + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "RegisterReplicatedBuffers"; + return RegisterInternal(std::move(replicated_buffers), tag); +} + +// ReleaseIfScopedShapedBuffer lets RegisterInternal(b) call +// b.release() if b is a ScopedShapedBuffer, or otherwise pass b through +// unmodified. +static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; } +static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) { + return b.release(); +} + +template StatusOr AllocationTracker::RegisterInternal( - std::unique_ptr shaped_buffer, const string& tag) { + std::vector replicated_buffers, const string& tag) { + static_assert(std::is_same::value || + std::is_same::value, + "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer."); VLOG(2) << "RegisterInternal(" - << "tag: \"" << tag << "\" " - << "shaped_buffer: " << *shaped_buffer; - if (shaped_buffer->platform() != backend_->platform()) { - return InvalidArgument( - "AllocationTracker for platform %s cannot register buffer from " - "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer->platform()->Name().c_str()); + << "tag: \"" << tag << "\" with " << replicated_buffers.size() + << " shaped_buffers."; + for (const auto& shaped_buffer : replicated_buffers) { + VLOG(2) << "shaped_buffer:" << shaped_buffer; + if (shaped_buffer.platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer.platform()->Name().c_str()); + } } int64 handle = next_handle_++; - std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); - for (const ShapeIndex& index : shape_indices) { - AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal()); + for (auto& shaped_buffer : replicated_buffers) { + std::vector shape_indices; + ShapeUtil::ForEachSubshape( + shaped_buffer.on_device_shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + shape_indices.push_back(index); + }); + // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns + // them. + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index), + shaped_buffer.device_ordinal()); + } + // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer + // into a regular ShapedBuffer, which is stored in + // handle_to_shaped_buffers_. + handle_to_shaped_buffers_[handle].emplace_back(MakeUnique( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } + GlobalDataHandle result; result.set_handle(handle); - - handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); - VLOG(2) << "handle: " << handle; - return result; } -tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { +Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; - TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); - std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); - for (const ShapeIndex& index : shape_indices) { - TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + for (const auto& shaped_buffer : replicated_buffers) { + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); + } } - - // Keep a nullptr as a tombstone for unregistered handles. This enables better - // error messages. That is, "handle has been deallocated" versus "handle does - // not exist". - handle_to_shaped_buffer_.at(data.handle()).reset(); - - return tensorflow::Status::OK(); + // Keep a nullptr as a tombstone for unregistered handles. This enables + // better error messages. That is, "handle has been deallocated" versus + // "handle does not exist". + auto it = handle_to_shaped_buffers_.find(data.handle()); + if (it == handle_to_shaped_buffers_.end()) { + return NotFound("no allocation record for global data handle: %lld", + data.handle()); + } + for (auto& shaped_buffer : it->second) { + shaped_buffer.reset(); + } + return Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); - TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + // We only need to care about replica id 0 here, since the GlobalDataHandle is + // the same for all buffers across replicas. + const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); @@ -109,79 +151,98 @@ StatusOr> AllocationTracker::DeconstructTuple( TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { - return Unimplemented("deconstructing nested tuples not yet supported"); + return Unimplemented("Deconstructing nested tuples is not implemented."); } std::vector element_handles; for (int i = 0; i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); ++i) { - auto element_buffer = MakeUnique( + auto element_buffer = ShapedBuffer( ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), shaped_buffer->platform(), shaped_buffer->device_ordinal()); - element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), - /*index=*/{}); + element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}), + /*index=*/{}); + std::vector replicated_buffers; + replicated_buffers.push_back(std::move(element_buffer)); TF_ASSIGN_OR_RETURN( GlobalDataHandle element_handle, - RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + RegisterInternal(std::move(replicated_buffers), "deconstructed tuple")); element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr> AllocationTracker::Resolve( const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveForReplica( + const GlobalDataHandle& data, int replica_id) { + tensorflow::mutex_lock lock(mutex_); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + if (replica_id >= replicated_buffers.size()) { + return InvalidArgument( + "Requesting buffer for replica %d, but found buffers only for %lu " + "replicas.", + replica_id, replicated_buffers.size()); + } + return replicated_buffers[replica_id]; +} + +StatusOr> AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_shaped_buffer_.find(data.handle()); - if (it == handle_to_shaped_buffer_.end()) { + auto it = handle_to_shaped_buffers_.find(data.handle()); + if (it == handle_to_shaped_buffers_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - ShapedBuffer* shaped_buffer = it->second.get(); - - if (shaped_buffer == nullptr) { - return InvalidArgument("global data handle %lld was previously deallocated", - data.handle()); + std::vector replicated_buffers; + for (const auto& shaped_buffer : it->second) { + if (shaped_buffer == nullptr) { + return InvalidArgument( + "global data handle %lld was previously deallocated", data.handle()); + } + replicated_buffers.push_back(shaped_buffer.get()); } - return shaped_buffer; + return replicated_buffers; } void AllocationTracker::AddAllocationOrIncrementRefCount( - perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + se::DeviceMemoryBase device_memory, int device_ordinal) { AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); if (it == allocation_map.end()) { - allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, - /*ref_count=*/1}; + allocation_map[device_memory.opaque()] = { + OwningDeviceMemory(device_memory, device_ordinal, + backend_->memory_allocator()), + /*ref_count=*/1}; } else { it->second.ref_count++; } } -Status AllocationTracker::DecrementRefCount( - perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { +Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, + int device_ordinal) { AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); TF_RET_CHECK(it != allocation_map.end()); Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( - device_ordinal, &device_memory)); + allocation.device_memory.Free(); allocation_map.erase(it); } else { allocation.ref_count--; } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 807af8694972083d097604a67ee46d2f73d9545a..a7d8927cf7e90d764ff8046df16c71922b11478e 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -43,9 +43,15 @@ class AllocationTracker { AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} // Registers a shaped buffer of device memory, and returns a corresponding - // handle that can be used for talking to XLA clients. - StatusOr Register( - std::unique_ptr shaped_buffer, const string& tag); + // handle that can be used for talking to XLA clients. The given shaped buffer + // will be treated as the buffer corresponding to the only replica. + StatusOr Register(ScopedShapedBuffer shaped_buffer, + const string& tag); + + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag); // Unregister the allocation for the given data handle. Status Unregister(const GlobalDataHandle& data); @@ -54,18 +60,23 @@ class AllocationTracker { StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to a shaped buffer, or provide an error - // status to say whether it was not found (or found, but found deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a vector of shaped buffers, one per + // replica, or provide an error status to say whether any of those buffers + // were not found (or found, but found deallocated). + StatusOr> Resolve( + const GlobalDataHandle& data); + + // Resolves a handle from an XLA client and replica id to a shaped buffer, or + // provide an error status to say whether it was not found (or found, but + // found deallocated). + StatusOr ResolveForReplica(const GlobalDataHandle& data, + int replica_id); private: // Data structure encapsulating single memory allocation on the device. struct Allocation { // The pointer to this allocation. - perftools::gputools::DeviceMemoryBase device_memory; - - // The device that the memory is allocated on. - int device_ordinal; + OwningDeviceMemory device_memory; // This is the number of times this memory allocation is referred to by // registered data handles. @@ -73,24 +84,28 @@ class AllocationTracker { }; // Internal helper which resolves the given GlobalDataHandle to a - // ShapedBuffer. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Internal helper which registers a shaped buffer. + // list of ScopedShapedBuffers. + StatusOr> ResolveInternal( + const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Internal helper which registers a vector of shaped buffers, one per + // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If + // it's ShapedBuffer, all of the given buffers must already be tracked by this + // object -- presumably this is a call from DeconstructTuple. + template StatusOr RegisterInternal( - std::unique_ptr shaped_buffer, const string& tag) + std::vector replicated_buffers, const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Adds the given device address to the allocation tracker, or if it already - // exists, then increment it's reference count. - void AddAllocationOrIncrementRefCount( - perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) + // exists, then increment its reference count. + void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory, + int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Decrements the reference count of the given device memory. Then, if it is // zero, deallocate the memory. - Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory, + Status DecrementRefCount(se::DeviceMemoryBase device_memory, int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // A map from device memory opaque value to allocation. One such map is @@ -108,12 +123,31 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - tensorflow::gtl::FlatMap opaque_to_allocation_map_ + // + // This is not a TF FlatMap because (currently) FlatMap (and therefore + // AllocationMap) is not movable. + std::unordered_map opaque_to_allocation_map_ GUARDED_BY(mutex_); - // A map from data handle to ShapedBuffer. - tensorflow::gtl::FlatMap> - handle_to_shaped_buffer_ GUARDED_BY(mutex_); + // A map from data handle to a vector of shaped buffers that represent the + // buffers for different replicas. + // + // The ShapedBuffers in this map's vectors need to be unique_ptrs, because our + // public API returns pointers to them. We expect the concrete class to be + // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is + // handled by opaque_to_allocation_map_. + // + // The elements of the vectors need to be unique_ptrs because we return + // pointers to them. (In theory we could use std::list or something instead, + // but we also want to be able to null out these elements.) + // + // The reason that the elements can't be unique_ptrs is + // the existence of DeconstructTuple(). This function allows us to create a + // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then + // free'd when both the view *and* the original tuple are Unregistered. This + // refcounting is managed in opaque_to_allocation_map_. + tensorflow::gtl::FlatMap>> + handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 05f2d062784147108a94ffb7bb0ca42ddfe4f010..349b32451a697dbd6804b44cd1a36419c753bb14 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -31,24 +31,20 @@ limitations under the License. #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { -BackendOptions& BackendOptions::set_platform( - perftools::gputools::Platform* platform) { +BackendOptions& BackendOptions::set_platform(se::Platform* platform) { platform_ = platform; return *this; } -perftools::gputools::Platform* BackendOptions::platform() const { - return platform_; -} +se::Platform* BackendOptions::platform() const { return platform_; } BackendOptions& BackendOptions::set_intra_op_parallelism_threads( int num_threads) { @@ -77,7 +73,7 @@ struct Backend::EigenThreadPoolWrapper { /* static */ StatusOr> Backend::CreateBackend( const BackendOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); @@ -121,7 +117,7 @@ StatusOr Backend::BorrowStream( } Backend::Backend( - perftools::gputools::Platform* platform, Compiler* compiler, + se::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads) @@ -142,9 +138,6 @@ Backend::Backend( << "Service found no devices for backend " << platform_->Name() << '.'; if (platform->id() == se::host::kHostPlatformId) { - inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "xla_inter_op", - tensorflow::port::NumSchedulableCPUs())); const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); @@ -159,10 +152,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { - return inter_op_thread_pool_.get(); -} - const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { if (intra_op_thread_pool_wrapper_ == nullptr) { @@ -178,7 +167,7 @@ tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { return intra_op_thread_pool_wrapper_->pool.get(); } -StatusOr Backend::stream_executor( +StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || device_ordinal > stream_executors_.back()->device_ordinal()) { @@ -201,9 +190,9 @@ StatusOr Backend::devices_equivalent(int device_ordinal_a, // bit crude but works for GPUs which is the important case where we compile // an executable for one GPU and want to know if it will run (well) on // another. - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_a, + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_a, stream_executor(device_ordinal_a)); - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_b, + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_b, stream_executor(device_ordinal_b)); return (executor_a->GetDeviceDescription().name() == executor_b->GetDeviceDescription().name()); diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index b5ca483b7274d20c31e932d748b6a4c9dea926f9..6546602473e3381cf13879ddebd05d34d1f7a055 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -44,8 +44,8 @@ namespace xla { class BackendOptions { public: // Set the platform backing the backend, or nullptr for the default platform. - BackendOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; + BackendOptions& set_platform(se::Platform* platform); + se::Platform* platform() const; // Sets the thread pool size for parallel execution of an individual operator. // The default value of -1 will result in initializing the thread pool with @@ -54,7 +54,7 @@ class BackendOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_ = nullptr; + se::Platform* platform_ = nullptr; int intra_op_parallelism_threads_ = -1; }; @@ -66,7 +66,7 @@ class BackendOptions { // StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - using StreamPtr = Pool::SmartPtr; + using StreamPtr = Pool::SmartPtr; // Creates a new backend. static StatusOr> CreateBackend( @@ -79,7 +79,7 @@ class Backend { ~Backend(); // Accessors for the various objects. - perftools::gputools::Platform* platform() const { return platform_; } + se::Platform* platform() const { return platform_; } Compiler* compiler() const { return compiler_; } DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_.get(); @@ -96,19 +96,17 @@ class Backend { // Returns stream executors of all supported devices for this backend. The // executors are ordered by the device ordinal. - const std::vector& stream_executors() - const { + const std::vector& stream_executors() const { return stream_executors_; } // Returns the stream executor for the given device ordinal. - StatusOr stream_executor( - int device_ordinal) const; + StatusOr stream_executor(int device_ordinal) const; // Returns the stream executor for the default device ordinal. This stream // executor can only be used when the number of computations is 1 (replication // can be > 1). - perftools::gputools::StreamExecutor* default_stream_executor() const { + se::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; } @@ -117,8 +115,7 @@ class Backend { // internal pool, or by constructing/initializating it, and returns the result // to the caller. StatusOr BorrowStream(int device_ordinal); - StatusOr BorrowStream( - perftools::gputools::StreamExecutor* executor); + StatusOr BorrowStream(se::StreamExecutor* executor); // Returns a function to borrow a stream, as `BorrowStream` above does. // Purely for convenience, the caller could rather make this anonymous @@ -143,10 +140,6 @@ class Backend { // be equivalent to an executable compiled for the other. StatusOr devices_equivalent(int device_ordinal_a, int device_ordinal_b); - // For the host platform, returns the threadpool to use when scheduling - // parallel operators. For other platforms, returns NULL. - tensorflow::thread::ThreadPool* inter_op_thread_pool() const; - // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; @@ -157,36 +150,30 @@ class Backend { private: struct EigenThreadPoolWrapper; - Backend(perftools::gputools::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice - stream_executors, + Backend(se::Platform* platform, Compiler* compiler, + tensorflow::gtl::ArraySlice stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; - perftools::gputools::Platform* platform_; + se::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; ComputationPlacer* computation_placer_; // Vector of stream executors. stream_executors_[0] is the default executor. - std::vector stream_executors_; + std::vector stream_executors_; tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map> - stream_pools_ GUARDED_BY(mu_); + std::map> stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; - // For the CPU backend, a threadpool for scheduling parallel operators. - std::unique_ptr inter_op_thread_pool_; - // For the CPU backend, an Eigen threadpool device for use by Eigen code. std::unique_ptr intra_op_thread_pool_wrapper_; }; diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc new file mode 100644 index 0000000000000000000000000000000000000000..2099916509acdbc2680cc2b5bd405e96f2f7bfb8 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" + +namespace xla { +StatusOr +BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot) { + const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); + HloInstruction *lhs = batch_dot->mutable_operand(0), + *rhs = batch_dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + + std::vector degenerate_dims; + for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { + if (lhs_shape.dimensions(batch_dim) == 1) { + degenerate_dims.push_back(batch_dim); + } + } + + if (degenerate_dims.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, + ElideDegenerateDims(lhs, degenerate_dims)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, + ElideDegenerateDims(rhs, degenerate_dims)); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.clear_lhs_batch_dimensions(); + new_dim_numbers.clear_rhs_batch_dimensions(); + + for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() - + degenerate_dims.size(); + i < e; i++) { + new_dim_numbers.add_lhs_batch_dimensions(i); + new_dim_numbers.add_rhs_batch_dimensions(i); + } + + new_dim_numbers.set_lhs_contracting_dimensions( + 0, + new_dim_numbers.lhs_contracting_dimensions(0) - degenerate_dims.size()); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, + new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, + MakeReshapeHlo(batch_dot->shape(), new_dot)); + + VLOG(2) << "Replaced " << batch_dot->ToString() << " with " + << new_dot->ToString(); + + TF_RETURN_IF_ERROR( + batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped)); + + return true; +} + +tensorflow::StringPiece BatchDotSimplification::name() const { + return "batch-dot-simplification"; +} + +StatusOr BatchDotSimplification::Run(HloModule* module) { + bool changed = false; + std::vector dot_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); + } + for (HloInstruction* dot_instr : dot_instrs) { + TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, + ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); + changed |= elided_batch_dim_from_one; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h new file mode 100644 index 0000000000000000000000000000000000000000..c0ca8d8ebac1a3b218e7bd4d6db02b69cfb6916f --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// Simplifies batch dot operations. +// +// Normally these would live in the algebraic simplifier, but we want to run +// this to fixpoint (this pass reaches fixed point in one execution) before we +// run the DotDecomposer. +class BatchDotSimplification : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override; + tensorflow::StringPiece name() const override; + + private: + StatusOr ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..38f1a5d3a645f98220ec445bb9bbdf2b9b842109 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class BatchDotSimplificationTest : public HloVerifiedTestBase {}; + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3,7] parameter(1) + ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,3] parameter(1) + ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,20,3] parameter(1) + ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,19,3] parameter(0) + b = f32[9,1,7,1,3,20] parameter(1) + ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 27ddfd47aa3096afd3e245af1ac3cedd9b48ce4a..ec13fadbc75e2315d1d6ef72e24a0faca0c7de40 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -15,36 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include #include -#include -#include #include #include #include -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -62,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -74,37 +69,55 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} - - HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); + rewrite_grad_op_(rewrite_grad_op) {} + + HloComputation* GetOrCreateScalarAddComputation( + PrimitiveType primitive_type) { + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormExpander is - // traversing. - HloComputation* computation_; - - bool rewrite_training_op_; - bool rewrite_inference_op_; - bool rewrite_grad_op_; - bool use_fusion_; + std::unique_ptr Rsqrt( + HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, + operand, exponent); + } - // Whether rewrite has occurred. - bool changed_ = false; + std::unique_ptr Mean( + int64 element_count, HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* elem_count_recip = + add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(1.0 / element_count))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, + operand, elem_count_recip); + } // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -128,18 +141,29 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { changed_ = true; return Status::OK(); } + // Current HloComputation instance the BatchNormExpander is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_inference_op_; + bool rewrite_grad_op_; + + // Whether rewrite has occurred. + bool changed_ = false; }; +} // namespace + bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -153,9 +177,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); // Expand batch norm training into smaller HLO ops. @@ -165,12 +194,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -182,8 +206,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = - add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = add(HloInstruction::CreateBroadcast( + operand_shape, + add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -199,11 +224,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // X^2. - auto operand_squared = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = + add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); // Sum[X]. auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, dimensions_without_feature, @@ -214,71 +239,48 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); - // Fuse two parallel reduces together to improve performance. - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); - - sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - squared_sum = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // E[X]. - auto mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + auto mean = add(Mean(elements_per_feature_int64, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); // E^2[X]. - auto mean_square = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, mean, mean)); + auto mean_square = + add_binary(feature_shape, HloOpcode::kMultiply, mean, mean); // Var[X]. - auto var = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + auto var = + add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square); auto var_broadcasted = add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted); auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); @@ -320,8 +322,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( + operand_shape, + computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(epsilon_literal))), + {})); std::vector dimensions_without_feature; @@ -334,9 +339,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); auto scale_broadcasted = add( @@ -352,30 +362,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( @@ -419,9 +422,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); @@ -437,26 +445,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = + auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon_activation = add( + HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {})); + auto epsilon_feature = + add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {})); std::vector dimensions_without_feature; @@ -476,29 +478,26 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, - epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = + add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation), + add)); + + auto rsqrt_var_add_epsilon = add(Rsqrt( + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature), + add)); // X - E[X]. - auto activation_minus_mean = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); + auto activation_minus_mean = add_binary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). auto grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + activation_minus_mean); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = @@ -511,25 +510,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple( - {sum_grad_output_times_activiation_minus_mean, grad_beta})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, - HloInstruction::FusionKind::kInput); - - sum_grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - grad_beta = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); + auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, + sum_grad_output_times_activiation_minus_mean, + rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, @@ -541,39 +525,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( {feature_index})); // I4 = (X - E[X]) * I3 - auto i4 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); + auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3, + activation_minus_mean); // I5 = I4 / (Var[X] + epsilon) - auto i5 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, i4, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)))); + auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4, + add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation)); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = + add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, - elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add( + Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); - auto i1 = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto elements_per_feature_literal = + Literal::CreateR0(elements_per_feature_int64); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + add(HloInstruction::CreateBroadcast( + activation_shape, elements_per_feature, {}))); // I6 = I1 - I2 - I5 - auto i6 = add(HloInstruction::CreateBinary( + auto i6 = add_binary( activation_shape, HloOpcode::kSubtract, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - i1, i2)), - i5)); + add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, i6)); + auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6); auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { @@ -602,8 +587,8 @@ StatusOr BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_fusion_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 4ad987085da91684bb7891070afeefd19be4138f..7ae202c583516443a6263403fb5460d1adbabd97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, bool use_fusion = true) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index cde990e176ddb57a8e93ecc3c60260b2dbae32a8..1b8b2d204503576c3fcb02f6d5b37f2db45e1768 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,6 +34,9 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); @@ -84,6 +87,25 @@ Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( return Status::OK(); } +namespace { + +// Returns whether hlo has users and all users are conversions from F32 to BF16. +bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) { + if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) { + return false; + } + for (const auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + return false; + } + return true; +} + +} // namespace + Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( HloInstruction* hlo) { std::vector bf16_to_f32_operands; @@ -104,22 +126,9 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( } } - bool fold_output_conversion = hlo->user_count() > 0 && - hlo->shape().element_type() == F32 && - bfloat16_support_->SupportsBF16Output(*hlo) && - hlo != computation_->root_instruction(); - if (fold_output_conversion) { - for (auto user : hlo->users()) { - if (user->opcode() == HloOpcode::kConvert && - user->shape().element_type() == BF16) { - continue; - } - // We should not change the output type if any user is not a conversion - // from F32 to BF16. - fold_output_conversion = false; - break; - } - } + const bool fold_output_conversion = + AllUsersAreF32ToBF16Converts(hlo) && + bfloat16_support_->SupportsBF16Output(*hlo); if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { if (has_other_f32_operands || @@ -147,6 +156,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kGetTupleElement || // hlo->opcode() == HloOpcode::kInfeed || // hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kSend || // + hlo->opcode() == HloOpcode::kSendDone || // + hlo->opcode() == HloOpcode::kRecv || // + hlo->opcode() == HloOpcode::kRecvDone || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -167,6 +180,63 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } +Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + // First use DefaultAction() to handle the operands. It can't handle + // tuple-shaped output. + TF_RETURN_IF_ERROR(DefaultAction(crs)); + + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + + // Then do per-tuple-element handling on the output. + std::vector> per_tuple_element_gtes( + crs->operand_count()); + for (auto user : crs->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return Status::OK(); + } + per_tuple_element_gtes[user->tuple_index()].push_back(user); + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + // Fold conversions only when all the get-tuple-elements' users are + // conversions from F32 to BF16. + auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() { + for (auto gte : per_tuple_element_gtes[i]) { + if (!AllUsersAreF32ToBF16Converts(gte)) { + return false; + } + } + return true; + }; + if (!all_gte_users_are_bf16_convert()) { + continue; + } + + ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) + ->set_element_type(BF16); + for (auto gte : per_tuple_element_gtes[i]) { + TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); + } + } + + return Status::OK(); +} + StatusOr BFloat16ConversionFolding::Run(HloModule* module) { XLA_VLOG_LINES( 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index cb37759439debf41a305ec7dccaa548e1bf234cd..f7b4c1405dbc8719d8fba5476e6e41d2921ea877 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -37,7 +37,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -47,7 +48,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -55,7 +57,8 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -206,4 +209,57 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } +TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { + auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* convert_a = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, a)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum, /*replica_group_ids=*/{}, /*barrier=*/"")); + HloInstruction* gte_a = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); + HloInstruction* gte_b = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 1)); + HloInstruction* convert_gte_b = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte_b)); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({gte_a, convert_gte_b})); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_EQ(tuple->operand(0), gte_a); + EXPECT_EQ(tuple->operand(1), gte_b); + EXPECT_EQ(gte_a->shape().element_type(), F32); + EXPECT_EQ(gte_b->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0), a); + EXPECT_EQ(crs->operand(1), b); + EXPECT_EQ(a->shape().element_type(), BF16); + EXPECT_EQ(b->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {0}).element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index b032c040e8aff49f9e0fc1ff9a1c1e79ea4bb77f..14c54ddd135af024327f63418b410da1ed3c4fd4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -152,44 +152,64 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( std::vector operand_types(crs->operand_count()); std::vector output_types(crs->operand_count()); - bool has_f32 = false; - bool has_bf16 = false; - bool has_bf16_output = false; + int64 f32_count = 0; + int64 bf16_count = 0; + bool has_unsupported_bf16_operand = false; + bool has_unsupported_bf16_output = false; for (int64 i = 0; i < crs->operand_count(); ++i) { operand_types[i] = crs->operand(i)->shape().element_type(); output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); - if (operand_types[i] == F32 || output_types[i] == F32) { - has_f32 = true; + if (operand_types[i] == F32) { + f32_count += 1; } else if (operand_types[i] == BF16) { - has_bf16 = true; + bf16_count += 1; + if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + has_unsupported_bf16_operand = true; + } } - if (output_types[i] == BF16) { - has_bf16 = true; - has_bf16_output = true; + if (output_types[i] == F32) { + f32_count += 1; + } else if (output_types[i] == BF16) { + bf16_count += 1; + if (!bfloat16_support_->SupportsBF16Output(*crs)) { + has_unsupported_bf16_output = true; + } } } - for (int64 i = 0; i < crs->operand_count(); ++i) { + if (bf16_count == 0) { + return Status::OK(); + } + + auto should_convert_operand = [&](int64 i) { if (operand_types[i] != BF16) { - continue; + return false; } - if (bfloat16_support_->SupportsBF16Operand(*crs, i) && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { - continue; + if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + return true; } - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); - has_f32 = true; - } + if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return false; + } + return has_unsupported_bf16_operand || has_unsupported_bf16_output || + f32_count > 0; + }; - if (!has_bf16_output) { - return Status::OK(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (should_convert_operand(i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + f32_count += 1; + bf16_count -= 1; + } } - if (bfloat16_support_->SupportsBF16Output(*crs) && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + if (!has_unsupported_bf16_output && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + bf16_count == 0)) { return Status::OK(); } + std::vector materialized_users = crs->users(); std::vector output_elements(crs->operand_count()); auto original_shape = crs->shape(); for (int64 i = 0; i < crs->operand_count(); ++i) { @@ -209,7 +229,6 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - std::vector materialized_users = crs->users(); // Use the crs' shape temporarily, in order to pass checks in // ReplaceUseWith. *tuple->mutable_shape() = crs->shape(); @@ -221,41 +240,37 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( } Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { - std::vector bf16_operands; - std::vector f32_operands; - bool has_f32 = false; - bool has_bf16 = false; + int f32_count = 0; + int bf16_count = 1; for (int64 i = 0; i < hlo->operand_count(); ++i) { if (hlo->operand(i)->shape().element_type() == F32) { - f32_operands.push_back(i); - has_f32 = true; + f32_count += 1; } else if (hlo->operand(i)->shape().element_type() == BF16) { - bf16_operands.push_back(i); - has_bf16 = true; + bf16_count += 1; } } if (hlo->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (hlo->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; } std::vector bf16_called_comps; for (auto* comp : hlo->called_computations()) { bool comp_has_bf16 = false; if (comp->root_instruction()->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (comp->root_instruction()->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; comp_has_bf16 = true; } for (auto* param : comp->parameter_instructions()) { if (param->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (param->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; comp_has_bf16 = true; } } @@ -264,54 +279,69 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { } } - if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && - has_f32) { - // Resolve unsupported mixed precision. - // - // See if we can change everything to BF16. - if (hlo->called_computations().empty() && - hlo->shape().element_type() == BF16) { - bool can_use_bf16 = true; - for (int i : f32_operands) { - if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, - i) && - bfloat16_support_->SupportsBF16Operand(*hlo, i)) { - continue; - } - can_use_bf16 = false; - break; - } - if (can_use_bf16) { - for (int i : f32_operands) { - TF_RETURN_IF_ERROR( - InsertConvertBeforeOperand(hlo, i, BF16, computation_)); - } - return Status::OK(); - } - } - if (hlo->shape().element_type() == BF16) { - TF_RETURN_IF_ERROR( - ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); - } - for (int i : bf16_operands) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); - } - return ConvertCalledComputations(hlo, bf16_called_comps); - } - - for (int i : bf16_operands) { - if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Resolve unsupported BF16 operands. + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Operand(*hlo, i)) { TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + bf16_count -= 1; + f32_count += 1; } } + // Resolve unsupported BF16 output. if (hlo->shape().element_type() == BF16 && !bfloat16_support_->SupportsBF16Output(*hlo)) { TF_RETURN_IF_ERROR( ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + bf16_count -= 1; + f32_count += 1; } - return Status::OK(); + // Resolve unsupported mixed precision after resolving unsupported BF16 + // operands and output, because the numbers of BF16 operands/output and F32 + // operands/output may have changed. + if (bfloat16_support_->SupportsMixedPrecisions(*hlo) || bf16_count == 0 || + f32_count == 0) { + return Status::OK(); + } + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16) { + continue; + } + if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) || + bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i)) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + return ConvertCalledComputations(hlo, bf16_called_comps); } Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 66c3085842c4afe7ffc4d5891883e4cce9389d45..830f26422bdc2b3bd789e7d5926bcebac815d34a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -41,13 +42,17 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kGetTupleElement) { return true; } + if (hlo.opcode() == HloOpcode::kDot) { + // Test that only the first operand of kDot supports BF16. + return operand_index == 0; + } return false; } bool SupportsBF16Output(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kSubtract || - hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { return true; } @@ -70,6 +75,10 @@ class BFloat16NormalizationTest : public HloTestBase { BFloat16Normalization normalization(&bfloat16_support_); StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); + + HloVerifier verifier(/*allow_mixed_precision=*/true); + EXPECT_IS_OK(verifier.Run(module).status()); + return result.ValueOrDie(); } }; @@ -166,7 +175,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); - Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {}); auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( @@ -219,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -230,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); @@ -245,4 +265,34 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +// Tests that the normalization should not cause unsupported mixed precision due +// to resolving unsupported BF16 operand. +TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(dot->shape().element_type(), F32); + EXPECT_EQ(dot->operand(0)->shape().element_type(), F32); + EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(dot->operand(1)->shape().element_type(), F32); + EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConvert); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed0746980f87ac2bea79c308644dc63769f9e309 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -0,0 +1,852 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_propagation.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +BFloat16Propagation::BFloat16Propagation( + const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + +void BFloat16Propagation::DetermineFusionComputationPrecision( + HloInstruction* fusion) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) { + return; + } + + // We are depending on the fusion node itself having already been analyzed + // for whether it can output BF16 and this has been adjusted in the output + // shape, and now we're looking to update the interior of the fusion node to + // match the new output shape, as well as recursively process the whole fusion + // node even if the output shape was not modified. + auto root = fusion->fused_instructions_computation()->root_instruction(); + + // Adjust root's element types according to the fusion's output shape. + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + if (OutputTypeAfterChange(fusion, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); + VLOG(2) << "Fused root " << root->ToString() << " at shape index " + << index << " changed to BF16 precision for fusion " + << fusion->ToString(); + } + }); + + // Propagate BF16 in the fusion computation. + auto insts = + fusion->fused_instructions_computation()->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_backward_pass_.insert( + fusion->fused_instructions_computation()); + + RevertIfFusionInternalBF16Changes(fusion); +} + +void BFloat16Propagation::RevertIfFusionInternalBF16Changes( + HloInstruction* fusion) { + auto has_changes = [this](HloInstruction* inst) { + auto it = changes_to_bf16_.find(inst); + return it != changes_to_bf16_.end() && !it->second.empty(); + }; + + auto root = fusion->fused_instructions_computation()->root_instruction(); + tensorflow::gtl::FlatSet changed_root_buffers; + + auto root_changes_it = changes_to_bf16_.find(root); + if (root_changes_it != changes_to_bf16_.end()) { + for (const auto& index : root_changes_it->second) { + for (const HloValue* value : + dataflow_->GetValueSet(root, index).values()) { + changed_root_buffers.insert(value); + } + } + } + + auto aliases_changed_root_buffer = + [this, &changed_root_buffers](const HloInstruction* inst) { + bool aliasing = false; + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (aliasing) { + // Skip if aliasing is already found. + return; + } + // Only F32 buffers are considered for changing to BF16 in this + // pass. + if (subshape.element_type() != F32) { + return; + } + for (const HloValue* value : + dataflow_->GetValueSet(inst, index).values()) { + if (ContainsKey(changed_root_buffers, value)) { + aliasing = true; + break; + } + } + }); + return aliasing; + }; + + for (auto inst : + fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + if (aliases_changed_root_buffer(inst)) { + continue; + } + if (inst->opcode() == HloOpcode::kFusion) { + bool parameter_reverted = false; + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (has_changes(inst->mutable_operand(i))) { + // Changes on the operand have not been reverted. + continue; + } + auto* fused_parameter = inst->fused_parameter(i); + if (has_changes(fused_parameter)) { + changes_to_bf16_.erase(fused_parameter); + parameter_reverted = true; + } + } + if (parameter_reverted) { + RevertIfFusionInternalBF16Changes(inst); + } + } + if (!has_changes(inst)) { + continue; + } + bool revert_changes = true; + for (auto operand : inst->operands()) { + if (has_changes(operand)) { + revert_changes = false; + break; + } + } + if (revert_changes) { + changes_to_bf16_.erase(inst); + } + } +} + +void BFloat16Propagation::DetermineWhileComputationsPrecision( + HloInstruction* while_hlo) { + CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile); + + // We are depending on the while node itself having already been analyzed for + // whether it can output BF16 and this has been adjusted in the output shape, + // and now we're looking to update the body and condition computations to + // match the new output shape, as well as recursively process the whole while + // node even if the output shape was not modified. + HloComputation* body = while_hlo->while_body(); + auto body_root = body->root_instruction(); + HloComputation* condition = while_hlo->while_condition(); + + ShapeUtil::ForEachSubshape( + body_root->shape(), [this, while_hlo, body_root]( + const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + if (OutputTypeAfterChange(while_hlo, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16); + VLOG(2) << "While body root " << body_root->ToString() + << " at shape index " << index + << " changed to BF16 precision for while " + << while_hlo->ToString(); + } + }); + + auto body_insts = body->MakeInstructionPostOrder(); + for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend(); + ++inst_it) { + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_backward_pass_.insert(body); + + auto condition_insts = condition->MakeInstructionPostOrder(); + for (auto inst_it = condition_insts.rbegin(); + inst_it != condition_insts.rend(); ++inst_it) { + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_backward_pass_.insert(condition); +} + +bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const { + auto& value_set = dataflow_->GetValueSet(&hlo, index); + for (const HloValue* value : value_set.values()) { + if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { + return false; + } + if (ValueTypeAfterChange(value) == BF16) { + continue; + } + for (const HloUse& use : value->uses()) { + if (!ContainsKey(instructions_visited_in_backward_pass_, + use.instruction)) { + // We don't know yet whether use.instruction will consume BF16 since it + // hasn't been visited. Although we visit instructions in reverse + // topological order, this is still possible because there may be + // unvisited instruction that alias the same buffer. In this case, we + // aggressively skip this use, and if this causes inconsistency (e.g., + // one use is in BF16 but another use is in F32), it will be resolved at + // the end of the BFloat16Propagation pass. + continue; + } + // Any visited user that can accept BF16 has already been updated if + // necessary, e.g., the output has been changed to BF16 if it propagates + // precision, or a called computation's parameters have been changed to + // BF16 for fusions or whiles. + if (use.instruction->opcode() == HloOpcode::kFusion) { + auto* fused_parameter = + use.instruction->fused_parameter(use.operand_number); + if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) { + return false; + } + continue; + } else if (use.instruction->opcode() == HloOpcode::kWhile) { + auto* cond_parameter = + use.instruction->while_condition()->parameter_instruction( + use.operand_number); + if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) { + return false; + } + auto* body_parameter = + use.instruction->while_body()->parameter_instruction( + use.operand_number); + if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) { + return false; + } + continue; + } + if (bfloat16_support_->EffectiveOperandPrecisionIsBF16( + *use.instruction, use.operand_number)) { + continue; + } + // If the op propagates precision and it outputs a BF16, then it's OK to + // supply BF16 also as the input. In the backward pass, the users shapes + // should have already been processed. + PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + user_output_type = + OutputTypeAfterChange(use.instruction, use_output_index); + } else { + user_output_type = OutputTypeAfterChange(use.instruction, {}); + } + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( + *use.instruction, use.operand_number) && + user_output_type == BF16) { + continue; + } + return false; + } + } + return true; +} + +void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, + bool skip_parameters) { + // We handle any fusion computation or while body/condition after the + // instruction is handled, because we need to know the output shape of a + // fusion or while before propagating inside its computations. + bool postpone_processing_called_computations = false; + auto cleaner = tensorflow::gtl::MakeCleanup( + [this, hlo, &postpone_processing_called_computations] { + if (!postpone_processing_called_computations) { + if (hlo->opcode() == HloOpcode::kFusion) { + DetermineFusionComputationPrecision(hlo); + } else if (hlo->opcode() == HloOpcode::kWhile) { + DetermineWhileComputationsPrecision(hlo); + } + } + instructions_visited_in_backward_pass_.insert(hlo); + }); + + if (hlo->opcode() == HloOpcode::kWhile && + (caller_counts_[hlo->while_condition()] > 1 || + caller_counts_[hlo->while_body()] > 1)) { + postpone_processing_called_computations = true; + return; + } + + // Do not change precision for instructions related to entry and exit of a + // computation, and control flow, because this pass might break the interfaces + // or assumptions for them. + if (hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kSend || // + hlo->opcode() == HloOpcode::kSendDone || // + hlo->opcode() == HloOpcode::kRecv || // + hlo->opcode() == HloOpcode::kRecvDone || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kConditional || // + (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { + return; + } + + // Prevent root instructions from having their output modified by recording + // all F32 output values as needing to stay as F32. + CHECK(hlo->parent() != nullptr); + if (hlo == hlo->parent()->root_instruction()) { + if (!hlo->parent()->IsFusionComputation()) { + ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */, + const ShapeIndex& index) { + if (OutputTypeAfterChange(hlo, index) != F32) { + return; + } + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + // Since we use HloValues from the dataflow analysis, this can also + // affect HLO instructions beyond the root, e.g., if the root is a + // Tuple HLO, then its operands are also affected. + values_that_must_be_kept_as_f32_.insert(value); + } + }); + } + return; + } + + if (!ContainsKey(consider_using_bfloat16_, hlo)) { + return; + } + + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { + return; + } + + ShapeUtil::ForEachSubshape( + hlo->shape(), + [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) { + if (OutputTypeAfterChange(hlo, index) == F32 && + AllUsersConsumeBF16(*hlo, index)) { + AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16); + VLOG(2) << "HloInstruction output at shape index " << index + << " changed to BF16 precision: " << hlo->ToString(); + } + }); +} + +bool BFloat16Propagation::InstructionIsCandidateForBF16Output( + HloInstruction* hlo) { + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && + hlo->opcode() != HloOpcode::kTuple && + hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->shape().element_type() != BF16) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) || + !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) { + return false; + } + } + } + return true; +} + +void BFloat16Propagation::AdjustCalledComputationParameters( + HloInstruction* hlo) { + auto adjust_computation = + [this, hlo](HloComputation* computation, + tensorflow::gtl::ArraySlice operands) { + // Adjust parameters. + CHECK_EQ(operands.size(), computation->num_parameters()); + for (int64 i = 0; i < operands.size(); ++i) { + auto parameter = computation->parameter_instruction(i); + ShapeUtil::ForEachSubshape( + parameter->shape(), + [this, i, hlo, &operands, parameter](const Shape& /* subshape */, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { + return; + } + PrimitiveType operand_type = + OutputTypeAfterChange(operands[i], index); + if (OutputTypeAfterChange(parameter, index) == operand_type) { + return; + } + AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type); + VLOG(2) << "Called computation parameter " + << parameter->ToString() << " at shape index " << index + << " adjusted to " + << (operand_type == BF16 ? "BF16" : "F32") + << " to match operand in HLO " << hlo->ToString(); + }); + } + }; + + switch (hlo->opcode()) { + case HloOpcode::kFusion: + adjust_computation(hlo->fused_instructions_computation(), + hlo->operands()); + break; + case HloOpcode::kWhile: + adjust_computation(hlo->while_condition(), hlo->operands()); + adjust_computation(hlo->while_body(), hlo->operands()); + break; + default: + break; + } +} + +void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { + auto adjust_computation = [this, hlo](HloComputation* computation, + HloInstruction* output) { + // Adjust root. + HloInstruction* root = computation->root_instruction(); + ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output]( + const Shape& /* subshape */, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { + return; + } + const PrimitiveType output_type = OutputTypeAfterChange(output, index); + if (OutputTypeAfterChange(root, index) == output_type) { + return; + } + AddToOrRemoveFromBF16ChangeSet(root, index, output_type); + // It's possible that output_type is F32, but the root instruction's + // type is BF16; e.g., a fusion node's output was changed to BF16 + // initially but then adjusted back to F32, and the fusion computation + // is now being adjusted after the fusion node. + if (output_type == F32) { + for (const auto* value : dataflow_->GetValueSet(root, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order so that called computation will be + // processed later. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the + // correctness of the adjustment for HLOs that will be + // processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index << " adjusted to " + << (output_type == BF16 ? "BF16" : "F32") + << " to match output shape of " << hlo->ToString(); + }); + }; + + switch (hlo->opcode()) { + case HloOpcode::kFusion: + adjust_computation(hlo->fused_instructions_computation(), hlo); + break; + case HloOpcode::kWhile: + adjust_computation(hlo->while_body(), hlo); + break; + default: + break; + } +} + +bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited_computations) { + bool parameter_changed = false; + auto insts = computation->MakeInstructionPostOrder(); + // Do the adjustment on each instruction in the computation in reverse + // topological order. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + auto adjust_hlo_output = [this, hlo, ¶meter_changed]( + const Shape& /* subshape */, + const ShapeIndex& index) { + auto output_type = OutputTypeAfterChange(hlo, index); + if (output_type != F32 && output_type != BF16) { + return; + } + PrimitiveType type = BF16; + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { + continue; + } + CHECK_EQ(value_type, F32); + type = F32; + break; + } + // It's possible that a user has been changed from BF16 to F32 + // during this final adjustment pass, so we need to check + // AllUsersConsumeBF16() again. + if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { + type = F32; + } + if (type == F32) { + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the correctness + // of the adjustment for HLOs that will be processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + if (type != output_type) { + AddToOrRemoveFromBF16ChangeSet(hlo, index, type); + VLOG(2) << "HloInstruction output at shape index " << index + << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": " + << hlo->ToString(); + if (hlo->opcode() == HloOpcode::kParameter) { + parameter_changed = true; + } + } + }; + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + if (hlo->opcode() == HloOpcode::kWhile) { + // We need to run on the while body and condition repeatedly until a fixed + // point is reached, i.e., the parameters do not change any more. We may + // need more than one iteration because the while input and output alias + // each other, so changing one input parameter requires changing the + // corresponding output element and thus may transitively require changing + // another input parameter. A fixed point will be reached because the + // parameters can only be changed from BF16 to F32, not the other way + // around. + tensorflow::gtl::FlatSet visited_in_while; + while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), + &visited_in_while) || + ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), + &visited_in_while)) { + visited_in_while.clear(); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + } + visited_computations->insert(visited_in_while.begin(), + visited_in_while.end()); + } + } + // Now adjust parameters of called computations. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + AdjustCalledComputationParameters(*inst_it); + } + return parameter_changed; +} + +void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( + HloModule* module) { + std::list computations_topological_order = + module->MakeComputationPostOrder(); + tensorflow::gtl::FlatSet resolved; + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + if (ContainsKey(resolved, *comp_it)) { + continue; + } + ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved); + } +} + +Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { + // We could have changed a fusion computation's root shape to have a different + // precision than the fusion node's output, if the fusion root does not + // define a buffer (e.g., a tuple). Now we add conversions after such fusion + // roots to make them match the fusion output. If the fusion output is a + // (possibly nested) tuple, we first create get-tuple-elements, then convert + // the unmatching leaf nodes, and finally create a new tuple as the fusion + // computation's root. If tuples and get-tuple-elements are created, we will + // run tuple simplifier and dead code elimination at the end (dead code is not + // allowed in fusion computation). E.g., + // + // (1) (2) (3) + // a b a b a b + // |\ | |\ | |\ | + // \ add -> |add -> | add + // \ | \ | convert | + // tuple tuple \ | + // / \ tuple + // gte gte + // | | + // convert | + // \ / + // tuple + // (1) a is F32 but tuple is BF16 + // (2) after adding conversion + // (3) after tuple simplifier and DCE. + bool needs_tuple_simplifier = false; + for (auto computation : module->MakeComputationPostOrder()) { + auto insts = computation->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + if (hlo->opcode() != HloOpcode::kFusion) { + continue; + } + auto fusion_computation = hlo->fused_instructions_computation(); + auto fusion_root = fusion_computation->root_instruction(); + if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) { + continue; + } + ShapeTree converted_outputs(hlo->shape()); + // Iterate through nodes in the shape tree in pre-order and initialize + // each non-root node with a corresponding get-tuple-element. For a leaf + // node, if its shape does not match the fusion output, create a + // conversion node to overwrite the node value. + for (auto it = converted_outputs.begin(); it != converted_outputs.end(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (output_index.empty()) { + output = fusion_root; + } else { + ShapeIndex parent_index = output_index; + parent_index.pop_back(); + output = fusion_computation->AddInstruction( + HloInstruction::CreateGetTupleElement( + subshape, converted_outputs.element(parent_index), + output_index.back())); + } + if (ShapeUtil::IsTuple(subshape)) { + continue; + } + if (!ShapeUtil::Compatible( + subshape, + ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) { + output = fusion_computation->AddInstruction( + HloInstruction::CreateConvert(subshape, output)); + } + } + // Iterate through nodes in the shape tree in reverse pre-order and create + // a tuple instruction for each non-leaf node where the elements are the + // values of its child nodes. + for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape& subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (!ShapeUtil::IsTuple(subshape)) { + continue; + } + std::vector elements( + ShapeUtil::TupleElementCount(subshape)); + ShapeIndex child_index = output_index; + for (int64 i = 0; i < elements.size(); ++i) { + child_index.push_back(i); + elements[i] = converted_outputs.element(child_index); + child_index.pop_back(); + } + output = fusion_computation->AddInstruction( + HloInstruction::CreateTuple(elements)); + } + fusion_computation->set_root_instruction(converted_outputs.element({})); + needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); + } + } + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + } + return Status::OK(); +} + +Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { + // We may have converted some constants from F32 to BF16, so adjust the + // constant literals in such cases. We do this here instead of when the + // constant node's is changed because 1) the HloInstruction interface does not + // allow resetting the literal so we have to create a new kConstant + // instruction to replace the old one, which invalidates dataflow analysis, + // and 2) it's possible that a kConstant's output gets changed to BF16 at the + // beginning but later on adjusted back to F32, so converting literals here + // can avoid repeated conversions. + // + // TODO(b/73833576): Consider resetting literal in HloInstruction. + for (auto computation : module->MakeComputationPostOrder()) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConstant) { + continue; + } + if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { + TF_ASSIGN_OR_RETURN( + auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape(), + /*round_f32_to_bf16=*/true)); + auto new_constant = computation->AddInstruction( + HloInstruction::CreateConstant(std::move(converted_literal))); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); + } + } + } + return Status::OK(); +} + +Status BFloat16Propagation::SkipNoopConversions(HloModule* module) { + for (auto computation : module->computations()) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConvert) { + continue; + } + auto source = hlo->mutable_operand(0); + if (!ShapeUtil::Equal(source->shape(), hlo->shape())) { + continue; + } + const bool is_root = hlo == computation->root_instruction(); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source)); + if (is_root) { + computation->set_root_instruction(source); + } + } + } + return Status::OK(); +} + +// The algorithm first does a forward pass (parameters to root) to determine a +// set of instructions to consider using bfloat16, then does a backward pass to +// determine the precisions of those instructions according to the need of +// their users. During the backward pass, the potential changes are stored in +// changes_to_bf16_ which are subject to further adjustments then applied to the +// HLOs. +StatusOr BFloat16Propagation::Run(HloModule* module) { + consider_using_bfloat16_.clear(); + instructions_visited_in_backward_pass_.clear(); + computations_visited_in_backward_pass_.clear(); + values_that_must_be_kept_as_f32_.clear(); + caller_counts_.clear(); + changes_to_bf16_.clear(); + changed_ = false; + + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); + + std::list computations_topological_order = + module->MakeComputationPostOrder(); + // The first step is a forward pass (parameters to root), where we determine + // the potential candidate instructions to use bfloat16 in the outputs that + // are not likely to cause overhead from extra explicit conversions. This is + // done forwardly because we determine whether an HLO is a candidate partially + // based on whether its operands are candidates. + for (auto computation : computations_topological_order) { + for (auto inst : computation->MakeInstructionPostOrder()) { + if (InstructionIsCandidateForBF16Output(inst)) { + consider_using_bfloat16_.insert(inst); + } + } + } + + // The second step is a backward pass (root to parameters), where we modify + // the precisions of the instructions identified in the first step when + // feasible. This is done backwardly because we determine the precision of an + // HLO's output based on how it is later used. + // + // The precision of an instruction is determined by its users, so we do the + // propagation in reverse topological order. + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + if ((*comp_it)->IsFusionComputation()) { + // Fusion computations are handled when visiting the fusion instruction. + continue; + } + auto insts = (*comp_it)->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineInstructionPrecision(*inst_it, + /*skip_parameters=*/true); + } + } + + // It's possible that an instruction does not define a buffer, but the + // defining instruction's shape has changed. So we need to adjust the output + // shapes of instructions according to the HLO values they refer to. + ResolveInconsistencyOfAliasingBuffers(module); + + // Apply the changes in changes_to_bf16_. + for (auto& change : changes_to_bf16_) { + auto shape = change.first->mutable_shape(); + for (const auto& index : change.second) { + auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + CHECK_EQ(subshape->element_type(), F32); + subshape->set_element_type(BF16); + changed_ = true; + } + } + + if (!changed_) { + return false; + } + + TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); + TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); + + // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> + // BF16), so we skip them now. + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); + + { + // We may have dead HLOs after ResolveInconsistentFusions, + // ResolveConvertedConstants and SkipNoopConversions. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return true; +} + +PrimitiveType BFloat16Propagation::OutputTypeAfterChange( + HloInstruction* hlo, const ShapeIndex& index) const { + PrimitiveType type_on_hlo = + ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + if (type_on_hlo != F32) { + return type_on_hlo; + } + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return type_on_hlo; + } + return ContainsKey(it->second, index) ? BF16 : F32; +} + +PrimitiveType BFloat16Propagation::ValueTypeAfterChange( + const HloValue* value) const { + auto hlo = value->defining_instruction(); + const auto& position = value->defining_position(); + return OutputTypeAfterChange(hlo, position.index); +} + +void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( + HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { + if (target_type == BF16) { + auto& entry = changes_to_bf16_[hlo]; + entry.insert(index); + } else { + CHECK_EQ(target_type, F32); + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return; + } + it->second.erase(index); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h new file mode 100644 index 0000000000000000000000000000000000000000..de0355ddfca127753f90d1899b424a8e77c9b291 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -0,0 +1,219 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace xla { + +// HLO pass which reduces the precision of some HLO instructions to BF16 +// according to the backend-specific BFloat16Support rule provided by the +// caller. +// +// This pass can be used to reduce instruction precision without affecting the +// numerical accuracy of the module, i.e., the final output of the module would +// be bitwise identical to that without this pass; this is possible if the +// backend already reduces precision to BF16 on some HLO instructions. +// +// This pass will not modify the signature of a computation, unless it is a +// fusion computation or its only caller is a while. +// +// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, +// which has two issues: +// +// 1) It does not guarantee to respect the passed-in BFloat16Support +// specification in terms of mixed precision, so the backend may not support an +// HLO that has mixed precision produced by this pass. To address this issue, +// run BFloat16Normalization with the same BFloat16Support after this pass. +// +// 2) In general, mixed precision may break the assumptions of some other HLO +// passes even if the specific backend supports the individual HLOs. Such +// assumptions include that there are no HLOs using mixed precision, or that the +// precision of an HLO's output is determined by its inputs. It should be used +// at the end of the HLO optimization pipeline but before +// BFloat16ConversionFolding. If other passes are needed after this pass, run +// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this +// pass. +class BFloat16Propagation : public HloPassInterface { + public: + explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); + + ~BFloat16Propagation() override = default; + + tensorflow::StringPiece name() const override { + return "bfloat16-propagation"; + } + + // Runs the pass on the given module. Returns whether the module was changed + // (precision reductions were added). + StatusOr Run(HloModule* module) override; + + private: + // *************************** + // Function called and state produced by the forward analysis pass (from + // parameters to root) that determines the candidate HLOs to use BF16 outputs. + + // Determines whether we should consider changing the precision of the given + // instruction in the forward pass. + bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); + + // The set of instructions to consider using bfloat16, computed in the forward + // pass. + tensorflow::gtl::FlatSet consider_using_bfloat16_; + + // *************************** + // Functions called and state produced by the backward pass (from root to + // parameters) that finds opportunities to use BF16. + + // Determines the precision for the given instruction in the + // opportunity-finding pass. + void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); + + // Special handling in the opportunity-finding pass for fusion computations. + // + // Precondition: hlo->opcode() == kFusion + void DetermineFusionComputationPrecision(HloInstruction* fusion); + + // Reverts changes to BF16 that will not propagate outside a fusion + // computation. This avoids BF16 casts overhead inside a fusion which won't + // save memory bandwidth. + // + // Precondition: hlo->opcode() == kFusion + void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); + + // Special handling in the opportunity-finding pass for while computations. + // + // Precondition: hlo->opcode() == kWhile + void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); + + // The set of HloInstructions that have been visited in the + // opportunity-finding pass. + tensorflow::gtl::FlatSet + instructions_visited_in_backward_pass_; + + // The set of HloComputations that have been visited in the + // opportunity-finding pass. + tensorflow::gtl::FlatSet + computations_visited_in_backward_pass_; + + // *************************** + // Functions called by the final inconsistency resolving pass. + + // Adjusts the output shapes of HloInstructions such that if two + // HloInstructions have aliasing buffers in their outputs, they must have the + // same precision. + void ResolveInconsistencyOfAliasingBuffers(HloModule* module); + + // Resolves inconsistency of aliasing buffers for the given computation, and + // recursively runs on a while instruction's condition and body until a fixed + // point is reached. + bool ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited_computations); + + // Makes the parameters of called computations match how they are called by + // the given HLO. + void AdjustCalledComputationParameters(HloInstruction* hlo); + + // Makes the root instructions of called computations match how they are used + // by the given HLO. + void AdjustCalledComputationRoot(HloInstruction* hlo); + + // *************************** + // Functions called after changes in changes_to_bf16_ are applied. + + // Resolves inconsistencies introduced by this pass for fusions with + // tuple-type output. + Status ResolveInconsistentFusions(HloModule* module); + + // Converts the literals in kConstant HLOs which have their types changed to + // BF16 by this pass. + Status ResolveConvertedConstants(HloModule* module); + + // Skips no-op conversions (same source and target shapes) that can be + // produced this pass, i.e., replaces them in their uses with their operands. + Status SkipNoopConversions(HloModule* module); + + // *************************** + // Functions called and state used by two or more passes. + + // Returns whether all uses of the given HloInstruction can consume BF16 + // input. + bool AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const; + + // The output element type of the HLO at the given shape index after changes + // in changes_to_bf16_ are applied. + PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, + const ShapeIndex& index) const; + + // The element type of the HLO value after changes in changes_to_bf16_ are + // applied. + PrimitiveType ValueTypeAfterChange(const HloValue* value) const; + + // If target_type == BF16, adds the HLO at the given index to + // changes_to_bf16_; otherwise, target_type must be F32 and this function + // removes the HLO at the given index from changes_to_bf16_ if it was earlier + // added. + void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, + const ShapeIndex& index, + PrimitiveType target_type); + + // The set of F32 HLO values that must be kept in F32. + tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; + + // Mapping from each HloComputation to the number of callers to it in the + // module. Populated at the beginning of this pass. + tensorflow::gtl::FlatMap caller_counts_; + + // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which + // are subject to further adjustment, then finally applied to the HLOs. This + // avoids setting changed_ to true but all changes are reverted during + // adjustment. + struct IndexHasher { + int64 operator()(const ShapeIndex& index) const { + int64 hash = 0; + for (int64 i : index) { + hash = tensorflow::Hash64Combine(hash, std::hash()(i)); + } + return hash; + } + }; + tensorflow::gtl::FlatMap> + changes_to_bf16_; + + // Whether the last processed HLO module has been changed by this pass. + bool changed_ = false; + + const BFloat16Support* bfloat16_support_; + std::unique_ptr dataflow_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e1499ee6b6ef397f95f7ed29e808d530777bd07 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -0,0 +1,745 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// A class specifying the BF16 support used to test the propagation pass. It +// specifies that BF16 and mixed precision are supported in all HloInstructions, +// and that kDot reduces its operands precision to BF16. +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return true; + } + + bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const override { + return hlo.opcode() == HloOpcode::kDot; + } +}; + +class BFloat16PropagationTest : public HloTestBase { + protected: + // Runs the propagation pass on the given module, and returns whether the + // module is changed after this pass. + bool PropagatePrecision(HloModule* module) { + TestBFloat16Support bfloat16_support; + BFloat16Propagation propagation(&bfloat16_support); + StatusOr result = propagation.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } + + // Returns whether the given HloInstruction's output element type is BF16 or + // the only use of it is converting to BF16. + bool OutputsBF16(const HloInstruction* inst) { + if (inst->shape().element_type() == BF16) { + return true; + } + return inst->user_count() == 1 && + inst->users()[0]->opcode() == HloOpcode::kConvert && + inst->users()[0]->shape().element_type() == BF16; + } +}; + +// Tests that BF16 can propagate through select over non-tuple buffers, but not +// through add where reducing operand precision can affect the result. +TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* sel = builder.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(sel)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); + EXPECT_FALSE(OutputsBF16(c)); +} + +// Tests that if a constant is converted to BF16 then its literal must also be +// converted. +TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Array2D array_a(4, 4); + array_a.FillUnique(1.0f); + Array2D array_b(4, 4); + array_b.FillUnique(10.0f); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromArray(array_a))); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromArray(array_b))); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(dot->operand(0))); + EXPECT_TRUE(OutputsBF16(dot->operand(1))); + EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); + EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); + EXPECT_TRUE(LiteralTestUtil::Equal( + dot->operand(0)->literal(), + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + EXPECT_TRUE(LiteralTestUtil::Equal( + dot->operand(1)->literal(), + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); +} + +// Tests that BF16 can be propagated through nested tuples. +TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); + + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose})); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1)); + HloInstruction* rhs = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + add0->shape(), + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple0->shape(), tuple1, 0)), + 0)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + + HloInstruction* output_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), output_tuple); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(add2)); +} + +// Tests that even if an instruction does not define a buffer in its output, its +// shape must match the defining instruction. +TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); + + HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); + + // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_TRUE(OutputsBF16(lhs)); + // rhs is a get-tuple-element, which does not define a buffer, but its shape + // should also be adjusted accordingly. + EXPECT_TRUE(OutputsBF16(rhs)); +} + +// Tests that a non-fusion computation's root should not be changed. +TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_FALSE(OutputsBF16(add)); +} + +// Tests that BF16 is propagated properly through fused computations. +TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f0 = HloComputation::Builder("fusion0"); + HloInstruction* a_f0 = + builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f0 = + builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* tuple_f0 = + builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0})); + auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build()); + auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f0)); + + auto builder_f1 = HloComputation::Builder("fusion1"); + HloInstruction* p_f1 = builder_f1.AddInstruction( + HloInstruction::CreateParameter(0, tuple_f0->shape(), "param")); + HloInstruction* a_f1 = builder_f1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); + HloInstruction* b_f1 = builder_f1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); + HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); + auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( + dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), fusion1); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(a_f0)); + EXPECT_TRUE(OutputsBF16(b_f0)); + EXPECT_TRUE(OutputsBF16(a_f1)); + EXPECT_TRUE(OutputsBF16(b_f1)); +} + +// Tests that changes to BF16 that cannot be propagated outside a fusion are +// discarded. +TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), fusion); +} + +// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion +// outputs are only used by a dot, and 3) one element of the tuple is used by +// an add in the fusion computation, then the propagation pass should create a +// convert in the fusion computation to keep the add's operand in F32 but change +// the fusion output to BF16. E.g., the following fusion computation +// (F32, F32) fusion_computation(F32 a, F32 b) +// = tuple(F32 a, F32 add(F32 a, F32 b)) +// will be changed to +// (BF16, BF16) fusion_computation(F32 a, F32 b) +// = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) +TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion0"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* tuple_f = + builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f)); + + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(gte0)); + EXPECT_TRUE(OutputsBF16(gte1)); + EXPECT_FALSE(OutputsBF16(a_f)); + EXPECT_FALSE(OutputsBF16(b_f)); + EXPECT_TRUE(OutputsBF16(add_f)); + auto new_fusion_root = comp_f->root_instruction(); + EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(new_fusion_root->operand(1), add_f); + EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0))); +} + +// A select over tuples does not define the leaf buffers, so the types in +// on_true and on_false must match, so that as long as one of them is F32, the +// other must be F32 as well. +TEST_F(BFloat16PropagationTest, SelectOverTuples) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(PRED, {}), "pred")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({param, add0})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); + HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( + tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, sel, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, sel, 1)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(gte0)); + EXPECT_FALSE(OutputsBF16(gte1)); + EXPECT_TRUE(OutputsBF16(xpose)); +} + +// Tests that BF16 is propagated properly through a while computation with +// non-tuple input/output. +TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, shape, "cond_param")); + auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, shape, "body_param")); + auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, body_param, body_param)); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(shape, cond, body, add)); + + auto dot = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE( + ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {}))); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_param)); + EXPECT_TRUE(OutputsBF16(cond_param)); + EXPECT_FALSE(OutputsBF16(dot)); +} + +// Tests that BF16 is propagated properly through while computations with +// tuple-shaped input/output. +TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, tuple->shape(), "cond_param")); + auto cond_lhs = builder_cond.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond_param, 0)); + auto cond_rhs = builder_cond.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond_param, 1)); + // This add should prevent RHS from using BF16 + auto cond_add_rhs = builder_cond.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); + auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, tuple->shape(), "body_param")); + auto body_lhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + auto body_rhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 1)); + auto body_dot = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + builder_body.AddInstruction( + HloInstruction::CreateTuple({body_dot, body_rhs})); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple)); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(lhs)); + EXPECT_FALSE(OutputsBF16(rhs)); + EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_lhs)); + EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_TRUE(OutputsBF16(cond_lhs)); + EXPECT_FALSE(OutputsBF16(cond_rhs)); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(add1)); +} + +// Tests that BF16 is not propagated through multiple whiles that invoke the +// same computation as long as one while prevents the propagation. +TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add3 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({add2, add3})); + + // Condition computation for the first while. + auto builder_cond0 = HloComputation::Builder("cond0"); + auto cond0_param = builder_cond0.AddInstruction( + HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param")); + auto cond0_lhs = builder_cond0.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond0_param, 0)); + auto cond0_rhs = builder_cond0.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond0_param, 1)); + // This add should prevent RHS from using BF16 + auto cond0_add_rhs = + builder_cond0.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); + auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + builder_cond0.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); + + // Condition computation for the second while. + auto builder_cond1 = HloComputation::Builder("cond1"); + auto cond1_param = builder_cond1.AddInstruction( + HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param")); + auto cond1_lhs = builder_cond1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond1_param, 0)); + auto cond1_rhs = builder_cond1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond1_param, 1)); + // This add should prevent LHS from using BF16 + auto cond1_add_lhs = + builder_cond1.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); + auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + builder_cond1.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); + + // Body computation shared by both whiles. + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, tuple0->shape(), "body_param")); + auto body_lhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + auto body_rhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 1)); + auto body_dot = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + builder_body.AddInstruction( + HloInstruction::CreateTuple({body_dot, body_rhs})); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); + + auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_FALSE(OutputsBF16(body_dot)); + EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_lhs)); + EXPECT_FALSE(OutputsBF16(cond0_lhs)); + EXPECT_FALSE(OutputsBF16(cond0_rhs)); + EXPECT_FALSE(OutputsBF16(cond1_lhs)); + EXPECT_FALSE(OutputsBF16(cond1_rhs)); + EXPECT_TRUE(OutputsBF16(cond0_add_rhs)); + EXPECT_TRUE(OutputsBF16(cond1_add_lhs)); + EXPECT_EQ(computation->root_instruction(), dot); +} + +// Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 -> +// BF16 conversion), then it will remove that conversion. +TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "param")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1)); + HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( + bf16_shape, HloOpcode::kAdd, convert0, convert1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), add2); + EXPECT_EQ(add2->operand(0), gte0); + EXPECT_EQ(add2->operand(1), gte1); + EXPECT_EQ(gte0->shape().element_type(), BF16); + EXPECT_EQ(gte1->shape().element_type(), BF16); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 3fd9e24601f27633c8063e4574c7c4f91f30dcff..07b4b14b5ec1bdbc01345091105df69368b0b2fb 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -79,6 +79,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kBroadcast: case HloOpcode::kClamp: case HloOpcode::kConcatenate: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h index 29f662d22b4e5486662a1387407d41e0fd2ed1b3..82c2745f444e4f9c544c78cb36dafc11f678518a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.h +++ b/tensorflow/compiler/xla/service/bfloat16_support.h @@ -39,7 +39,7 @@ class BFloat16Support { // precisions (BF16 and F32). virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const; - // Returns whether the given HLO inherits its BF16 operand precision at the + // Returns whether the given HLO preserves its BF16 operand precision at the // given index, so even if the output is F32, elements in the output that // depend on the BF16 operand will still have BF16 effective precision even if // they have F32 format. Similarly, this also means if the output is BF16 then diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b1e693da9d5af4babe619b8796007f2da318f6a8..5d3b0cb333928d3b7b042ef0a6f4969f87655d7f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -48,6 +49,184 @@ using ::tensorflow::strings::HumanReadableNumBytes; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrAppend; +namespace { + +template +string ColocatedBufferSetsToString(const T& container, const char* title) { + string result; + StrAppend(&result, title, "\n"); + for (const auto& it : container) { + StrAppend(&result, "\t", it->ToString(), "\n"); + } + return result; +} + +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations) { + // Create a worklist of computations paired with whether the allocation must + // be thread-local. + std::deque> worklist; + worklist.push_back(std::make_pair(module->entry_computation(), + /*is_thread_local*/ false)); + + // Sets for quickly checking membership. Computations are returned in vectors + // for stable iteration. + FlatSet thread_local_set; + FlatSet global_set; + + while (!worklist.empty()) { + auto worklist_front = worklist.front(); + worklist.pop_front(); + const HloComputation* computation = worklist_front.first; + bool is_thread_local = worklist_front.second; + bool in_thread_local_set = thread_local_set.count(computation) > 0; + bool in_global_set = global_set.count(computation) > 0; + + // If the computation has already been added to the respective set, then + // nothing to do. + if ((is_thread_local && in_thread_local_set) || + (!is_thread_local && in_global_set)) { + continue; + } + + // If the computation has already been added to the other set this is an + // error condition because the global call to the computation (eg, + // while/call) may return a reference to one of the thread-local buffers to + // the calling computation which will become a dangling reference when the + // thread-local is deallocated with the call return. + if ((is_thread_local && in_global_set) || + (!is_thread_local && in_thread_local_set)) { + return InvalidArgument( + "computation %s has conflicting allocation requirements (global " + "and thread-local)", + computation->name().c_str()); + } + + if (is_thread_local) { + thread_local_set.insert(computation); + } else { + global_set.insert(computation); + } + + for (auto* instruction : computation->instructions()) { + for (HloComputation* subcomputation : + instruction->called_computations()) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kWhile: + // Call and while must be called from a computation with global + // allocations as they may return references to buffers inside the + // called computation which cannot be thread-local. + if (is_thread_local) { + return InvalidArgument( + "computation %s cannot contain call/while op because it " + "requires thread-local buffer allocations", + computation->name().c_str()); + } + worklist.push_back(std::make_pair(subcomputation, + false)); // Not thread local. + break; + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + // Map/reduce etc computations are always thread-local. + worklist.push_back(std::make_pair(subcomputation, + true)); // Thread local. + break; + default: + return InternalError( + "Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode()).c_str()); + } + } + } + } + + // Add the computations to the vectors in post order. + for (auto* computation : module->MakeComputationPostOrder()) { + if (thread_local_set.count(computation) > 0) { + thread_local_computations->push_back(computation); + } else if (global_set.count(computation) > 0) { + global_computations->push_back(computation); + } + // If the computation is not reachable from the entry computation, then it + // will not appear in either thread_local_set or global_set. We don't bother + // assigning buffers for these. + } + return Status::OK(); +} + +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace + size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); @@ -327,6 +506,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, BufferAllocation* allocation = NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color()); AddAssignment(allocation, buffer, /*offset=*/0, size); + allocation->peak_buffers_.push_back(&buffer); return allocation; } @@ -348,6 +528,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // Combines allocations of temporary buffers of the same color into one big // BufferAllocation. void BufferAssignment::CombineTempAllocations() { + VLOG(1) << "CombineTempAllocations()"; FlatMap combined_allocation_map; @@ -369,11 +550,16 @@ void BufferAssignment::CombineTempAllocations() { if (combined_it == combined_allocation_map.end()) { // We have found the first temp allocation of this color. Collect // the other temp allocations of the same color into it. + VLOG(1) << "Combined temp allocation for color " << color + << " is: " << temp_allocation; combined_allocation_map.emplace(color, temp_allocation); continue; } auto* combined_allocation = &combined_it->second; + VLOG(1) << "Combined allocation absorbing temp allocation: " + << temp_allocation; + // Each temp allocation is placed end-to-end, accounting for alignment. // The offset of each buffer in the combined allocation is computed from // the base offset of the allocation. @@ -387,6 +573,14 @@ void BufferAssignment::CombineTempAllocations() { const int64 size = buffer_offset_size.second.size; combined_allocation->AddAssignment(*buffer, base + offset, size); } + if (!temp_allocation.HeapTraces().empty()) { + CHECK_EQ(temp_allocation.HeapTraces().size(), 1); + combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front()); + } + combined_allocation->peak_buffers_.insert( + combined_allocation->peak_buffers_.end(), + temp_allocation.peak_buffers_.begin(), + temp_allocation.peak_buffers_.end()); } // Replace all existing temporary allocations with the new combined // allocations. @@ -437,9 +631,8 @@ Status BufferAssignment::ComputeSummaryStats() { } } if (module_sequence.size() == module_->computation_count()) { - TF_ASSIGN_OR_RETURN( - const int64 min_size, - MinimumMemoryForSequence(module_sequence, buffer_size_)); + TF_ASSIGN_OR_RETURN(const int64 min_size, + MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -507,7 +700,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto::BufferAlias* proto_alias = proto.add_buffer_aliases(); LogicalBufferProto::Location proto_alias_location = - LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index()); + BufferValue::ToLocationProto(*alias.instruction(), alias.index()); proto_alias->set_source_buffer_id(buffer.id()); proto_alias->mutable_location()->Swap(&proto_alias_location); } @@ -516,123 +709,13 @@ BufferAssignmentProto BufferAssignment::ToProto() const { for (const BufferAllocation& allocation : Allocations()) { BufferAllocationProto proto_allocation = allocation.ToProto(); proto.add_buffer_allocations()->Swap(&proto_allocation); - } - for (const HeapSimulatorTrace& trace : heap_simulator_traces_) { - *proto.add_heap_simulator_traces() = trace; - } - return proto; -} - -namespace { - -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). -Status GatherComputationsByAllocationType( - const HloModule* module, - std::vector* thread_local_computations, - std::vector* global_computations) { - // Create a worklist of computations paired with whether the allocation must - // be thread-local. - std::deque> worklist; - worklist.push_back(std::make_pair(module->entry_computation(), - /*is_thread_local*/ false)); - - // Sets for quickly checking membership. Computations are returned in vectors - // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; - - while (!worklist.empty()) { - auto worklist_front = worklist.front(); - worklist.pop_front(); - const HloComputation* computation = worklist_front.first; - bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; - - // If the computation has already been added to the respective set, then - // nothing to do. - if ((is_thread_local && in_thread_local_set) || - (!is_thread_local && in_global_set)) { - continue; - } - - // If the computation has already been added to the other set this is an - // error condition because the global call to the computation (eg, - // while/call) may return a reference to one of the thread-local buffers to - // the calling computation which will become a dangling reference when the - // thread-local is deallocated with the call return. - if ((is_thread_local && in_global_set) || - (!is_thread_local && in_thread_local_set)) { - return InvalidArgument( - "computation %s has conflicting allocation requirements (global " - "and thread-local)", - computation->name().c_str()); - } - - if (is_thread_local) { - thread_local_set.insert(computation); - } else { - global_set.insert(computation); - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* subcomputation : - instruction->called_computations()) { - switch (instruction->opcode()) { - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kWhile: - // Call and while must be called from a computation with global - // allocations as they may return references to buffers inside the - // called computation which cannot be thread-local. - if (is_thread_local) { - return InvalidArgument( - "computation %s cannot contain call/while op because it " - "requires thread-local buffer allocations", - computation->name().c_str()); - } - worklist.push_back(std::make_pair(subcomputation, - false)); // Not thread local. - break; - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kFusion: - // Map/reduce etc computations are always thread-local. - worklist.push_back(std::make_pair(subcomputation, - true)); // Thread local. - break; - default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); - } - } - } - } - - // Add the computations to the vectors in post order. - for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { - thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { - global_computations->push_back(computation); + for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) { + *proto.add_heap_simulator_traces() = heap_trace; } - // If the computation is not reachable from the entry computation, then it - // will not appear in either thread_local_set or global_set. We don't bother - // assigning buffers for these. } - return Status::OK(); + return proto; } -} // namespace - /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, @@ -1001,7 +1084,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1029,7 +1114,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1045,6 +1132,89 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( return Status::OK(); } +namespace { + +// Computes and returns the set of logical buffers live at the point of maximal +// liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id. +std::vector ComputePeakMemoryLogicalBuffers( + const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { + // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical + // buffers in this allocation. + tensorflow::gtl::FlatMap + id_to_buffer; + tensorflow::gtl::FlatMap buffer_sizes; + for (const auto& pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + const BufferAllocation::OffsetSize& offset_size = pair.second; + id_to_buffer[buffer->id()] = buffer; + buffer_sizes[buffer] = offset_size.size; + } + + // Returns how much the given event increases the total size of live + // buffers. Can be negative. + auto memory_delta = [&id_to_buffer, &buffer_sizes]( + const HeapSimulatorTrace::Event& event) -> int64 { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + const int64 buffer_size = buffer_sizes.at(buffer); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + return buffer_size; + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Sharing a buffer does not change the live set size for the purposes of + // the heap simulator. Even though the shared-with buffer may be smaller, + // the entire allocation remains live. + return 0; + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + return -1 * buffer_size; + } + LOG(FATAL) << "Unknown event kind: " << event.kind(); + }; + + // First compute the size of the maximal live set. + int64 max_live_size = 0; + int64 live_size = 0; + for (const auto& event : heap_trace.events()) { + live_size += memory_delta(event); + if (max_live_size < live_size) { + max_live_size = live_size; + } + } + + // Next gather the set of logical buffers live at the earliest point of + // maximal live set size. + tensorflow::gtl::FlatSet live_buffers; + live_size = 0; + for (const auto& event : heap_trace.events()) { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + InsertOrDie(&live_buffers, buffer); + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Nothing to do. + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + CHECK(ContainsKey(live_buffers, buffer)); + live_buffers.erase(buffer); + } + + live_size += memory_delta(event); + if (live_size == max_live_size) { + break; + } + } + CHECK_EQ(live_size, max_live_size); + + std::vector live_buffers_vector; + live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), + live_buffers.end()); + + // Stabily sort the live buffers. + std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + return live_buffers_vector; +} + +} // namespace + void BufferAssigner::AssignBuffersFromHeapSimulator( const HeapSimulator::Result& result, BufferAssignment* assignment, LogicalBuffer::Color color) { @@ -1059,12 +1229,18 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); for (const auto& buffer_chunk : result.chunk_map) { - const LogicalBuffer& buffer = *buffer_chunk.first; + // TODO(lauj) Remove this down_cast after downstream users of + // BufferAllocation::assigned_buffers() are updated to use BufferValue. + const LogicalBuffer& buffer = + *CHECK_NOTNULL(dynamic_cast(buffer_chunk.first)); const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } + allocation->peak_buffers_ = + ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace); - assignment->heap_simulator_traces_.push_back(result.debug_trace); + VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString(); + allocation->AddHeapTrace(result.debug_trace); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -1085,7 +1261,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( if (colocated_set.empty()) { return; } - + VLOG(5) << ColocatedBufferSetsToString(colocated_set, + "Adding colocated buffer set"); // Find existing sets that overlap with at least one buffer from the // colocated_set. The resulting 'overlap_set_indices' will have at most // colocated_buffer_sets->size() entries, and will be in increasing order. @@ -1093,6 +1270,10 @@ void BufferAssigner::AddSetToColocatedBufferSets( for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + VLOG(5) << "Found overlap with existing set on buffer " + << buffer->ToString() << "\n" + << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], + "Overlapping set"); overlap_set_indices.push_back(index); break; } @@ -1104,6 +1285,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( colocated_buffer_sets->emplace_back(); colocated_buffer_sets->back().insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << "No overlap found, new group created"; return; } @@ -1115,6 +1297,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( first->insert(overlap_set.begin(), overlap_set.end()); } first->insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << ColocatedBufferSetsToString( + *first, "Result of the colocated buffer set merging"); // Remove overlap sets that we just merged. The offset accounts for the fact // that as elements are erased, the indices need to be adjusted. Keep in mind @@ -1125,67 +1309,6 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } -namespace { - -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - std::vector BufferAssigner::MergeColocatedBufferSets( const std::vector& colocated_buffer_sets, @@ -1208,26 +1331,35 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - for (auto& buffer_a : colocated_buffer_sets[i]) { - for (auto& buffer_b : colocated_buffer_sets[j]) { - // Do not merge if the set includes live outs or entry parameters. - if ((buffer_liveness.MaybeLiveOut(*buffer_a) && - is_entry_parameter(*buffer_b)) || - (buffer_liveness.MaybeLiveOut(*buffer_b) && - is_entry_parameter(*buffer_a))) { + // Do not merge if one of the sets includes live outs or entry parameters. + for (int64 key : {i, j}) { + for (auto& buffer : colocated_buffer_sets[key]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer)) { return true; } - // Do not merge if the buffers interfere with each other. + } + } + + // Colocated sets satisfy the invariant that all buffers within a set have + // the same size. That means we need to check whether the size is the same + // between the two sets, but also that it's enough to look at just one + // buffer within each set. + if (buffer_size(**colocated_buffer_sets[i].begin()) != + buffer_size(**colocated_buffer_sets[j].begin())) { + return true; + } + + // Do not merge if some pair of buffers interferes with each other. + for (auto& buffer_a : colocated_buffer_sets[i]) { + for (auto& buffer_b : colocated_buffer_sets[j]) { if (buffer_a->id() != buffer_b->id() && buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { return true; } - // Do not merge if the buffer sizes are different. - if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) { - return true; - } } } + return false; }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 6b7fd0014d103ef0617afcc5cb3f663554a01aa4..ad0b0bf7c25d7194a06801e4ef1c9ee961f6b915 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -192,6 +192,35 @@ class BufferAllocation { !is_thread_local(); } + // Add a heap trace which was used to assign slices to logical buffers in this + // allocation. A single BufferAllocation may include multiple heap traces + // in the case of the temporary block where there is a heap trace per + // computation. + void AddHeapTrace(const HeapSimulatorTrace& heap_trace) { + heap_traces_.push_back(heap_trace); + } + + // Return the set of heap traces used to assign slices to logical buffers in + // this allocation. + const std::vector HeapTraces() const { + return heap_traces_; + } + + // Returns the LogicalBuffers which are live at the point of peak memory usage + // for this allocation. The point of peak memory usage is the point at which + // the total size of all live logical buffers is maximal. If peak memory is + // reached at multiple points, the set of logical buffers live at the earliest + // maximal point is returned. The vector is stabily sorted by + // LogicalBuffer::Index. + const std::vector& PeakMemoryLogicalBuffers() const { + return peak_buffers_; + } + + // Get the number of bytes lost to fragmentation. This is equal to the + // difference between the size of the allocation and the size of the maximal + // live set. + int64 fragmentation_bytes() const { return fragmentation_bytes_; } + bool operator==(const BufferAllocation& other) const { return index_ == other.index_; } @@ -257,6 +286,12 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. tensorflow::gtl::FlatMap assigned_buffers_; + + int64 fragmentation_bytes_ = 0; + std::vector heap_traces_; + + // Set of buffers live at the point of peak memory usage for this allocation. + std::vector peak_buffers_; }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. @@ -380,10 +415,10 @@ class BufferAssignment { // Only BufferAssigner can build or modify BufferAssignments. friend class BufferAssigner; - explicit BufferAssignment(const HloModule* module, - std::unique_ptr liveness, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) + BufferAssignment(const HloModule* module, + std::unique_ptr liveness, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), buffer_size_(std::move(buffer_size)), @@ -441,7 +476,6 @@ class BufferAssignment { LogicalBuffer::AlignmentFunction color_alignment_; Stats stats_; - std::vector heap_simulator_traces_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index cd73654b8f666c4b96c000235cc3ad2cd0a46c17..efa4696130ffeff669b0d674438a45c5a9d48ef2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -42,9 +43,10 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" namespace xla { - namespace { +using ::testing::UnorderedElementsAre; + // DFS visitor that collects the instructions referenced by a computation // without descending into nested computations, i.e., only from the operands. class InstructionListVisitor : public DfsHloVisitorWithDefault { @@ -79,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { class BufferAssignmentTest : public HloTestBase { protected: - BufferAssignmentTest() : computation_tracker_() {} + BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -101,6 +103,22 @@ class BufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr RunBufferAssignmentWithInstructionSequence( + HloModule* module, + tensorflow::gtl::ArraySlice instruction_sequence, + int64 alignment = 1) { + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[module->entry_computation()] = + std::vector(instruction_sequence.begin(), + instruction_sequence.end()); + return BufferAssigner::Run( + module, + xla::MakeUnique(module, module_sequence), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }) + .ConsumeValueOrDie(); + } + // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -233,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Computation tracker for nested computations. - ComputationTracker computation_tracker_; - // Shapes for use in the examples. Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); @@ -356,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -403,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // share anything. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -462,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // have the color 0, which allows the mul and add to share buffers. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -532,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -586,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); module->AddEntryComputation(builder.Build()); @@ -639,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); auto exp2 = builder.AddInstruction( @@ -803,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); auto tanh = builder.AddInstruction( @@ -1370,7 +1385,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto element_slices = assignment->GetAllSlices(select, /*index=*/{0}); EXPECT_EQ(2, element_slices.size()); EXPECT_THAT(element_slices, - ::testing::UnorderedElementsAre( + UnorderedElementsAre( assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) .ConsumeValueOrDie(), assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) @@ -1473,6 +1488,156 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { } } +TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "p1")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "p2")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Trivially, the set of peak memory logical buffer(s) of an allocation with a + // single logical buffer should be exactly the logical buffer in that + // allocation. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + const std::vector& peak_buffers = + mul_buffer.PeakMemoryLogicalBuffers(); + ASSERT_EQ(peak_buffers.size(), 1); + EXPECT_EQ(peak_buffers[0]->instruction(), mul); +} + +TEST_F(BufferAssignmentTest, PeakBuffers) { + // Compute the peak liveness buffers of the following sequence: + // + // %param = ... + // %log = log(%param) + // %rev = reverse(%log) + // %neg = neg(%param) + // %concat = concat(%rev, %neg) + // ROOT %root = slice(concat) + // + // In the temporary block, the set of live buffers at peak memory use should + // be {%rev, %neg, %concat}. This occurs right at the concat itself. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "p")); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); + auto rev = builder.AddInstruction( + HloInstruction::CreateReverse(f32vec100_, log, {0})); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + const Shape concat_shape = ShapeUtil::MakeShape(F32, {200}); + auto concat = builder.AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0)); + // Make the root tiny so no interior nodes can share its buffer. + auto root = builder.AddInstruction(HloInstruction::CreateSlice( + + ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignmentWithInstructionSequence( + module.get(), {param, log, rev, neg, concat, root}); + + // The temporary buffer should hold the 4 interior instructions. + const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); + EXPECT_FALSE(buffer.IsInputOrOutput()); + EXPECT_TRUE(buffer.IsPreallocatedTempBuffer()); + ASSERT_EQ(buffer.assigned_buffers().size(), 4); + + const std::vector& peak_buffers = + buffer.PeakMemoryLogicalBuffers(); + + // The peak live set should be concat and its inputs. + ASSERT_EQ(peak_buffers.size(), 3); + std::vector peak_instructions; + for (const LogicalBuffer* logical_buffer : peak_buffers) { + peak_instructions.push_back(logical_buffer->instruction()); + } + EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); +} + +TEST_F(BufferAssignmentTest, PeakBuffersWhile) { + auto module = CreateNewModule(); + const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); + HloComputation* condition; + { + auto b = HloComputation::Builder(TestName() + ".cond"); + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + condition = module->AddEmbeddedComputation(b.Build()); + } + HloComputation* body; + { + auto b = HloComputation::Builder(TestName() + ".body"); + auto param = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + b.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param)); + body = module->AddEmbeddedComputation(b.Build()); + } + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param)); + auto while_op = builder.AddInstruction( + HloInstruction::CreateWhile(shape, condition, body, copy)); + // This broadcast should get a temporary allocation which is merged with the + // allocation for the while. Peak buffers should include the while and the + // broadcast. + auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1})); + builder.AddInstruction(HloInstruction::CreateReverse( + ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); + const std::vector& peak_buffers = + buffer.PeakMemoryLogicalBuffers(); + ASSERT_EQ(peak_buffers.size(), 2); + + // The peak buffers should include the broadcast and one of the colocated + // buffers of the while (body param, condition param, body root, or the while + // itself). + const LogicalBuffer* bcast_buffer; + const LogicalBuffer* nonbcast_buffer; + if (peak_buffers[0]->instruction() == bcast) { + bcast_buffer = peak_buffers[0]; + nonbcast_buffer = peak_buffers[1]; + } else { + bcast_buffer = peak_buffers[1]; + nonbcast_buffer = peak_buffers[0]; + } + EXPECT_EQ(bcast_buffer->instruction(), bcast); + EXPECT_TRUE( + nonbcast_buffer->instruction() == copy || + nonbcast_buffer->instruction() == while_op || + nonbcast_buffer->instruction() == body->parameter_instruction(0) || + nonbcast_buffer->instruction() == body->root_instruction() || + nonbcast_buffer->instruction() == condition->parameter_instruction(0)); +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1508,7 +1673,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, @@ -1516,7 +1681,7 @@ class WhileBufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } - static int64 ByteSizeOf(const LogicalBuffer& buffer) { + static int64 ByteSizeOf(const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); } @@ -1531,7 +1696,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1587,6 +1752,81 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); } +// Tests that two colocated buffer sets are not merged if an entry parameter +// buffer belongs to either of the colocation sets (b/73267882). +// +// %param --> %while.0 --> %mul --> %while.1 --> %broadcast +// +// %while.0 body just forwards the init value, so the loop carried variable +// remains the constant, whereas %while.1 changes the loop carried variable. +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %param.3 = s32[] parameter(0) + %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* param = + module->entry_computation()->parameter_instruction(0); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_param, + assignment->GetUniqueSlice(param, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The parameter slice is part of the while0's colocation set (init value), + // but not merged into the while1's colocation set. + EXPECT_EQ(slice_param, slice_while0); + EXPECT_NE(slice_param, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. @@ -1631,7 +1871,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); @@ -1699,7 +1939,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1744,7 +1984,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1809,7 +2049,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -1863,7 +2103,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually @@ -1888,7 +2128,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 37982aaef9eddd64ef6b57ad5a9cf8dd6a565097..810d597e730c1823668c81598df6138655e58b55 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +43,7 @@ StatusOr> BufferLiveness::Run( return std::move(liveness); } -tensorflow::Status BufferLiveness::Analyze() { +Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { @@ -71,7 +70,7 @@ tensorflow::Status BufferLiveness::Analyze() { } XLA_VLOG_LINES(3, ToString()); - return tensorflow::Status::OK(); + return Status::OK(); } string BufferLiveness::ToString() const { @@ -105,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(), + alias.index(), user)) { continue; } if (user != b.instruction() && @@ -132,9 +131,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && - !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), - b.instruction(), b.index(), - points_to_analysis())) { + !points_to_analysis().CanShareOperandBufferWithUser( + alias.instruction(), alias.index(), b.instruction(), b.index())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 11834a5127e383cc2ec2ab3fe1bb82ba86e4abed..cdd3cf4032ef6916086e1c2d148b575192503000 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -89,7 +89,7 @@ class BufferLiveness { // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. - tensorflow::Status Analyze(); + Status Analyze(); // Returns true if the live range of the buffer of 'a' is strictly before the // live range of the buffer of 'b' (they do not overlap). diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bc556a9e270136f5f3eaf2433f8c96eeeaea0a2 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_value.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, + Id id) + : id_(id) { + const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); + is_array_ = ShapeUtil::IsArray(shape); + is_tuple_ = ShapeUtil::IsTuple(shape); +} + +BufferValue::~BufferValue() {} + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer) { + out << buffer.ToString(); + return out; +} + +/*static*/ LogicalBufferProto::Location BufferValue::ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index) { + LogicalBufferProto::Location proto; + proto.set_computation_name(instruction.parent()->name()); + proto.set_instruction_name(instruction.name()); + for (const int64 index_entry : index) { + proto.add_shape_index(index_entry); + } + return proto; +} + +LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const { + LogicalBufferProto proto; + proto.set_id(id()); + proto.set_size(size_fn(*this)); + LogicalBufferProto::Location proto_location = + ToLocationProto(*instruction(), index()); + proto.mutable_defined_at()->Swap(&proto_location); + if (has_color()) { + proto.set_color(color().value()); + } + return proto; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h new file mode 100644 index 0000000000000000000000000000000000000000..f4be16e0843f64f41ef27539bf263ae98ce0ebf9 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Abstract class describing a value used by one of the dataflow analyses - +// TuplePointsToAnalysis or HloDataflowAnalysis. +// TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. +// +// XLA arrays are trivially a single BufferValue. Tuples are made up of more +// than one BufferValue: an BufferValue for the pointer vector, and an +// BufferValue for each child element. +// +// Every BufferValue is defined by a particular instruction and most +// instructions define only a single BufferValue. Instructions which define a +// single BufferValue include array-shaped instructions such as Add but also +// includes Tuple-shaped instructions such as Tuple. The Tuple instruction +// defines a single BufferValue which is a vector of pointers to the values +// containing the Tuple instruction's operands. Though the result of the Tuple +// instruction includes multiple values only the top-level BufferValue (the +// vector of pointers) is defined by the Tuple instruction. The values +// containing the tuple elements are defined by earlier instructions, usually +// the operands of the Tuple instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one BufferValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing BufferValues like the Tuple instruction +// does, but rather define all the BufferValues in the tuple. +// +// Some instructions, such as Bitcast, define no buffers. These instructions +// simply forward buffers from their operands. +// +// The BufferValue object describes which HLO instruction defines a buffer and +// where within that instruction's output shape the buffer is defined. The +// location within the output shape is indicated by BufferValue::index() which +// is defined identically to the index used in ShapeUtil::GetSubshape(). +// Examples: +// +// %add = Add(%foo, %bar) +// %tuple_constant = Constant({1, {42, 43}}) +// +// %add defines a single array-shaped buffer BufferValue(%add, {}) which holds +// the array result of the add operation. The nested-tuple-shaped +// %tuple_constant defines 5 buffers described by the following BufferValue +// objects: +// +// BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of +// // pointers to BufferValues at +// // indices {0} and {1} +// BufferValue(%tuple_constant, {0}) // Holds value "1" +// BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of +// // pointers to BufferValues at +// // indices {1, 0} and {1, 1} +// BufferValue(%tuple_constant, {1, 0}) // Holds value "42" +// BufferValue(%tuple_constant, {1, 1}) // Holds value "43" + +class BufferValue { + public: + TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); + + // Id is a unique identifier for the BufferValue to facilitate efficient + // collections of BufferValues with stable iteration order. + using Id = int64; + + // Functions which return the size and alignment of a logical buffer in bytes. + using SizeFunction = std::function; + using AlignmentFunction = std::function; + + virtual ~BufferValue(); + + Id id() const { return id_; } + + // Return the instruction that defines the buffer. + virtual HloInstruction* instruction() const = 0; + + // Return the index within the output of the instruction where the buffer is + // defined. Index used defined as in ShapeUtil::GetSubshape() + virtual const ShapeIndex& index() const = 0; + + // Return the color of the BufferValue. Differently colored buffers can not be + // parts of the same allocation. + Color color() const { + CHECK_NE(color_, kInvalidColor) + << "Should not query the color of a buffer that was never colored"; + return color_; + } + + void set_color(Color color) { + CHECK_NE(color, kInvalidColor) + << "Should not set the color of a buffer to the invalid color"; + color_ = color; + } + + bool has_color() const { return color_ != kInvalidColor; } + + // Return the shape of the buffer. This reference points into the shape field + // of the instruction defining the buffer. Therefore, the returned shape will + // contain the layout of instruction, if any. + virtual const Shape& shape() const = 0; + + // Returns true if this buffer is the top-level output buffer of the defining + // HLO instruction. This is equivalent to index == {}. + bool IsTopLevel() const { return index().empty(); } + + // Whether this buffer contains a tuple. + bool IsTuple() const { return is_tuple_; } + + // Whether this buffer contains an array. + bool IsArray() const { return is_array_; } + + // operator< is required for std::set. + bool operator<(const BufferValue& other) const { return id_ < other.id_; } + + virtual string ToString() const = 0; + + // TODO(lauj) rename LogicalBufferProto to BufferValueProto. + LogicalBufferProto ToProto(const SizeFunction& size_fn) const; + + // Returns the LogicalBufferProto::Location that serializes the given + // instruction and index. + static LogicalBufferProto::Location ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index); + + const Color kInvalidColor = Color(-1); + + protected: + BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); + + private: + // The definining instruction and index are not stored here; they can be found + // in the LogicalBuffer and HloValue subclasses. This class exists only to + // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by + // allowing abstract use of LogicalBuffer or HloValue. After those migrations + // are complete, this class should be deleted (b/78906445). Because we plan to + // delete LogicalBuffer and this class, we don't refactor all the shared + // features from LogicalBuffer and HloValue into this class. + Id id_ : 62; + bool is_array_ : 1; + bool is_tuple_ : 1; + Color color_ = kInvalidColor; +}; + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h new file mode 100644 index 0000000000000000000000000000000000000000..305914fca828f110bf54239bddb1590172562b16 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ + +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/core/lib/gtl/compactptrset.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Define various containers of BufferValues, and utilities to convert from +// containers of LogicalBuffers to containers of BufferValues. + +using BufferValueCompactPointerSet = + tensorflow::gtl::CompactPointerSet; +template +BufferValueCompactPointerSet ToBufferValueCompactPointerSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueCompactPointerSet output; + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +using BufferValueFlatSet = tensorflow::gtl::FlatSet; +template +BufferValueFlatSet ToBufferValueFlatSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueFlatSet output; + output.reserve(logical_buffer_container.size()); + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c..a23427f00ccd88bb0fe1d973a667f80ca54b14cd 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -51,12 +51,13 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { return out; } -CallContext GetInstructionCallContext(const HloInstruction* instruction) { - switch (instruction->opcode()) { +CallContext GetInstructionCallContext(HloOpcode opcode) { + switch (opcode) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -101,7 +102,7 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { CHECK_EQ(instruction->parent(), computation()); - const CallContext context = GetInstructionCallContext(instruction); + const CallContext context = GetInstructionCallContext(instruction->opcode()); if (!instruction->called_computations().empty()) { CHECK(context == CallContext::kSequential || context == CallContext::kParallel); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 688c4085dfb4f47d3e08a4abee5e7b645f595b11..97d3811508adee1bf2d0942bcc69e3e34a41c8c3 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -53,7 +53,7 @@ enum class CallContext { string CallContextToString(CallContext context); std::ostream& operator<<(std::ostream& out, const CallContext& context); -CallContext GetInstructionCallContext(const HloInstruction* instruction); +CallContext GetInstructionCallContext(HloOpcode opcode); // Represents an HLO instruction which calls one or more computations. class CallSite { diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index c7763f2ca3e68490cd0cd9b4ba4d7bd180134080..fac0afd672ff3ed083aacf778dd9c4f90a2ee870 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,9 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc deleted file mode 100644 index b16907da9e9c909d2639f83895db27d724a84a7b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/compilation_cache.h" - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -std::shared_ptr CompilationCache::Insert( - std::unique_ptr executable, - const HloModuleConfig& module_config) { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = - BuildKey(executable->entry_computation_handle(), module_config); - VLOG(2) << "inserting cache key: " << key; - if (cache_.count(key) == 0) { - cache_.emplace(key, std::move(executable)); - } else { - // Executable already exists in the cache. This can happen if two Execute - // calls for a new computation are received simultaneously by the - // service. In this case, we discard the Executable given as a parameter and - // return what is in the cache. This is necessary because the service relies - // on the cache to keep ownership of the Executable. We only want to store - // one Executable for a given computation version and we can't discard the - // executable which is in the cache because it may be in use. - executable.reset(); - } - return cache_.at(key); -} - -std::shared_ptr CompilationCache::LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = BuildKey(versioned_handle, module_config); - VLOG(2) << "looking up cache key: " << key; - if (cache_.count(key) == 0) { - VLOG(2) << "cache key not found: " << key; - return nullptr; - } else { - std::shared_ptr result = cache_.at(key); - VLOG(2) << "hit executable with module config: " - << result->module_config().compilation_cache_key(); - return result; - } -} - -CompilationCache::CacheKey CompilationCache::BuildKey( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - // The computation shape is represented entirely by its ProgramShape member, - // so just serialize the proto as part of the key. - return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", - versioned_handle.version, "::", - module_config.compilation_cache_key()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h deleted file mode 100644 index 09989726ae6629aa65cb1dd84c16408a75019fa5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { - -// A cache which stores Executables indexed by computation handle and version. -class CompilationCache { - public: - CompilationCache() {} - - // Insert the given Executable into the cache. Return a bare Executable - // pointer for the caller to use. Note: the returned pointer will *not* be the - // same as the given unique pointer if the computation already exists in the - // cache. See comments in the .cc implementation for details of this case. - // - // module_config is provided by the caller, instead of being taken from the - // executable, so that we can insert keys into the compilation cache that are - // devoid of layout (where XLA gets to choose what layout to compile). - // - // A shared_ptr is returned so the caller can keep the Executable from being - // destructed in the event that the Executable is evicted from the - // computation cache (and the cache's shared_ptr to the Executable is - // destructed). - std::shared_ptr Insert(std::unique_ptr executable, - const HloModuleConfig& module_config); - - // Lookup the Executable for the specified versioned computation in the cache. - // Return a shared_ptr to the Executable if it exists in the cache. Return - // nullptr otherwise. - std::shared_ptr LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - - protected: - mutable tensorflow::mutex mutex_; - - // Map from versioned handle with program layout to Executable built - // for that computation version and program layout. - using CacheKey = string; - - CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - std::map> cache_ GUARDED_BY(mutex_); - - private: - TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index dab73596e1639eed62151197048ee8d29570b20a..7426672a7a2a9102bd5ea98bd51092982e1e09b4 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +36,7 @@ limitations under the License. namespace xla { /* static */ StatusOr> -CompileOnlyService::NewService(perftools::gputools::Platform* platform) { +CompileOnlyService::NewService(se::Platform* platform) { ServiceOptions default_options; default_options.set_platform(platform); return NewService(default_options); @@ -45,7 +44,7 @@ CompileOnlyService::NewService(perftools::gputools::Platform* platform) { /* static */ StatusOr> CompileOnlyService::NewService(const ServiceOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } @@ -63,55 +62,47 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + for (const AotXlaComputationInstance& instance : computations) { + TF_RET_CHECK(instance.computation.has_program_shape()); - // TODO(b/63773457): Track DebugOptions in AotCompilationOptions. - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + const DebugOptions& debug_options = options.debug_options(); - // Dump computation proto state if flag is set. + // Dump computation proto if flag is set. const string& directory_path = debug_options.xla_dump_computations_to(); if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); + HloSnapshot hlo_snapshot; + *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); + "computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - + const auto& program_shape = instance.computation.program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, *user_computation)); + CreateModuleConfig(program_shape, instance.argument_layouts, + &execution_options)); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(instance.computation, *module_config)); TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, + metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 9859941c6c17460939e5b6817f1c7c415e63443c..1ac950bdd66bd034dfdafa8598ec506221e99c2f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -34,70 +34,56 @@ class CompileOnlyService : public Service { // platform that the service should target. If platform is null then the // default platform is used. static StatusOr> NewService( - perftools::gputools::Platform* platform); + se::Platform* platform); static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; + // A description of a xla computation to compile using CompileAheadOfTime. + struct AotXlaComputationInstance { + HloModuleProto computation; std::vector argument_layouts; const Shape* result_layout = nullptr; }; - // Compiles a list of computations for ahead-of-time execution. This is + // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override { + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override { + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1..0dceed853dcbae211657f00433866cfe10c51fc7 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -23,26 +23,42 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" -namespace se = ::perftools::gputools; - namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -/* static */ std::map* +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return {}; +} + +// Define a default version where metadata is not used. +StatusOr>> +Compiler::CompileAheadOfTime( + std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { + if (metadata != nullptr) { + return Unimplemented( + "Populating AotCompilationMetadata is not implemented on this " + "compiler."); + } + return CompileAheadOfTime(std::move(modules), options); +} + +/* static */ std::map* Compiler::GetPlatformCompilerFactories() { - static auto* r = - new std::map; + static auto* r = new std::map; return r; } /* static */ -std::map>* +std::map>* Compiler::GetPlatformCompilers() { - static auto* r = new std::map>; + static auto* r = new std::map>; return r; } @@ -86,4 +102,7 @@ Compiler::GetPlatformCompilers() { return compilers->at(platform->id()).get(); } +AotCompilationOptions::AotCompilationOptions() + : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 74fd24edf88d44b2dfdc87556b0af43987e69e08..d1144f97bb2ab29d3d18f3b3f65a38af46e68dd1 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -24,8 +24,11 @@ limitations under the License. #include #include #include +#include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -33,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -70,7 +74,7 @@ class AotCompilationOptions { virtual ~AotCompilationOptions() = default; // Returns the ID of the platform to which these options apply. - virtual perftools::gputools::Platform::Id PlatformId() const = 0; + virtual se::Platform::Id PlatformId() const = 0; // Optional allocator that may be used for allocating temp space on the device // during compilation. @@ -79,11 +83,28 @@ class AotCompilationOptions { device_allocator_ = device_allocator; } + const DebugOptions& debug_options() const { return debug_options_; } + DebugOptions* mutable_debug_options() { return &debug_options_; } + protected: - AotCompilationOptions() = default; + AotCompilationOptions(); private: DeviceMemoryAllocator* device_allocator_ = nullptr; + DebugOptions debug_options_; +}; + +// Abstract superclass describing metadata produced during ahead-of-time +// compilation. +class AotCompilationMetadata { + public: + AotCompilationMetadata(const AotCompilationMetadata&) = delete; + AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; + + virtual ~AotCompilationMetadata() = default; + + protected: + AotCompilationMetadata() = default; }; // Abstract compiler interface that is subclassed for compilation on a @@ -105,7 +126,7 @@ class Compiler { virtual ~Compiler() {} // Returns the ID of the platform that this compiler targets. - virtual perftools::gputools::Platform::Id PlatformId() const = 0; + virtual se::Platform::Id PlatformId() const = 0; // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. @@ -116,14 +137,13 @@ class Compiler { // algorithm over those buffers, to see which variant is fastest. Any space // allocated should be deallocated before this function returns. virtual StatusOr> RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* executor, + std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses - // prior to calling this method because the some HLO passes are required for + // prior to calling this method because some HLO passes are required for // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device @@ -133,8 +153,7 @@ class Compiler { // // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* executor, + std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; // Compiles a set of HLO modules that can run in parallel, potentially @@ -147,16 +166,32 @@ class Compiler { // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, - std::vector> - stream_exec, + std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; + // Returns the backend configurations that the backend will consider for the + // given HLO. Returns no configurations if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) = 0; + // Similar to CompileAheadOfTime above but AotCompilationMetadata + // has an argument that can be populated during compilation. + virtual StatusOr>> + CompileAheadOfTime(std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + ///// // The Compiler class also serves as a point to register compiler objects // for the various platforms. @@ -167,14 +202,12 @@ class Compiler { // be a singleton, so no ownership is transferred. // // Precondition: a platform kind must not be registered more than once. - static void RegisterCompilerFactory( - perftools::gputools::Platform::Id platform_id, - CompilerFactory compiler_factory); + static void RegisterCompilerFactory(se::Platform::Id platform_id, + CompilerFactory compiler_factory); // Returns the compiler singleton pointer if it is available for the given // platform, or an error status if it is not. - static StatusOr GetForPlatform( - const perftools::gputools::Platform* platform); + static StatusOr GetForPlatform(const se::Platform* platform); // Returns a function that computes the size in bytes of the logical // buffer that contains a shape. @@ -182,9 +215,9 @@ class Compiler { // Returns a function that computes the size in bytes of a given // logical buffer. - std::function BufferSizeBytesFunction() { + std::function BufferSizeBytesFunction() { HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); - return [shape_size](const LogicalBuffer& buffer) { + return [shape_size](const BufferValue& buffer) { return shape_size(buffer.shape()); }; } @@ -194,12 +227,12 @@ class Compiler { static tensorflow::mutex platform_compiler_mutex_; // Map from platform kind to compiler factory. - static std::map* + static std::map* GetPlatformCompilerFactories(); // Map from platform kind to compiler instance, if we made one already (based // on the factories above). - static std::map>* + static std::map>* GetPlatformCompilers(); }; diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index d2d4f14fcec35f5b51a2670a646154ce8bb9bfc1..cb61f3da39fb8eef69fd81066d87a1da91a62935 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -23,12 +23,15 @@ limitations under the License. namespace xla { -ComputationLayout::ComputationLayout(const ProgramShape& program_shape) +ComputationLayout::ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts) : result_layout_(program_shape.result()) { for (auto& shape : program_shape.parameters()) { parameter_layouts_.emplace_back(shape); } - SetToDefaultLayout(); + if (ignore_layouts) { + SetToDefaultLayout(); + } } void ComputationLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 80e102411c7885669947d89f378b1ec61e3e4e96..6975f387b4864bf28ea0ad23d7d4602b5b346e08 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -32,10 +32,20 @@ namespace xla { // mutable layouts. class ComputationLayout { public: + // Creates a new ComputationLayout with the given result layout. + explicit ComputationLayout(ShapeLayout result_layout) + : result_layout_(std::move(result_layout)) {} + // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the - // ProgramShape are ignored. - explicit ComputationLayout(const ProgramShape& program_shape); + // ProgramShape are ignored if ignore_layouts is true. + explicit ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts = true); + + // Adds a new parameter layout to the computation layout. + void add_parameter_layout(ShapeLayout shape_layout) { + parameter_layouts_.push_back(std::move(shape_layout)); + } // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 657fba6b6231104bf47f9dec80f7cd36a0ba3efd..7c1bacff92b231661477b9931a3066fd91110445 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -32,8 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -132,11 +130,9 @@ StatusOr ComputationPlacer::AssignDevices( ComputationPlacer::platform_computation_placer_mutex_( tensorflow::LINKER_INITIALIZED); -/* static */ std::map* +/* static */ std::map* ComputationPlacer::GetPlatformComputationPlacers() { - static auto* r = - new std::map; + static auto* r = new std::map; return r; } @@ -147,10 +143,10 @@ static std::unique_ptr CreateComputationPlacer() { } static bool InitModule() { - xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, - &CreateComputationPlacer); - xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, - &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::host::kHostPlatformId, &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 737ccabaa7a61931b6e2787f75b02857562d4820..737d00e93ecb51a9bd544bbcbe99d93374d108fb 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -80,13 +80,13 @@ class ComputationPlacer { // Registers a computation placer creation function for a particular platform. static void RegisterComputationPlacer( - perftools::gputools::Platform::Id platform_id, + se::Platform::Id platform_id, ComputationPlacerCreationFunction creation_function); // Returns the computation placer singleton pointer if it is available for the // given platform, or an error status if it is not. static StatusOr GetForPlatform( - const perftools::gputools::Platform* platform); + const se::Platform* platform); private: // The mutex that guards the platform-to-computation placer map. @@ -101,10 +101,9 @@ class ComputationPlacer { }; // Map from platform kind to computation placer singleton. - static std::map* - GetPlatformComputationPlacers(); + static std::map* GetPlatformComputationPlacers(); - perftools::gputools::Platform::Id platform_id_; + se::Platform::Id platform_id_; TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); }; diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc deleted file mode 100644 index 70e25eebdb068db893e24aec0f72d09090ac7027..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/computation_tracker.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" - -using ::tensorflow::strings::Appendf; - -namespace xla { - -ComputationTracker::ComputationTracker() : next_computation_(1) {} - -ComputationHandle ComputationTracker::NewComputation( - const string& computation_name) { - tensorflow::mutex_lock lock(computation_mutex_); - ComputationHandle computation_handle; - int64 handle_value = next_computation_++; - computation_handle.set_handle(handle_value); - opaque_to_computation_[handle_value] = - MakeUnique(computation_name, computation_handle); - return computation_handle; -} - -StatusOr ComputationTracker::LoadSessionModule( - const SessionModule& session_module) { - tensorflow::mutex_lock lock(computation_mutex_); - - // For each embedded computation, create a new computation based on its - // serialized data, and place the mapping from the old computation handle to - // the new computation handle. - - // Build a mapping from old embedded computation handles to new computation - // handles. We build the ID mapping first since the embedded computations are - // in no particular order and may refer to each other. - std::map old_to_new; - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { - return InvalidArgument("Duplicate embedded computation handle %lld", - old_handle); - } - } - - // Create a new computation from each serialized embedded computation. - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - const ComputationHandle& new_handle = old_to_new[old_handle]; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - computation, new_handle, old_to_new)); - } - - // Finally, place the entry computation in the tracker with all of the - // remappings populated from the above. - const int64 old_handle = session_module.entry().computation_handle().handle(); - TF_ASSIGN_OR_RETURN( - old_to_new[old_handle], - LoadSessionComputation(session_module.entry(), &old_to_new)); - return old_to_new[old_handle]; -} - -StatusOr> -ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); - const VersionedComputationHandle entry_versioned_handle = - user_computation->GetVersionedHandle(); - std::set visited; - std::list post_order; - { - tensorflow::mutex_lock lock(computation_mutex_); - ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); - } - auto session_module = MakeUnique(); - *session_module->mutable_entry() = - Resolve(entry_versioned_handle.handle) - .ValueOrDie() - ->CloneSessionComputation(entry_versioned_handle.version); - for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { - *session_module->add_embedded_computations() = - Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); - } - return std::move(session_module); -} - -StatusOr ComputationTracker::Resolve( - const ComputationHandle& computation) const { - tensorflow::mutex_lock lock(computation_mutex_); - return ResolveInternal(computation); -} - -ComputationHandle ComputationTracker::AllocateHandle() { - int64 handle_value = next_computation_++; - ComputationHandle result; - result.set_handle(handle_value); - return result; -} - -StatusOr ComputationTracker::LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) { - TF_RET_CHECK(old_to_new != nullptr); - const ComputationHandle new_handle = AllocateHandle(); - (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - session_computation, new_handle, *old_to_new)); - return new_handle; -} - -StatusOr ComputationTracker::ResolveInternal( - const ComputationHandle& computation) const { - auto it = opaque_to_computation_.find(computation.handle()); - if (it == opaque_to_computation_.end()) { - return NotFound("computation handle not found: %lld", computation.handle()); - } - UserComputation* user_computation = it->second.get(); - return user_computation; -} - -void ComputationTracker::ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const { - if (visited->count(versioned_handle) > 0) { - CHECK_EQ(1, visited->count(versioned_handle)); - return; - } - - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - std::vector embedded_handles = - computation->GetEmbeddedComputations(versioned_handle.version); - - for (const auto& embedded_handle : embedded_handles) { - ComputeComputationPostOrder(embedded_handle, visited, post_order); - } - - visited->insert(versioned_handle); - post_order->push_back(versioned_handle); -} - -StatusOr> ComputationTracker::BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(computation_mutex_); - - VLOG(1) << "BuildHloModule(" << entry_handle - << ", include_unreachable_instructions=" - << include_unreachable_instructions << ")"; - XLA_VLOG_LINES(1, ToStringInternal()); - - TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, - ResolveInternal(entry_handle.handle)); - - // Build a topological sort of the entry and any embedded computations as a - // list. The root of the computation will be the last element in the list. - std::set visited; - std::list post_order; - ComputeComputationPostOrder(entry_handle, &visited, &post_order); - - // Map from ComputationHandle value and computation version to HloComputation. - std::map hlo_computations; - - // The resolver lambda resolves VersionedHandles to embedded - // HloComputation*. This is required by UserComputation::BuildHloComputation - // when lowering calling operations (map, reduce etc). - auto resolver = [&hlo_computations]( - const VersionedComputationHandle& versioned_handle) -> HloComputation* { - CHECK_GT(hlo_computations.count(versioned_handle), 0); - return hlo_computations.at(versioned_handle); - }; - - // Print the post-order list for this entry computation. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Visiting UserComputations in post order:"; - for (const VersionedComputationHandle& versioned_handle : post_order) { - VLOG(2) << " " << versioned_handle; - } - } - - string module_name = - tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle, config); - for (auto versioned_handle : post_order) { - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - computation->BuildHloComputation(versioned_handle.version, resolver, - config.debug_options(), - include_unreachable_instructions)); - - // Add the newly created computation to VersionedHandle-to-HloComputation - // map. - DCHECK_EQ(0, hlo_computations.count(versioned_handle)); - hlo_computations[versioned_handle] = hlo_computation.get(); - - if (computation == entry_computation) { - module->AddEntryComputation(std::move(hlo_computation)); - } else { - module->AddEmbeddedComputation(std::move(hlo_computation)); - } - } - - return std::move(module); -} - -string ComputationTracker::ToString() const { - tensorflow::mutex_lock lock(computation_mutex_); - return ToStringInternal(); -} - -string ComputationTracker::ToStringInternal() const { - string out; - Appendf(&out, "ComputationTracker(%p):\n", this); - for (const auto& handle_computation : opaque_to_computation_) { - int64 handle = handle_computation.first; - const std::unique_ptr& computation = - handle_computation.second; - Appendf(&out, " %4lld : %s \"%s\"\n", handle, - computation->GetVersionedHandle().ToString().c_str(), - computation->name().c_str()); - } - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h deleted file mode 100644 index d42d66adefe7faa2751da4cd80b392a38917ce70..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Tracks computations for the XLA service; computations can be registered -// with a UserComputation instance and can be resolved from a handle for later -// use. -// -// This class is also capable of serializing/deserializing computations that it -// tracks (and to serialize properly you need to serialize all referred-to -// computations as well). -class ComputationTracker { - public: - ComputationTracker(); - - // Creates a new UserComputation object and returns the corresponding - // ComputationHandle for it. - // - // Precondition: user_computation is not already present in the map. - ComputationHandle NewComputation(const string& computation_name); - - // Restores session data for a computation that has been serialized, and - // allocates a new computation handle for it. - StatusOr LoadSessionModule( - const SessionModule& session_module); - - // Snapshots a computation (referenced by the provided handle) at its latest - // version, returning a module where it is the entry, and any referred-to - // computations are entrained as "embedded" (non-entry) computations. - StatusOr> SnapshotComputation( - const ComputationHandle& computation); - - // Resolves a ComputationHandle to a UserComputation that is present in the - // map. - StatusOr Resolve( - const ComputationHandle& computation) const; - - // Builds an HLO module using the specified computation as the entry. The - // module will include the entry computation as well as all computations which - // are called directly or indirectly from the entry computation via operations - // like "map". config is the HLO module configuration to use for the - // constructed module. - // If include_unreachable_instructions is true, then instructions - // which are not reachable from the root are lowered into HloInstructions - // including unreachable parameters. This ensures the entry HloComputation has - // the same program shape (ProgramShape) as the entry UserComputation. - StatusOr> BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions = true) const; - - string ToString() const; - - private: - // Bumps the next_computation_ number and returns the allocated number wrapped - // in a ComputationHandle. - ComputationHandle AllocateHandle() - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Loads a session computation into a UserComputation, registers it, and - // returns the computation handle of the registered computation. If old_to_new - // is provided, it is used for remapping references to computations present in - // session_computation. - // - // old_to_new will be updated with the mapping from session_computation's old - // handle to the returned handle value, and may not be null. - StatusOr LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Internal implementation of Resolve method which requires, but does not - // acquire the mutex. - StatusOr ResolveInternal( - const ComputationHandle& computation) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Builds a post order sort of a computation ("entry") and all of its embedded - // computations including all transitively embedded computations. An embedded - // computation (the callee) will always appear in the sort before the - // computation which calls the embedded computation (the caller). Necessarily, - // the entry computation is the last element in the sort. visited and - // post_order should be empty when calling. post_order contains the post order - // sort when the function return. - void ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Guards the computation mapping. Marked mutable so that the Resolve method - // can remain const; Resolve does't really modify the tracker in any way, but - // it has to lock the mutex for safety. - mutable tensorflow::mutex computation_mutex_; - - // The next sequence number to assign to a computation, guarded by the same - // mutex as the mapping as they'll be mutated at the same time. - int64 next_computation_ GUARDED_BY(computation_mutex_); - - // Mapping from ComputationHandle value to the corresponding registered - // UserComputation object. - std::map> opaque_to_computation_ - GUARDED_BY(computation_mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9ec796121fff223474c3e81a5e973cc37f8caec --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +// Tries to replace a conditional with a call operation of the corresponding +// computation. If the given conditional has a constant predicate, tries to +// replace it with a call to its true/false computation as appropriate and then +// inline that computation. +// +// Returns true if it made a change to the graph. +static StatusOr TryRemoveConditional(HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + // Do not remove conditionals that contain side-effecting instructions or + // have control predecessors/successors in either true/false computation. + if (!conditional->parent()->IsRemovable(conditional) || + conditional->HasSideEffect()) { + VLOG(2) << "Not attempting to remove conditional as it is not removable or " + "has side effect: " + << conditional->ToShortString(); + return false; + } + + if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { + VLOG(2) << "Not attempting to remove conditional as its predicate is not a " + "compile-time constant: " + << conditional->ToShortString(); + return false; + } + + auto computation = conditional->parent(); + HloInstruction* call_op; + if (conditional->operand(0)->literal().Get({})) { + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(1)}, + conditional->true_computation())); + } else { + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(2)}, + conditional->false_computation())); + } + conditional->SetupDerivedInstruction(call_op); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + + return true; +} + +StatusOr ConditionalSimplifier::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "ConditionalSimplifier::Run(), before:\n" + module->ToString()); + bool changed = false; + + // Gather all the conditional ops in our module. We do this ahead of time so + // we don't have to worry about mutating the lists of computations or + // instructions as we iterate. + std::vector conditional_ops; + for (auto* comp : module->computations()) { + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); + changed |= result; + } + + XLA_VLOG_LINES(3, + "ConditionalSimplifier::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..063261e26d06e21a297e8e3c405898a17221b7ca --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// HLO pass that removes kConditional with a constant predicate, replacing them +// with their true or false computation as appropriate. +class ConditionalSimplifier : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "simplify-conditional"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..868348547d9f5cbdc7576c7fc0697d72c3a3e557 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ConditionalSimplifierTest : public HloVerifiedTestBase { + public: + // Makes a computation that contains a conditional with constant predicate. + HloComputation* MakeConditional(HloModule* module); +}; + +HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { + HloComputation::Builder builder(TestName()); + + // true_computation returns param+1. + HloComputation* true_computation; + { + HloComputation::Builder true_computation_builder(TestName() + + ".true_computation"); + auto param = + true_computation_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param")); + auto one = true_computation_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + + true_computation_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); + + true_computation = + module->AddEmbeddedComputation(true_computation_builder.Build()); + } + + // false_computation returns param+42. + HloComputation* false_computation; + { + HloComputation::Builder false_computation_builder(TestName() + + ".false_computation"); + auto param = false_computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), + "param")); + auto forty_two = false_computation_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + false_computation_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); + false_computation = + module->AddEmbeddedComputation(false_computation_builder.Build()); + } + + auto false_instrn = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "false_param")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + + builder.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, + false_param, false_computation)); + + return module->AddEntryComputation(builder.Build()); +} + +TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { + HloComputation* computation = MakeConditional(&module()); + ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Constant())); +} + +TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { + HloComputation* computation = MakeConditional(&module()); + + auto* true_op = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + TF_ASSERT_OK( + true_op->AddControlDependencyTo(computation->root_instruction())); + + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + + auto* true_computation = conditional->true_computation(); + auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( + true_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))), + /*channel_id=*/0)); + true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + + auto* true_computation = conditional->true_computation(); + auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( + ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + auto* false_computation = conditional->false_computation(); + false_computation->AddInstruction( + HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cc195879a6bb490a9b49ad962aa9326cb51d9b0a..33d8338809d4e8c7c4774f062c3dda5494543ca6 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -58,6 +57,46 @@ bool ValueIsReadOnly(const HloValue& value) { return IsConstantValue(value) || IsEntryParameterValue(value); } +// Data structure describing the action which should be taken on parts of a +// computation buffers, with respect to the adding of special case copies. +struct SpecialCaseCopyPolicy { + // Insert a copy if the same buffer is found at multiple indices within the + // output tuple. + bool copy_root_replicated_buffers = false; + // If true, insert a copy if a buffer coming from a constant or a parameter + // is found within the output tuple. + bool copy_parameters_and_constants = false; +}; + +SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, + HloModule* module, + HloComputation* computation) { + SpecialCaseCopyPolicy policy; + if (computation == module->entry_computation()) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + for (const CallSite& site : node.caller_callsites()) { + // The AddCopiesForConditional() already adds copies, but the copy remover + // removes them, so we re-add them by returning the policy here. But really + // the copy remover should not be removing them. + if (site.instruction()->opcode() == HloOpcode::kConditional) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + } + return policy; +} + +bool ShouldCopyRootValue(const HloValue& value, + const SpecialCaseCopyPolicy& policy) { + if (policy.copy_parameters_and_constants) { + return IsConstantValue(value) || + value.defining_instruction()->opcode() == HloOpcode::kParameter; + } + return false; +} + // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in // 'indices_to_copy'. Add control edges from the respective kCopy instructions // in deep copy of 'from' to the respective kCopy instruction in the deep copy @@ -282,6 +321,29 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// We add copies for all the indices of the true and false computaiton roots, +// in order to resolve interference. We later rely on the CopyRemover to drop +// the unnecessary ones. +Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, + HloInstruction* conditional) { + VLOG(2) << "Adding copies for kConditional instruction " + << conditional->name(); + TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); + + for (HloComputation* computation : + {conditional->true_computation(), conditional->false_computation()}) { + HloInstruction* root = computation->root_instruction(); + std::vector users = root->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + computation->DeepCopyInstruction(root)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); + } + computation->set_root_instruction(deep_copy); + } + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -309,6 +371,9 @@ Status AddCopiesToResolveInterference(HloModule* module) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR( + AddCopiesForConditional(*alias_analysis, instruction)); } } } @@ -557,6 +622,7 @@ class CopyRemover { auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; if (LiveRangeBefore(a, b)) { VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " << b.value->ToShortString(); @@ -571,7 +637,7 @@ class CopyRemover { VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); VLOG(3) << "Source buffer values: " << ValueListToString(src); - VLOG(3) << "Dest buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(dest); // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -747,16 +813,16 @@ class CopyRemover { // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { if (a.uses.empty()) { - VLOG(2) << "Empty uses"; + VLOG(2) << "Empty uses for " << *a.value; return ordering_.IsDefinedBefore(*a.value, *b.value); } for (const HloUse* use : a.uses) { - VLOG(2) << "use: " << *use; - VLOG(2) << "is before:" << *b.value; + VLOG(2) << "Checking use " << *use << " against " << *b.value; if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Not before"; + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; return false; } + VLOG(2) << "Use " << *use << " is before " << *b.value; } return true; } @@ -892,7 +958,6 @@ Status RemoveUnnecessaryCopies( CopyRemover copy_remover(*alias_analysis, ordering, module); XLA_VLOG_LINES(3, copy_remover.ToString()); - tensorflow::gtl::FlatSet existing_copies; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && @@ -901,7 +966,6 @@ Status RemoveUnnecessaryCopies( } } } - return Status::OK(); } @@ -921,7 +985,7 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { // Identify which shape indices of which instructions need to be copied. Store // these results in 'instructions_to_copy'. - std::unordered_map> instructions_to_copy; + HloInstructionMap> instructions_to_copy; auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, const ShapeIndex& index) { auto it = instructions_to_copy.find(instruction); @@ -957,7 +1021,8 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { } TF_RET_CHECK(node.context() == CallContext::kSequential); - const bool is_entry = computation == module->entry_computation(); + SpecialCaseCopyPolicy policy = + GetSpecialCaseCopyPolicy(node, module, computation); HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. @@ -970,27 +1035,26 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { for (const HloBuffer* buffer : buffers_at_index) { buffer_seen_before |= !seen.insert(buffer).second; } - if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { - VLOG(2) << "Index " << index << " of root of computation " + if (buffers_at_index.size() > 1 || + (buffer_seen_before && policy.copy_root_replicated_buffers)) { + VLOG(2) << "Index " << index << " of computation " << computation->name() << " (" << root->name() << ") has ambiguous or non-distinct buffer. Copying."; add_index_to_copy(root, index); } }); - // For entry instructions, mark any parameter or constant values. - if (is_entry) { - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (ValueIsReadOnly(*value)) { - VLOG(2) << "Root of entry computation (" << root->name() - << ") has constant or entry parameter value at index " - << index << ". Copying."; - add_index_to_copy(root, index); - } + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ShouldCopyRootValue(*value, policy)) { + VLOG(2) << "Root of (" << root->name() << ") of computation(" + << computation->name() + << ") has constant or parameter value at index " << index + << ". Copying."; + add_index_to_copy(root, index); } } } @@ -1012,7 +1076,6 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { instruction->parent()->set_root_instruction(deep_copy); } } - return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 153f062d015e49db11c4c9ae0a2a61e76c020f02..684fff8a6fd4fd045a0de9df9660b887f32bdf40 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1636,8 +1636,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1677,8 +1676,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1750,8 +1748,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { std::vector tuple_params(num_tuple_inputs); for (int i = 0; i < num_iters; ++i) { auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), - config); + HloModule module("BM_ManyElementTuple", config); for (int j = 0; j < num_tuple_inputs; ++j) { tuple_params[j] = builder.AddInstruction( HloInstruction::CreateParameter(j, element_shape, "")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index c13a0b1cdf0b5be0b69db98b2b9587f30ca4c304..b703be0f39e2032bc58479f0b957f9d8b01a77c3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -18,6 +18,10 @@ load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -85,12 +89,10 @@ cc_library( ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", - ":cpu_parallelization_preparation", ":disassembler", ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/xla:literal_util", @@ -101,13 +103,16 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batch_dot_simplification", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -121,12 +126,14 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", @@ -144,7 +151,14 @@ cc_library( "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep - ], + ] + select({ + "@org_tensorflow//tensorflow:linux_ppc64le": [ + "@llvm//:powerpc_disassembler", + "@llvm//:powerpc_code_gen", + ], + "//conditions:default": [ + ], + }), alwayslink = True, # Contains compiler registration ) @@ -163,11 +177,15 @@ cc_library( ":disassembler", ":external_constant_pool", ":orc_jit_memory_mapper", + ":runtime_fp16", ":runtime_conv2d", + ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", + ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -182,6 +200,20 @@ cc_library( ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) +cc_library( + name = "runtime_fp16", + srcs = [ + "runtime_fp16.cc", + ], + hdrs = [ + "runtime_fp16.h", + ], + copts = runtime_copts(), + deps = [ + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "cpu_executable", srcs = ["cpu_executable.cc"], @@ -210,35 +242,6 @@ cc_library( ], ) -cc_library( - name = "parallel_cpu_executable", - srcs = ["parallel_cpu_executable.cc"], - hdrs = [ - "parallel_cpu_executable.h", - ], - deps = [ - ":cpu_runtime", - ":shape_partition", - ":simple_orc_jit", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:device_memory_allocator", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:logical_buffer", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - "@llvm//:orc_jit", - ], -) - cc_library( name = "ir_emitter", srcs = [ @@ -302,6 +305,15 @@ cc_library( ], ) +cc_library( + name = "target_machine_features_fake", + testonly = 1, + hdrs = ["target_machine_features_fake.h"], + deps = [ + ":target_machine_features", + ], +) + cc_library( name = "ir_function", srcs = ["ir_function.cc"], @@ -343,6 +355,7 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":ir_emission_utils", ":target_machine_features", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", @@ -372,10 +385,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -415,7 +428,6 @@ cc_library( "//tensorflow/core:lib", "@llvm//:analysis", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:ipo", "@llvm//:mc", "@llvm//:object", @@ -479,6 +491,27 @@ cc_library( ], ) +cc_library( + name = "runtime_conv2d_mkl", + srcs = [ + "runtime_conv2d_mkl.cc", + ], + hdrs = ["runtime_conv2d_mkl.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + ":runtime_conv2d", + ":runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_helpers", + "//third_party/eigen3", + ] + if_mkl([ + "@mkl_dnn", + "//third_party/mkl:intel_binary_blob", + ]), +) + cc_library( name = "runtime_fft", srcs = [ @@ -491,7 +524,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -499,7 +531,6 @@ cc_library( cc_library( name = "runtime_matvec", - srcs = ["runtime_matvec.cc"], hdrs = ["runtime_matvec.h"], copts = runtime_copts(), deps = [ @@ -522,6 +553,22 @@ cc_library( ], ) +cc_library( + name = "runtime_matmul_mkl", + srcs = ["runtime_matmul_mkl.cc"], + hdrs = ["runtime_matmul_mkl.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ] + if_mkl([ + "//third_party/mkl:intel_binary_blob", + "@mkl_dnn", + ]), +) + cc_library( name = "runtime_single_threaded_conv2d", srcs = [ @@ -538,6 +585,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], @@ -568,10 +631,12 @@ cc_library( tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], + shard_count = 10, tags = ["optonly"], deps = [ ":cpu_runtime", ":runtime_matmul", + ":runtime_matmul_mkl", ":runtime_single_threaded_matmul", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:types", @@ -591,6 +656,7 @@ tf_cc_test( deps = [ ":cpu_instruction_fusion", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -622,31 +688,13 @@ cc_library( ], ) -cc_library( - name = "cpu_parallelization_preparation", - srcs = ["cpu_parallelization_preparation.cc"], - hdrs = [ - "cpu_parallelization_preparation.h", - ], - deps = [ - ":ir_emission_utils", - ":parallel_task_assignment", - ":shape_partition", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_cost_analysis", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:lib", - ], -) - cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ ":cpu_runtime", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", @@ -654,6 +702,23 @@ cc_library( ], ) +tf_cc_test( + name = "ir_emission_utils_test", + srcs = ["ir_emission_utils_test.cc"], + deps = [ + ":ir_emission_utils", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "cpu_layout_assignment", srcs = ["cpu_layout_assignment.cc"], @@ -661,6 +726,7 @@ cc_library( deps = [ ":dot_op_emitter", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", @@ -674,6 +740,7 @@ tf_cc_test( srcs = ["cpu_layout_assignment_test.cc"], deps = [ ":cpu_layout_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -698,6 +765,7 @@ cc_library( deps = [ ":cpu_runtime", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -712,6 +780,7 @@ tf_cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -750,12 +819,39 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", + ":target_machine_features", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", ], ) +tf_cc_test( + name = "parallel_task_assignment_test", + srcs = ["parallel_task_assignment_test.cc"], + deps = [ + ":cpu_executable", + ":parallel_task_assignment", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], @@ -809,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", ], @@ -860,16 +957,16 @@ tf_cc_test( ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +tf_cc_test( + name = "cpu_eigen_tensor_alignment_test", + size = "small", + srcs = ["cpu_eigen_tensor_alignment_test.cc"], + deps = [ + ":dot_op_emitter", + ":ir_emission_utils", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], ) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index ed290fcdf8bb69f1bbad57fa5a0926376bc9405a..6a7eb85e3baec3517b8f3ddef6a8dcfae9c9e614 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -25,11 +25,11 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/MC/MCContext.h" #include "llvm/Object/ObjectFile.h" +#include "llvm/Support/SmallVectorMemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" @@ -93,8 +93,8 @@ class FilteredPassManager : public llvm::legacy::PassManager { }; } // anonymous namespace -llvm::object::OwningBinary CompilerFunctor:: -operator()(llvm::Module& module) const { +std::unique_ptr CompilerFunctor::operator()( + llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); FilteredFunctionPassManager function_passes(&module, disable_expensive_passes_); @@ -157,27 +157,8 @@ operator()(llvm::Module& module) const { codegen_passes.run(module); // Construct ObjectFile from machine code buffer. - std::unique_ptr memory_buffer( - new llvm::ObjectMemoryBuffer(std::move(stream_buffer))); - llvm::Expected> - object_file_or_error = llvm::object::ObjectFile::createObjectFile( - memory_buffer->getMemBufferRef()); - CHECK(object_file_or_error); - - std::unique_ptr object_file = - std::move(object_file_or_error.get()); - if (VLOG_IS_ON(2)) { - StatusOr disassembly_status = - disassembler_->DisassembleObjectFile(*object_file); - if (disassembly_status.ok()) { - auto result = disassembly_status.ValueOrDie(); - XLA_VLOG_LINES(2, result.text); - VLOG(2) << "compiled code size: " << result.code_size_bytes << " bytes"; - } - } - - return llvm::object::OwningBinary( - std::move(object_file), std::move(memory_buffer)); + return std::unique_ptr( + new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 1a8283a702223a7414c1ffcd99c1ac42c04ac068..c38b896c5019b48fd2a16a51abd59e12ebdb29eb 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -47,7 +47,7 @@ class CompilerFunctor { post_optimization_hook_(post_optimization_hook) {} // Compile a Module to an ObjectFile. - llvm::object::OwningBinary operator()( + std::unique_ptr operator()( llvm::Module& module) const; // NOLINT private: diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2136aeb3877685373efaf5bf702a42b39a63f082..0985b9297fe487f3523826cb0978c17775549735 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -33,7 +33,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { if (hlo->opcode() == HloOpcode::kConvolution && - !PotentiallyImplementedAsEigenConvolution(*hlo)) { + !PotentiallyImplementedAsEigenConvolution(*hlo, + target_machine_features_)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); auto input_batch_dim = dnums.input_batch_dimension(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 9b2c3d82eb673ce542cc03ec706015967dc975b6..e6fd1499edd0095395194200a5b444ad61e7e39d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,12 +33,19 @@ namespace cpu { // convolutions can run faster. class ConvCanonicalization : public HloPassInterface { public: + explicit ConvCanonicalization( + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) {} + ~ConvCanonicalization() override {} tensorflow::StringPiece name() const override { return "convolution-canonicalization"; } StatusOr Run(HloModule* module) override; + + private: + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 968f53d5c706651d2a470a853e0e9b601c0ed2df..375b017b09263c20c1b1ef8329f7e2f6a573dda4 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -89,7 +90,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); @@ -146,7 +151,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index f9cc9651846cca7bd6ab7e9e61590cec4e2400da..d039132535071661d047579587385210719fede3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -43,10 +43,12 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" @@ -55,17 +57,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -81,12 +82,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" @@ -98,8 +101,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -namespace se = ::perftools::gputools; - namespace xla { namespace cpu { @@ -120,10 +121,12 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { CpuAotCompilationResult::CpuAotCompilationResult( ObjectFileData object_file_data, BufferSizes buffer_sizes, - int64 result_buffer_index) + int64 result_buffer_index, + std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), buffer_sizes_(std::move(buffer_sizes)), - result_buffer_index_(result_buffer_index) {} + result_buffer_index_(result_buffer_index), + hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} CpuAotCompilationResult::~CpuAotCompilationResult() = default; @@ -173,14 +176,13 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: static StatusOr> GetCandidatesForComputation( - HloComputation* computation, + const HloComputation& computation, const std::unordered_map& assigned_indices) { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx, assigned_indices); - TF_RETURN_IF_ERROR( - computation->Accept(&profile_candidates_for_computation)); + TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } @@ -231,7 +233,10 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); @@ -248,8 +253,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(&target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); @@ -258,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -271,15 +276,19 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); } + pipeline.AddPass(); pipeline.AddPass( - [](const HloInstruction& dot, - const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot) + [&target_machine_features]( + const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot, target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -287,12 +296,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout()); + module->mutable_device_entry_computation_layout(), + &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -306,17 +318,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { module->config().intra_op_parallelism_threads() > 0 ? module->config().intra_op_parallelism_threads() : tensorflow::port::NumSchedulableCPUs(); - if (options::CpuParallelBackendRequested(module->config())) { - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); - } else if (!is_aot_compile) { + if (!is_aot_compile) { // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. // Note this is not run for AOT because it would bring in thread pool // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). - // TODO(29630486) Support multi-threaded AOT. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); + // TODO(b/29630486) Support multi-threaded AOT. + pipeline.AddPass( + max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -327,13 +336,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - if (options::CpuParallelBackendRequested(module->config())) { - // Re-run the outlining, in case any copies were inserted into the entry - // computation. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); - pipeline.AddPass(); - } pipeline.AddPass(); return pipeline.Run(module).status(); } @@ -433,16 +435,56 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { return Status::OK(); } +Status CreateHloProfilingArtifacts( + const HloModule& module, + std::unordered_map* + instruction_to_profile_idx, + std::unordered_map* + computation_to_profile_idx, + std::unique_ptr* hlo_profile_index_map, + std::unique_ptr* hlo_profile_printer_data) { + *hlo_profile_index_map = MakeUnique(module); + const HloComputation& entry_computation = *module.entry_computation(); + + TF_ASSIGN_OR_RETURN( + *instruction_to_profile_idx, + CollectProfileCandidates::GetCandidatesForComputation( + entry_computation, + (*hlo_profile_index_map)->instruction_to_profile_idx())); + + auto shape_size_bytes = [](const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return static_cast(sizeof(void*)); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + + HloCostAnalysis cost_analysis(shape_size_bytes); + TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis)); + *hlo_profile_printer_data = + CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis); + *computation_to_profile_idx = + (*hlo_profile_index_map)->computation_to_profile_idx(); + + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* /*stream_exec*/, + std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, DeviceMemoryAllocator* /*device_allocator*/) { VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + std::unique_ptr jit_target_machine = + SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config())); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, + jit_target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -450,8 +492,7 @@ StatusOr> CpuCompiler::RunHloPasses( } StatusOr> CpuCompiler::RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; @@ -489,28 +530,9 @@ StatusOr> CpuCompiler::RunBackend( std::unique_ptr hlo_profile_index_map; std::unique_ptr hlo_profile_printer_data; if (module->config().hlo_profiling_enabled()) { - hlo_profile_index_map = MakeUnique(*module); - - TF_ASSIGN_OR_RETURN( - instruction_to_profile_idx, - CollectProfileCandidates::GetCandidatesForComputation( - entry_computation, - hlo_profile_index_map->instruction_to_profile_idx())); - - auto shape_size_bytes = [](const Shape& shape) { - // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { - return static_cast(sizeof(void*)); - } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - }; - - HloCostAnalysis cost_analysis(shape_size_bytes); - TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); - hlo_profile_printer_data = - CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis); - computation_to_profile_idx = - hlo_profile_index_map->computation_to_profile_idx(); + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); } std::unique_ptr cpu_executable; @@ -522,190 +544,82 @@ StatusOr> CpuCompiler::RunBackend( const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); - if (options::CpuParallelBackendRequested(module->config())) { - VLOG(1) << "Using parallel cpu backend"; - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - // DependencyHloOrdering is used for the parallel emitter because the order - // of HLO instruction execution is not known ahead of time. - // DependencyHloOrdering is the most conservative partial order and only - // uses data dependencies for determining order. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), xla::MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); - } - - // If we are using the parallel CPU backend, we need to create map from - // HloInstruction to the corresponding generated function name. - std::map parallel_computations; - std::unordered_map> - aligned_constants; - for (auto instruction : entry_computation->MakeInstructionPostOrder()) { - // Parameters and constants don't get their own computation. - if (instruction->opcode() == HloOpcode::kParameter) { - continue; - } - if (instruction->opcode() == HloOpcode::kConstant) { - // Copy the constant out of the ProtocolBuffer so that we can give it a - // higher alignment. - const void* data = instruction->literal().untyped_data(); - int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); - auto iter = aligned_constants.emplace( - instruction, xla::MakeUnique(size)); - CHECK_EQ(iter.second, true); - unsigned char* aligned_data = iter.first->second.get(); - memcpy(aligned_data, data, size); - continue; - } - // The parallel preparation should have ensured that the top-level - // computation consists solely of Call instructions. - TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall) - << module->ToString(); - HloComputation* to_apply = instruction->to_apply(); - parallel_computations.emplace(to_apply, instruction); - } - - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - std::move(instruction_to_profile_idx), - std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); - - std::unique_ptr> function_names( - new HloInstructionMap()); - for (auto embedded_computation : - entry_computation->MakeEmbeddedComputationsList()) { - if (embedded_computation->IsFusionComputation()) { - continue; - } - auto parallel_computation_iter = - parallel_computations.find(embedded_computation); - // All parallel computations are considered to be an entry computation for - // IR generation purposes. - bool computation_is_parallel = - parallel_computation_iter != parallel_computations.end(); - TF_ASSIGN_OR_RETURN( - llvm::Function * ir_function, - ir_emitter.EmitComputation( - embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/computation_is_parallel, - /*instruction_order=*/nullptr)); - // If this computation is parallel, remember it in the function name map. - // This way we know what function to execute when we try to run code for - // the Call instruction. - if (computation_is_parallel) { - HloInstruction* call_instruction = parallel_computation_iter->second; - InsertOrDie(function_names.get(), call_instruction, - llvm_ir::AsString(ir_function->getName())); - } - } - - string ir_module_string; - if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); - } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); - - // JIT compile the LLVM IR module to in-memory machine code. - jit->AddModule(std::move(llvm_module)); - cpu_executable.reset(new ParallelCpuExecutable( - std::move(jit), std::move(assignment), std::move(module), - std::move(function_names), std::move(aligned_constants), - std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); - - if (embed_ir_in_executable) { - static_cast(*cpu_executable) - .set_ir_module_string(ir_module_string); - } - } else { - VLOG(1) << "Using sequential cpu backend"; - - // Select an order for emitting the HLO instructions for each - // computation. Using this sequence enables tighter buffer liveness analysis - // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - xla::MakeUnique( - module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); - } - - // Each computation is a single function. Emit all embedded computations - // before the entry computation. The order of computations returned from - // GetEmbeddedComputations guarantees that a called computation occurs - // before a caller computation. - - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - std::move(instruction_to_profile_idx), - std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); - - for (auto embedded_computation : - entry_computation->MakeEmbeddedComputationsList()) { - if (embedded_computation->IsFusionComputation()) { - continue; - } - TF_RETURN_IF_ERROR( - ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) - .status()); - } - string function_name_prefix = entry_computation->name().empty() - ? "__compute" - : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + // Select an order for emitting the HLO instructions for each + // computation. Using this sequence enables tighter buffer liveness analysis + // and reduced memory usage (as compared to using DependencyHloOrdering). + TF_ASSIGN_OR_RETURN( + SequentialHloOrdering::HloModuleSequence module_sequence, + ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); + + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run( + module.get(), + xla::MakeUnique(module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); + // BufferAssignment::ToString() includes a header, so no need for us to + // print one ourselves. + XLA_VLOG_LINES(2, assignment->ToString()); + + if (!xla_dump_optimized_hlo_proto_to.empty()) { + HloProto proto = MakeHloProto(*module, *assignment); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_optimized_hlo_proto_to, module->name())); + } - string function_name = llvm_ir::AsString(entry_function->getName()); - string ir_module_string; - if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); + // Each computation is a single function. Emit all embedded computations + // before the entry computation. The order of computations returned from + // GetEmbeddedComputations guarantees that a called computation occurs + // before a caller computation. + + LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + &target_machine_features, jit->external_constant_pool()); + + for (auto embedded_computation : + entry_computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &module_sequence.at(embedded_computation)) + .status()); + } + string function_name_prefix = entry_computation->name().empty() + ? "__compute" + : entry_computation->name(); + TF_ASSIGN_OR_RETURN( + llvm::Function * entry_function, + ir_emitter.EmitComputation(entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &module_sequence.at(entry_computation))); + + string function_name = llvm_ir::AsString(entry_function->getName()); + string ir_module_string; + if (embed_ir_in_executable) { + ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); + } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); - XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); - // JIT compile the LLVM IR module to in-memory machine code. - jit->AddModule(std::move(llvm_module)); - cpu_executable.reset(new CpuExecutable( - std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); + // JIT compile the LLVM IR module to in-memory machine code. + jit->AddModule(std::move(llvm_module)); + cpu_executable.reset(new CpuExecutable( + std::move(jit), std::move(assignment), std::move(module), function_name, + std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); - if (embed_ir_in_executable) { - static_cast(*cpu_executable) - .set_ir_module_string(ir_module_string); - } + if (embed_ir_in_executable) { + static_cast(*cpu_executable) + .set_ir_module_string(ir_module_string); } VLOG(1) << "Compilation finished"; @@ -807,14 +721,15 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + TF_RETURN_IF_ERROR( + RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -836,12 +751,22 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_optimized_hlo_proto_to, module->name())); } + std::unordered_map instruction_to_profile_idx; + std::unordered_map computation_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer_data; + + if (module->config().hlo_profiling_enabled()) { + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); + } + + LLVMTargetMachineFeatures target_machine_features(target_machine.get()); IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*instruction_to_profile_idx=*/ - std::unordered_map{}, - /*computation_to_profile_idx=*/ - std::unordered_map{}, - target_machine.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + &target_machine_features, /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : @@ -882,6 +807,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_RETURN_IF_ERROR(verify_status); } + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(llvm_module)); + Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, @@ -889,11 +816,10 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, module->config().debug_options().xla_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); - llvm::object::OwningBinary object_file = + std::unique_ptr object_file = compiler_functor(llvm_module); - llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); - ObjectFileData object_file_data(object_file_data_ref.begin(), - object_file_data_ref.end()); + ObjectFileData object_file_data(object_file->getBufferStart(), + object_file->getBufferEnd()); BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { @@ -916,7 +842,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, results.emplace_back(MakeUnique( std::move(object_file_data), std::move(buffer_sizes), - result_slice.index())); + result_slice.index(), std::move(hlo_profile_printer_data))); } VLOG(1) << "Compilation finished"; @@ -935,9 +861,9 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { } // namespace xla static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { - return xla::MakeUnique(); - }); + xla::Compiler::RegisterCompilerFactory( + stream_executor::host::kHostPlatformId, + []() { return xla::MakeUnique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 3498139ab95d21383c6dc008ae5614b7bfe91148..e56f9f01134f84b4698c078b750b0c1fdca7748e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -53,7 +54,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions { RelocationModel relocation_model); ~CpuAotCompilationOptions() override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; // The triple used for compilation, similar to clang's -target flag. const string& triple() const { return triple_; } @@ -76,10 +77,16 @@ class CpuAotCompilationOptions : public AotCompilationOptions { class CpuAotCompilationResult : public AotCompilationResult { public: - CpuAotCompilationResult(ObjectFileData object_file_data, - BufferSizes buffer_sizes, int64 result_buffer_index); + CpuAotCompilationResult( + ObjectFileData object_file_data, BufferSizes buffer_sizes, + int64 result_buffer_index, + std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); + HloProfilePrinterData* hlo_profile_printer_data() const { + return hlo_profile_printer_data_.get(); + } + const ObjectFileData& object_file_data() const { return object_file_data_; } const BufferSizes& buffer_sizes() const { return buffer_sizes_; } int64 result_buffer_index() const { return result_buffer_index_; } @@ -97,6 +104,10 @@ class CpuAotCompilationResult : public AotCompilationResult { // result of the computation. This buffer should be passed into the output // parameter when calling the compiled computation. const int64 result_buffer_index_; + + // Contains an instance of HloProfilePrinterData if HLO profiling is enabled, + // otherwise is nullptr. + std::unique_ptr hlo_profile_printer_data_; }; // CPU-targeting implementation of the XLA Compiler interface. @@ -112,25 +123,23 @@ class CpuCompiler : public LLVMCompiler { // Bring in // StatusOr>> Compile( // std::vector> modules, - // std::vector> + // std::vector> // stream_execs) using LLVMCompiler::Compile; StatusOr> RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; @@ -140,7 +149,8 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module, bool is_aot_compile); + Status RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8727c72b6e42517b1859e98ecadb41bbceed761c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace cpu { +namespace { + +// Test that we don't call into Eigen with tensors too small to be aligned +// reliably. + +class CpuEigenTensorAlignmentTest : public ::testing::Test {}; + +TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { + string hlo_string = R"( +HloModule DotOperation + +ENTRY DotOperation { + arg0 = f32[5,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloInstruction* dot = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE( + PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenDot( + *dot, target_machine_with_full_alignment)); +} + +TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { + string hlo_string = R"( +HloModule ConvOperation + +ENTRY ConvOperation { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, dim_labels=b0f_0io->b0f +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloInstruction* conv = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_full_alignment)); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index c053703c3524a47ee1de9681c1b986edbf109430..cf43b74c699ca8cbbef11a0abbaf4d69476f5d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -45,8 +45,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/host/host_stream.h" -namespace se = ::perftools::gputools; - namespace xla { namespace cpu { @@ -75,7 +73,7 @@ CpuExecutable::CpuExecutable( Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { + std::vector* buffers) { CHECK_EQ(buffers->size(), assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -203,61 +201,19 @@ Status CpuExecutable::ExecuteComputeFunction( return Status::OK(); } -static void LogLiveAddresses( - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - if (!VLOG_IS_ON(3)) { - return; - } - - CHECK_EQ(buffers.size(), buffers_in_result.size()); - std::vector live_out_buffers; - for (int i = 0; i < buffers.size(); ++i) { - if (buffers_in_result[i]) { - live_out_buffers.push_back(buffers[i].opaque()); - } - } - VLOG(3) << "Live addresses in output marking found " - << live_out_buffers.size() << " addresses:\n" - << tensorflow::str_util::Join( - live_out_buffers, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); -} - -static Status DeallocateTempBuffers( - DeviceMemoryAllocator* allocator, se::Stream* stream, - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - // Keep those buffers in the output of the marked live because they are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR( - allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); - } - } - - return Status::OK(); -} - -StatusOr> CpuExecutable::CreateResultShapedBuffer( +StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - allocated_buffers, - std::vector* buffers_in_result) { + tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); - auto result_buffer = MakeUnique( - /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), - stream->parent()->platform(), stream->parent()->device_ordinal()); - - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. - TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + ScopedShapedBuffer result_buffer( + /*on_host_shape=*/host_result_shape(), + /*on_device_shape=*/host_result_shape(), run_options->allocator(), + stream->parent()->device_ordinal()); + + // Move OwningDeviceMemory values which contain the array(s) of the result + // into the respective location in ScopedShapedBuffer which is returned to the + // caller. + TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus( [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); // The points to set is unambiguous so the set should be a @@ -275,16 +231,15 @@ StatusOr> CpuExecutable::CreateResultShapedBuffer( CHECK(!slice.allocation()->is_entry_computation_parameter()); const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + OwningDeviceMemory& buffer = buffers[buffer_index]; CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - (*buffers_in_result)[buffer_index] = true; + *device_memory = buffer.Forget(); return Status::OK(); })); return std::move(result_buffer); } -StatusOr> CpuExecutable::ExecuteOnStream( +StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -294,26 +249,24 @@ StatusOr> CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - // Free all buffers not in the result. - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), + arguments, unowning_buffers, + hlo_execution_profile)); - return std::move(result_buffer); + return CreateResultShapedBuffer(run_options, &buffers); } -StatusOr> CpuExecutable::ExecuteAsyncOnStream( +StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { if (hlo_profiling_enabled()) { @@ -322,34 +275,57 @@ StatusOr> CpuExecutable::ExecuteAsyncOnStream( "supported on CPU."); } - auto* host_stream = dynamic_cast( + auto* host_stream = dynamic_cast( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - LogLiveAddresses(buffers, buffers_in_result); - - host_stream->EnqueueTask([this, run_options, arguments, buffers, - buffers_in_result, memory_allocator, stream]() { - // Failing a CHECK here is not great, but I don't see an obvious way to - // return a failed Status asynchronously. - TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, - buffers, - /*hlo_execution_profile=*/nullptr)); - TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); - }); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, &buffers)); - return std::move(result_buffer); + // At this point, `unowning_buffers` contains unowning pointers to all of our + // buffers, and `buffers` contains owning pointers to the non-live-out + // buffers. Enqueue a task which keeps alive the non-live-out buffers. + // + // Logically we want this lambda to capture `buffers` by move, ultimately our + // functor needs to be wrapped in an std::function, and that requires its + // functor to be copyable. Thus we perpitrate the hack of capturing buffers + // "by shared pointer". + // + // We also need to change the types of some of the variables we capture: + // run_options needs to change from a pointer to a value type, and arguments + // needs to change from an ArraySlice into a vector. We use a struct instead + // of a lambda to make this explicit. + struct AsyncRunTask { + CpuExecutable* executable; + ServiceExecutableRunOptions run_options; + std::vector arguments; + std::vector unowning_buffers; + std::shared_ptr> buffers; + + void operator()() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(executable->ExecuteComputeFunction( + &run_options.run_options(), arguments, unowning_buffers, + /*hlo_execution_profile=*/nullptr)); + } + }; + host_stream->EnqueueTask(AsyncRunTask{ + this, *run_options, + std::vector(arguments.begin(), arguments.end()), + unowning_buffers, + std::make_shared>(std::move(buffers))}); + + return std::move(result); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 267b89a10b3c038dc2048f0ad5b5b343c88ef0f9..8dd47bfb865e8a0552542f510d3365cff0d111e0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -55,12 +55,12 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr> ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; @@ -71,11 +71,6 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on CPU executable. - return Unimplemented("Equality test on CPU executable is not implemented."); - } - static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. @@ -95,30 +90,24 @@ class CpuExecutable : public Executable { // assignment. Each vector element corresponds to a particular Index. If // a vector element already contains a non-null DeviceMemoryBase, then no // buffer is assigned for this element. - Status AllocateBuffers( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); + Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, + int device_ordinal, + std::vector* buffers); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice - buffers, + tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); - // Create a ShapedBuffer for holding the result of the computation. The - // addresses (DeviceMemoryBases) are set according to buffer assignment. - // 'buffers_in_result' should point to a vector of the same size as - // 'allocated_buffers'. An element in buffers_in_result is set to true if the - // corresponding buffer is live out of the computation (and thus contained in - // the returned ShapedBuffer). - StatusOr> CreateResultShapedBuffer( + // Creates a ScopedShapedBuffer for holding the result of the computation, + // moving buffers out of allocated_buffers and into the result as appropriate. + // The addresses are set according to buffer assignment. + StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - allocated_buffers, - std::vector* buffers_in_result); + tensorflow::gtl::MutableArraySlice buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 482e04052d5a914eab0e5bff2c7a83f3b698052f..b40d264c03aba6e9308e8a621ae86e180e33c335 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -30,11 +30,11 @@ bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // - hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 595c3f55b321f47e2312b93e0c238c7637495d77..97e10a89a209c057685709e7a5034052ff4376ed 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -77,7 +78,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); } -TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { +TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); @@ -94,8 +95,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); - EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Fusion()); + EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); } TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { @@ -158,37 +158,95 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { EXPECT_EQ(dot, computation->root_instruction()); } -TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { - HloComputation::Builder builder(TestName()); - HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); - HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1")); +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1)); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[1024,256] parameter(1) + exponential = s32[1024,256] exponential(arg1) + transpose = s32[256,1024] transpose(exponential), dimensions={1,0} + ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloComputation* computation = module->entry_computation(); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }, TransposeFolding::NeverFoldTranspose); - EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); - EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[256,1] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[1,256] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)); +} + +TEST_F(InstructionFusionTest, + DotOperationFusion_TransposeFusion_LHS_NonDefault) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[256,1] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); } class OpcodeFusionTest : public InstructionFusionTest { @@ -244,35 +302,33 @@ class OpcodeFusionTest : public InstructionFusionTest { } }; -TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) { +TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4}); Shape result_shape = ShapeUtil::MakeShape(F32, {4}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); - // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand - // is a parameter, so create an operand between the parameter and bitcast. HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); - HloInstruction* bitcast2 = builder.AddInstruction( - HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1)); + HloInstruction* reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1)); builder.AddInstruction( - HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2)); + HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp, + module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { +TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); - Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8}); + Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); @@ -280,11 +336,11 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { HloInstruction::CreateParameter(1, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); - HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary( - bitcast_shape, HloOpcode::kBitcast, broadcast2)); + HloInstruction* reshape3 = builder.AddInstruction( + HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, bitcast3, param1, {4, 4})); + dynamic_slice_shape, reshape3, param1, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -293,7 +349,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), - {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast, + {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); } @@ -700,6 +756,154 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +struct GatherLoopFusionTestSpec { + string test_name; + string hlo_computation_text; + + static string Name( + const ::testing::TestParamInfo& info) { + return info.param.test_name; + } +}; + +class GatherLoopFusionTest + : public OpcodeFusionTest, + public ::testing::WithParamInterface {}; + +TEST_P(GatherLoopFusionTest, GatherLoopFusion) { + const GatherLoopFusionTestSpec& spec = GetParam(); + string hlo_string = tensorflow::strings::StrCat( + "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); +} + +std::vector GetGatherLoopFusionTestSpecs() { + std::vector result; + + result.push_back({"FusedTensorFlowGatherV2", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_0", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_1", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedBatchDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + return result; +} + +INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index e8117377e61a4e21b8c45b929c518a18878fcb60..aa872d5ec9e7593b8d2f731421c17af590729529 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -100,7 +100,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) { + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) { const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -126,7 +127,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction)) { + } else if (PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, // rhs, and output need to be row-major. @@ -139,13 +141,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - // dot is a kDot or a kTransposeDot fusion node. In the latter case, if - // it represents X @ X, it may have just one operand. - if (dot->operand_count() > 1) { - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - } + const HloInstruction* rhs_instruction = dot->operand(1); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); @@ -181,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index c8edbb9e15a5b6f9c574f5fe9d130d149499ebd2..3c4fe68b830d9602f009b318d4e51e9a04a27e09 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" @@ -27,12 +28,17 @@ namespace cpu { // layout constraints for operands and results of library calls. class CpuLayoutAssignment : public LayoutAssignment { public: - explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit CpuLayoutAssignment( + ComputationLayout* entry_computation_layout, + const TargetMachineFeatures* target_machine_features) + : LayoutAssignment(entry_computation_layout), + target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 6ba030fff3bbc5f413bfb133114ceb5309b77672..429fc7b78608da0e9cd794ac294851b326f5be24 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -49,7 +50,12 @@ class CpuLayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -311,7 +317,12 @@ static StatusOr RunDotOutputFusion( result.addend_fusion_param = fusion_instruction->operand( fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(&computation_layout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 09f028463af68bbc2841fecdb2ca6c6a42498798..3ed7876715f64191f6e652d2b5cb1673df9a1b94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -16,13 +16,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace { -const char* const kXlaParallelCpuOption = "xla_cpu_parallel"; const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; +const char* const kXlaEnableExperimentalLlvmIrGemm = + "xla_enable_experimental_llvm_ir_gemm"; +const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -30,12 +33,6 @@ namespace xla { namespace cpu { namespace options { -bool CpuParallelBackendRequested(const HloModuleConfig& config) { - const auto& extra_options_map = - config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaParallelCpuOption) > 0; -} - bool OptimizeForSizeRequested(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); @@ -61,6 +58,49 @@ tensorflow::gtl::optional LlvmIrGemvTilingFactor( return tensorflow::gtl::nullopt; } +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; +} + +static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, + tensorflow::StringPiece suffix) { + CHECK_GE(str.size(), suffix.size()); + CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); + return str.substr(0, str.size() - suffix.size()); +} + +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrGemmTileSize); + if (it == extra_options_map.end()) { + return tensorflow::gtl::nullopt; + } + + std::vector tile_components = + tensorflow::str_util::Split(it->second, ':'); + CHECK_EQ(tile_components.size(), 3); + + int64 tile_size_m; + int64 tile_size_k; + int64 tile_size_n_in_vector_width; + + CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); + CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + + tensorflow::StringPiece tile_size_n_in_vector_width_str = + RemoveSuffix(tile_components[2], "*vectwidth"); + + CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); + + return std::tuple(tile_size_m, tile_size_k, + tile_size_n_in_vector_width); +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 6ba0fd24538b63a3da81083482e6bee3b552dfea..429b9e16cbdd6f623919533582481f1640118081 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -24,11 +24,13 @@ namespace xla { namespace cpu { namespace options { -bool CpuParallelBackendRequested(const HloModuleConfig& config); bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc deleted file mode 100644 index 662ee609232f5582ce74f4f515637b2623175e94..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ /dev/null @@ -1,192 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" - -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -namespace xla { -namespace cpu { - -StatusOr ParallelizationPreparation::Run(HloModule* module) { - XLA_VLOG_LINES(2, "ParallelizationPreparation ENTRY"); - XLA_VLOG_LINES(2, module->ToString()); - - bool changed = false; - TF_ASSIGN_OR_RETURN(changed, RunParallelTaskAssignment(module)); - - HloComputation* entry_computation = module->entry_computation(); - std::unordered_set outlined; - std::vector instructions_to_outline; - for (HloInstruction* instruction : - entry_computation->MakeInstructionPostOrder()) { - // If the instruction has been outlined, it no longer exists and we must not - // dereference it. - if (outlined.count(instruction) > 0) { - continue; - } - - // Skip parameters and constants, there is nothing to parallelize. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant) { - continue; - } - - // Outline 'instruction' in isolation if it was assigned parallel tasks. - if (OutlineParallelizableInstruction(instruction)) { - outlined.insert(instruction); - changed = true; - continue; - } - - instructions_to_outline.clear(); - HloInstruction* outline_candidate = instruction; - instructions_to_outline.push_back(outline_candidate); - - // Outline sole users with the current instruction. - while (CanOutlineWithUser(outline_candidate)) { - HloInstruction* prior_candidate = outline_candidate; - outline_candidate = *outline_candidate->users().begin(); - if (std::any_of(outline_candidate->operands().begin(), - outline_candidate->operands().end(), - [&](const HloInstruction* operand) { - // Do not consider any candidates which have operands - // other than the prior candidate, constants or - // parameters. Otherwise, we'd increase the fan-in which - // would reduce parallelism. - return operand->opcode() != HloOpcode::kParameter && - operand->opcode() != HloOpcode::kConstant && - operand != prior_candidate; - })) { - break; - } - instructions_to_outline.push_back(outline_candidate); - } - - outlined.insert(instructions_to_outline.begin(), - instructions_to_outline.end()); - - // Optimization to avoid replacing a single existing kCall with another - // kCall that just calls the first one. - if (instructions_to_outline.size() == 1 && - instructions_to_outline[0]->opcode() == HloOpcode::kCall) { - continue; - } - - module->OutlineExpressionFromComputation( - instructions_to_outline, - tensorflow::strings::StrCat("pp_", instruction->name()), - entry_computation); - changed = true; - } - - XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT"); - XLA_VLOG_LINES(2, module->ToString()); - return changed; -} - -StatusOr ParallelizationPreparation::RunParallelTaskAssignment( - HloModule* module) { - VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; - bool changed = false; - // Initialize ParallelTaskAssignment. - ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_, - module); - // Assign parallel tasks to HLOs in entry computation. - HloComputation* computation = module->entry_computation(); - for (auto* instruction : computation->instructions()) { - // Calculate target parallel task count in [1, max_parallelism_]. - const int64 target_parallel_task_count = - parallel_task_assignment.GetTargetParallelTaskCount(instruction); - if (target_parallel_task_count == 1) { - continue; - } - - // Assign feasible dimension partitions (based on actual dimension sizes). - auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) - .Run(target_parallel_task_count); - const int64 total_partition_count = - ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); - if (total_partition_count <= 1) { - // Feasible partition calculation resulting in no partitioning, so skip. - continue; - } - VLOG(2) << "Assigning parallel task count: " << total_partition_count - << " to instruction: " << instruction->name(); - // Map 'instruction' to assigned dimension partitioning. - instruction->set_outer_dimension_partitions(dim_partition_counts); - } - - return changed; -} - -bool ParallelizationPreparation::OutlineParallelizableInstruction( - HloInstruction* instruction) { - if (instruction->outer_dimension_partitions().empty()) { - return false; - } - // Store dimension partition counts before outlining (which clones - // 'instruction'). - std::vector dim_partition_counts = - instruction->outer_dimension_partitions(); - // Outline 'instruction' in its own sub-computation. - HloModule* module = instruction->parent()->parent(); - auto* call = module->OutlineExpressionFromComputation( - {instruction}, tensorflow::strings::StrCat("pp_", instruction->name()), - module->entry_computation()); - // Map previously assigned 'dim_partition_counts' to cloned root instruction. - VLOG(1) << "Outlining parallelizable" - << " caller: " << call->name() - << " callee: " << call->to_apply()->root_instruction()->name(); - call->to_apply()->root_instruction()->set_outer_dimension_partitions( - dim_partition_counts); - return true; -} - -bool ParallelizationPreparation::CanOutlineWithUser( - HloInstruction* instruction) { - if (instruction->users().size() != 1) { - // Do not outline 'instruction' with multiple users. - return false; - } - if (AssignedParallelTasks(instruction) || - AssignedParallelTasks(*instruction->users().begin())) { - // Do not outline if 'instruction' (or user) were assigned parallel tasks. - return false; - } - return true; -} - -bool ParallelizationPreparation::AssignedParallelTasks( - HloInstruction* instruction) { - return !instruction->outer_dimension_partitions().empty() || - (instruction->opcode() == HloOpcode::kCall && - !instruction->to_apply() - ->root_instruction() - ->outer_dimension_partitions() - .empty()); -} - -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h deleted file mode 100644 index 87be758ef5d0535fdce3a65e54ce225042019cdb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ - -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace cpu { - -// This pass prepares an HLO module for parallel execution by transforming -// subgraphs of the top-level computation into embedded computations which can -// be executed in parallel. -// TODO(b/29630486): Currently, it is limited to turning all instructions (which -// are not constants or parameters) in the entry computation into embedded -// computations. However, it could make sense to coarsen the parallelization to -// improve cache locality. Also, we will need to do something to intelligently -// handle While constructs. -class ParallelizationPreparation : public HloPassInterface { - public: - // 'max_parallelism': the maximum parallel task count per instruction. - // 'shape_size': shape size function used by HloCostAnalysis during parallel - // task assignment. - ParallelizationPreparation( - const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_(shape_size) {} - ~ParallelizationPreparation() override {} - - tensorflow::StringPiece name() const override { - return "cpu-parallel-prepare"; - } - - // Run parallel preparation on the given computation. Returns whether the - // computation was changed. - StatusOr Run(HloModule* module) override; - - private: - // Assigns parallel task partitions to conformant instructions in 'module'. - // Returns true on success or error status otherwise. - StatusOr RunParallelTaskAssignment(HloModule* module); - - // Outlines 'instruction' from entry computation, if it had - // been assigned parallel tasks in an earlier pass through the computation. - // Returns true if 'instruction' was successfully outlined, false otherwise. - bool OutlineParallelizableInstruction(HloInstruction* instruction); - - // Returns true if 'instruction' can be outlined into the same sub-computation - // with its single user (parallelizable instructions are not outlined with - // each other). Returns false otherwise. - bool CanOutlineWithUser(HloInstruction* instruction); - - // Returns true if 'instruction' (or the root of the sub-computation that - // 'instruction' calls) has had parallel tasks assigned in earlier pass. - // Returns false otherwise. - bool AssignedParallelTasks(HloInstruction* instruction); - - const int64 max_parallelism_; - const HloCostAnalysis::ShapeSizeFunction shape_size_; -}; - -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 40ace963270e8cead47cc731cc326351178dff7d..54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -31,15 +31,30 @@ XfeedManager* GetXfeedManager() { return manager; } +extern const char* const kEigenMatMulF16SymbolName = + "__xla_cpu_runtime_EigenMatMulF16"; extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32"; +extern const char* const kMKLMatMulF32SymbolName = + "__xla_cpu_runtime_MKLMatMulF32"; +extern const char* const kMKLMatMulF64SymbolName = + "__xla_cpu_runtime_MKLMatMulF64"; +extern const char* const kMKLSingleThreadedMatMulF32SymbolName = + "__xla_cpu_runtime_MKLSingleThreadedMatMulF32"; +extern const char* const kMKLSingleThreadedMatMulF64SymbolName = + "__xla_cpu_runtime_MKLSingleThreadedMatMulF64"; extern const char* const kEigenConvF16SymbolName = "__xla_cpu_runtime_EigenConvF16"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; +extern const char* const kEigenSingleThreadedMatMulF16SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 2141dfe1cedd6f9674acc348152574b4fd30895b..aa0e96712302e806a389c6ad05a2c1b6634ef901 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -41,11 +41,19 @@ namespace runtime { // the actual symbol. // 2. When using ahead-of-time compilation, the linker can resolve the name // because it is a symbol in the cpu_runtime library. +extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kMKLConvF32SymbolName; +extern const char* const kMKLMatMulF32SymbolName; +extern const char* const kMKLMatMulF64SymbolName; +extern const char* const kMKLSingleThreadedMatMulF32SymbolName; +extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; +extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; extern const char* const kEigenSingleThreadedConvF16SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index f385829cdf5cafbd35e083f47106734cdd5dde88..2ac950e6d93ade315808f2ca1d0bdd7bc85f53b9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" @@ -130,25 +131,23 @@ MatMulShape MatMulShapes[] = { // * transpose_lhs // * transpose_rhs // * single_threaded -using EigenMatMulTestParam = std::tuple; +using MatMulTestParam = std::tuple; -class EigenMatMulTest - : public CpuRuntimeTest, - public ::testing::WithParamInterface { +class EigenMatMulTest : public CpuRuntimeTest, + public ::testing::WithParamInterface { public: - static string Name( - const ::testing::TestParamInfo& info) { + static string Name(const ::testing::TestParamInfo& info) { MatMulShape shape = std::get<0>(info.param); bool transpose_lhs = std::get<1>(info.param); bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); return tensorflow::strings::Printf( - "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, + "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", single_threaded ? "single" : "multi"); } -}; // namespace xla +}; TEST_P(EigenMatMulTest, DoIt) { MatMulShape shape = std::get<0>(GetParam()); @@ -169,5 +168,74 @@ INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, ::testing::Bool()), EigenMatMulTest::Name); +#ifdef INTEL_MKL +class MKLMatMulTest : public CpuRuntimeTest, + public ::testing::WithParamInterface { + public: + static string Name(const ::testing::TestParamInfo& info) { + MatMulShape shape = std::get<0>(info.param); + bool transpose_lhs = std::get<1>(info.param); + bool transpose_rhs = std::get<2>(info.param); + bool single_threaded = std::get<3>(info.param); + + return tensorflow::strings::Printf( + "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, + transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); + } +}; + +std::unique_ptr> MKLMatrixMultiply(const Array2D& a, + const Array2D& b, + bool transpose_lhs, + bool transpose_rhs, + bool single_threaded) { + CHECK_EQ(a.width(), b.height()); + int64 m = a.height(); + int64 n = b.width(); + int64 k = a.width(); + + // The MKL matmul runtime function expects the matrix to be in column major + // order and array2d is in row-major order. Create transposes of a and b. The + // 'data' buffer in the transposed array is the original array in column major + // order. + auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs); + auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs); + + // Since we're going to transpose c before returning it, swap the order of the + // dimension sizes to ensure the returned array is properly dimensioned. + auto c_transpose = MakeUnique>(n, m); + if (single_threaded) { + __xla_cpu_runtime_MKLSingleThreadedMatMulF32( + nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), + m, n, k, transpose_lhs, transpose_rhs); + } else { + __xla_cpu_runtime_MKLMatMulF32(nullptr, c_transpose->data(), + a_transpose->data(), b_transpose->data(), m, + n, k, transpose_lhs, transpose_rhs); + } + return MaybeTransposeArray2D(*c_transpose, true); +} + +TEST_P(MKLMatMulTest, DoIt) { + MatMulShape shape = std::get<0>(GetParam()); + bool transpose_lhs = std::get<1>(GetParam()); + bool transpose_rhs = std::get<2>(GetParam()); + bool single_threaded = std::get<3>(GetParam()); + + auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k); + auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n); + auto c = + MKLMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs, single_threaded); + CheckMatrixMultiply(*a, *b, *c); +} + +INSTANTIATE_TEST_CASE_P(MKLMatMulTestInstantiaion, MKLMatMulTest, + ::testing::Combine(::testing::ValuesIn(MatMulShapes), + ::testing::Bool(), ::testing::Bool(), + ::testing::Bool()), + MKLMatMulTest::Name); +#endif // INTEL_MKL + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index f5e61aef534da57ce13d3ee9bbeaeaec31f53d2e..d97802ee45d6add3c466577d7624d9ca74e2f380 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -34,8 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -90,8 +88,8 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status CpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); @@ -241,21 +239,20 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( } StatusOr CpuTransferManager::TransferTupleBuffersFromOutfeed( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data) { return TransferBuffersFromOutfeedInternal(executor, buffer_data, /*is_tuple=*/true); } StatusOr CpuTransferManager::TransferArrayBufferFromOutfeed( - perftools::gputools::StreamExecutor* executor, void* destination, - int64 size_bytes) { + se::StreamExecutor* executor, void* destination, int64 size_bytes) { return TransferBuffersFromOutfeedInternal( executor, {{destination, size_bytes}}, /*is_tuple=*/false); } StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data, bool is_tuple) { std::vector> buffers; @@ -306,8 +303,8 @@ static std::unique_ptr CreateCpuTransferManager() { } static bool InitModule() { - xla::TransferManager::RegisterTransferManager(se::host::kHostPlatformId, - &CreateCpuTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::host::kHostPlatformId, &CreateCpuTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 6c7524d94716464218ba18ad9950f702d2759f89..6dfc666f09dfa6df740cd54bea0957e3144181bc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -37,36 +37,35 @@ class CpuTransferManager : public GenericTransferManager { CpuTransferManager(); ~CpuTransferManager() override {} - Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, - const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; - Status TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; private: // Transfers infeed data to device. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. StatusOr TransferBufferToInfeedInternal( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source); + se::StreamExecutor* executor, int64 size, const void* source); // Helper that transfers a tuple of element buffers from the device's outfeed. StatusOr TransferTupleBuffersFromOutfeed( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data); // Helper that transfers an array buffer from the device's outfeed. - StatusOr TransferArrayBufferFromOutfeed( - perftools::gputools::StreamExecutor* executor, void* destination, - int64 size_bytes); + StatusOr TransferArrayBufferFromOutfeed(se::StreamExecutor* executor, + void* destination, + int64 size_bytes); // On success, returns the shape that was transferred from the outfeed -- if // is_tuple is true, the returned shape will be a tuple of the returned shapes // for the given buffers. StatusOr TransferBuffersFromOutfeedInternal( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data, bool is_tuple); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index cfe7c9c3af0be109ac8a86753e880e2bcbceba41..8eb39d615fd482cdcea716ba7b105c643a2d8b87 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -41,17 +42,17 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { namespace { -// Loads a tile of values from a 2D tensor. -class TileLoader { +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { public: - // Constructs a TileLoader that will load a tile consisting of + // Constructs a MemoryTile that can operate on tiles consisting of // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { + : vsl_(vsl), ir_builder_(ir_builder) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { llvm::Value* total_offset = ir_builder->CreateMul( @@ -61,9 +62,10 @@ class TileLoader { } } - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. std::vector LoadTile(llvm::Value* minor_dim_offset) const { std::vector result; result.reserve(pointers_.size()); @@ -73,11 +75,104 @@ class TileLoader { return result; } + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(tensorflow::gtl::ArraySlice tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, + ir_builder_->getInt64(j)))); + } + } + return result; + } + private: VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* ir_builder_; std::vector pointers_; }; +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", + tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the // layout of the vector does not matter). This implementation uses a tiling // scheme to improve performance. @@ -139,38 +234,46 @@ class TileLoader { // TODO(sanjoy): We should investigate if using gather loads and scatter stores // can be used here have the same inner loop for both column-major and row-major // matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter { +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, - int64 tile_rows, int64 tile_cols, - int64 m, int64 k, llvm::Value* lhs, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { - CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), + ir_builder_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: void EmitOuterLoopBody(llvm::Value* column, int64 column_count, bool is_first_column); - TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m_, + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); } @@ -187,18 +290,14 @@ class ColumnMajorMatrixVectorProductEmitter { return result; } - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column); void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -210,26 +309,26 @@ class ColumnMajorMatrixVectorProductEmitter { void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( llvm::Value* column, int64 column_count, bool is_first_column) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, /*column_count=*/column_count); std::vector rhs_tile = LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, /*columns=*/column_count, is_first_column); EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); } void ColumnMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k_ % tile_cols_; - int64 column_limit = k_ - column_remainder; + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols_, is_first_column); - }); + ksl_.ForReturnVoid("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, @@ -238,29 +337,30 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column) { - int64 row_limit = m_ - (m_ % tile_rows_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows_, [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.ForReturnVoid( + "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m_ - (m_ % tile_rows_); - if (row_start == m_) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { return; } @@ -273,25 +373,25 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( is_first_scalar_col, ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( + ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -364,51 +464,55 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer // epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter { +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, - llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* addend, llvm::Value* result, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { - CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: - TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k_, + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); } void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators); void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -420,7 +524,7 @@ class RowMajorMatrixVectorProductEmitter { void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, int64 row_count) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, /*row_count=*/row_count); std::vector vector_accumulators; std::vector scalar_accumulators; @@ -428,7 +532,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); } - EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, &vector_accumulators); EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); @@ -465,12 +569,13 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, void RowMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m_ % tile_rows_; - int64 row_limit = m_ - row_remainder; + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + ksl_.ForReturnVoid( + "dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -478,60 +583,407 @@ void RowMajorMatrixVectorProductEmitter::Emit() { } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, int64 rows, + MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators) { - int64 column_limit = k_ - (k_ % tile_cols_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols_, [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set(vsl_.Add( + old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators) { - int64 column_start = k_ - (k_ % tile_cols_); - if (column_start == k_) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { return; } for (int r = 0; r < rows; r++) { llvm::Value* total_offset = ir_builder_->CreateMul( ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k_)); + ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); } } +// This class implements a tiled matrix multiplication algorithm, intended for +// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, +// Kazushige, and Robert Van De Geijn. "High-performance implementation of the +// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): +// 4). +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class MatrixMatrixBlockPanelEmitter { + public: + // Describe the dimensions of the GEBP kernel. These will usually not be the + // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP + // kernels with smaller dimensions. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { + return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); + } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the GEBP emitter. The LLVM IR emitted by + // the emitter, modulo the LLVM values holding the input and output buffers, + // must be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), + "_", max_vectorization_width(), "_", min_vectorization_width(), "_", + tile_size_m(), "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + ir_builder_(ir_builder), + ksl_(ir_builder_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; +}; + +void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, + ir_builder_, "gebp"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = + ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.ForReturnVoid( + "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile( + vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.ForReturnVoid( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, + result_memory_tile.LoadTile(n_i)); + ksl_.ForReturnVoid( + "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, + dims().n(), k_i, tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = + rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + } // namespace -DotOpEmitter::DotOpEmitter( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -541,23 +993,104 @@ DotOpEmitter::DotOpEmitter( hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, +/* static */ Status DotOpEmitter::EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, addend_array, - executable_run_options_value, ir_builder, - hlo_module_config, target_machine_features); + DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, + addend_array, executable_run_options_value, + ir_builder, hlo_module_config, + target_machine_features); return dot_emitter.Emit(); } -bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } +bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( + const DotOpEmitter::MatMultDims& mat_mult_dims) { + if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + return false; + } + + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + switch (primitive_type) { + default: + return false; + + case F32: + case F64: + case S32: + case S64: + break; + } + + if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && + mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { + return false; + } + + llvm::Value* lhs = lhs_array_.GetBasePointer(); + llvm::Value* rhs = rhs_array_.GetBasePointer(); + llvm::Value* target = target_array_.GetBasePointer(); + int64 m = mat_mult_dims.m; + int64 k = mat_mult_dims.k; + int64 n = mat_mult_dims.n; + + if (mat_mult_dims.lhs_column_major) { + std::swap(lhs, rhs); + std::swap(m, n); + } + + int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + ir_builder_->CreateMemSet( + target, ir_builder_->getInt8(0), size_bytes, + target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + + int64 max_target_vector_width = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; + std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = + GetGemmTileSize(); + + MatrixMatrixBlockPanelEmitter::Config config( + /*scalar_type=*/primitive_type, + MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + << config.GetCacheKey(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs, rhs, target, + [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { + MatrixMatrixBlockPanelEmitter gebp_emitter( + config, /*lhs=*/lhs, /*rhs=*/rhs, + /*result=*/target, ir_builder_); + gebp_emitter.Emit(); + }); + + return true; +} bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { @@ -580,7 +1113,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.m == 1) { bool rhs_effectively_row_major = - transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major; if (rhs_effectively_row_major) { k = mat_mult_dims.k; m = mat_mult_dims.n; @@ -596,7 +1129,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.n == 1) { bool lhs_effectively_column_major = - transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major; if (lhs_effectively_column_major) { m = mat_mult_dims.m; k = mat_mult_dims.k; @@ -611,7 +1144,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; + return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -644,47 +1177,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = vector_register_element_size; - int64 tile_cols = tiling_factor; - - string kernel_name = tensorflow::strings::StrCat( - "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = tiling_factor; - int64 tile_cols = vector_register_element_size; - - string kernel_name = tensorflow::strings::StrCat( - "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { RowMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } @@ -692,7 +1217,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { return true; } -tensorflow::Status DotOpEmitter::Emit() { +Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -715,6 +1240,11 @@ tensorflow::Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK( + dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); + const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -731,23 +1261,17 @@ tensorflow::Status DotOpEmitter::Emit() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_)) { + if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { return EmitCallToRuntime(); } // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = 0; - if (ShapeUtil::Rank(lhs_shape) >= 2) { - lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); - } - int64 rhs_reduction_dimension = 0; - if (ShapeUtil::Rank(rhs_shape) >= 2) { - rhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); - } + int64 lhs_reduction_dimension = + dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = + dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -777,8 +1301,11 @@ tensorflow::Status DotOpEmitter::Emit() { // from messing up the vectorization. std::unique_ptr reduction_loop = loop_nest.AddLoop( 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", - /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && - rhs_reduction_along_minor_dimension); + /*unroll_mode=*/ + (lhs_reduction_along_minor_dimension && + rhs_reduction_along_minor_dimension) + ? xla::llvm_ir::UnrollMode::kNoUnroll + : xla::llvm_ir::UnrollMode::kDefaultUnroll); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. @@ -871,10 +1398,10 @@ tensorflow::Status DotOpEmitter::Emit() { // loop. ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitScalarDot() { +Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; llvm::Value* lhs_value = @@ -899,12 +1426,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitCallToRuntime() { - DCHECK(ShapesAreLegalForRuntimeDot()); - +Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -913,22 +1438,34 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded_eigen = - hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(); + bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; const char* fn_name; switch (type) { + case F16: + fn_name = multi_threaded + ? runtime::kEigenMatMulF16SymbolName + : runtime::kEigenSingleThreadedMatMulF16SymbolName; + float_type = ir_builder_->getHalfTy(); + break; case F32: - fn_name = multi_threaded_eigen - ? runtime::kEigenMatMulF32SymbolName - : runtime::kEigenSingleThreadedMatMulF32SymbolName; + fn_name = multi_threaded + ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName + : runtime::kEigenMatMulF32SymbolName) + : (use_mkl_dnn + ? runtime::kMKLSingleThreadedMatMulF32SymbolName + : runtime::kEigenSingleThreadedMatMulF32SymbolName); float_type = ir_builder_->getFloatTy(); break; case F64: - fn_name = multi_threaded_eigen - ? runtime::kEigenMatMulF64SymbolName - : runtime::kEigenSingleThreadedMatMulF64SymbolName; + fn_name = multi_threaded + ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName + : runtime::kEigenMatMulF64SymbolName) + : (use_mkl_dnn + ? runtime::kMKLSingleThreadedMatMulF64SymbolName + : runtime::kEigenSingleThreadedMatMulF64SymbolName); float_type = ir_builder_->getDoubleTy(); break; default: @@ -972,8 +1509,8 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; - bool transpose_lhs = transpose_lhs_; - bool transpose_rhs = transpose_rhs_; + bool transpose_lhs = mat_mult_dims.lhs_non_canonical; + bool transpose_rhs = mat_mult_dims.rhs_non_canonical; if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); @@ -993,7 +1530,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { ir_builder_->getInt64(mat_mult_dims.k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); - return tensorflow::Status::OK(); + return Status::OK(); } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { @@ -1001,12 +1538,18 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - - return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), - lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), - rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, - LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; + const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + + return { + /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), + /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), + /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, + /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1, + /*target_column_major=*/ + LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1046,17 +1589,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) { // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. -static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { +static bool AreValidGemmShapes( + const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. - return output_shape.element_type() == F32 && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); + PrimitiveType output_primitive_type = output_shape.element_type(); + if (!(output_primitive_type == F64 || output_primitive_type == F32 || + output_primitive_type == F16)) { + return false; + } + + if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape))) { + return false; + } + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || + !is_aligned(output_shape)) { + return false; + } + + return true; } -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features) { // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -1073,28 +1638,18 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), + target_machine_features)) { + const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; } } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 9d748eb81f7850f3ccdb10f076eecfdc8326c05f..ed2a18976a0f1a88e7bb4632d3a63167d5c146ad 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -31,7 +31,9 @@ limitations under the License. namespace xla { namespace cpu { -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot @@ -55,17 +57,16 @@ class DotOpEmitter { // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. - static tensorflow::Status EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + static Status EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); private: - DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, const llvm_ir::IrArray& target_array, + DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -75,18 +76,18 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - tensorflow::Status Emit(); + Status Emit(); // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - tensorflow::Status EmitScalarDot(); + Status EmitScalarDot(); // Emit an LLVM IR implementation of the dot operation if we can. Returns // true if an LLVM IR implementation was emitted. bool EmitLlvmIrDotIfProfitable(); // Emits a call to the CPU runtime to perform the matrix multiply. - tensorflow::Status EmitCallToRuntime(); + Status EmitCallToRuntime(); // Emits a series of nested loops for iterating over an operand array in the // dot operation. Loops are constructed in major to minor dimension layout @@ -99,10 +100,6 @@ class DotOpEmitter { llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Our runtime operation requires that all arrays have the same layout, - // no padding, and a rank of two. - bool ShapesAreLegalForRuntimeDot() const; - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. @@ -115,11 +112,20 @@ class DotOpEmitter { // The number of columns on the RHS. int64 n; - // True if the LHS matrix column major. + // True if the LHS matrix is column major. bool lhs_column_major; - // True if the RHS matrix column major. + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; + + // True if the RHS matrix is column major. bool rhs_column_major; + + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; + + // True if the result matrix is column major. + bool target_column_major; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -127,6 +133,8 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; + bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. int64 GetGemvTilingFactor() const { @@ -135,9 +143,29 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + + // Returns true if we should use an experimental implementation of GEMM + // (general matrix matrix multiplication) if possible. + bool EnableExperimentalLlvmIrGemm() const { + return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); + } + + // Returns true if we should call into multi-threaded Eigen routines. + bool ShouldUseMultiThreadedEigen() { + return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + } + const HloInstruction& dot_; - const bool transpose_lhs_; - const bool transpose_rhs_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 99c5e16db70c6a203b4751c1ed8a106c0dc260e6..e97113dfa0f59e791d614c0093d0781e49c48ee4 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -115,7 +115,7 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( for (int i = 0; i < hlo->operand_count(); i++) { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(i))( - ElementwiseSourceIndex(index, *hlo, 0))); + ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index 7dcc4ca7fa08b478f24065275ffa69725dc51682..c56286559158758ca6db5ae097729286bde346f0 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -26,13 +26,13 @@ limitations under the License. namespace xla { namespace cpu { -void ExternalConstantPool::Insert(string name, const Literal& literal, +void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, int64 alignment) { CHECK(!ShapeUtil::IsTuple(literal.shape())); CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); CHECK(entries_.find(name) == entries_.end()); - int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); void* raw_pointer = tensorflow::port::AlignedMalloc( literal_size, std::max(alignment, sizeof(void*))); CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 8008a56df4dbf16e7b57aee8a344058bb0d5883d..0677f5f0b58005079890052a426e5f48c5d09ed1 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -43,7 +43,7 @@ class ExternalConstantPool { // The constant pool copies out the contents of `literal` into a buffer it // owns -- it does not keep pointers to `literal`, or to memory owned by // `literal`. - void Insert(string name, const Literal& literal, int64 alignment); + void Insert(string name, const LiteralSlice& literal, int64 alignment); // Find the constant with name `name` in this constant pool. If there isn't // such constant, return nullptr. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 788217aab6172b4e548452b3f6ffd4197c163ce4..b560b7531c0d24e6f670e61a15dce295d9fa2a49 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -24,8 +24,25 @@ limitations under the License. namespace xla { namespace cpu { +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features) { + CHECK(ShapeUtil::IsArray(shape)); + CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); + + // We don't require a layout to be set on `shape`. This only works on CPU + // because we don't pad our tensors or otherwise have complicated data tiling + // schemes. + + int64 allocation_size_bytes = + ShapeUtil::ElementsIn(shape) * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + return target_machine_features.minimum_alignment_for_allocation( + allocation_size_bytes); +} + bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution) { + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features) { // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -34,14 +51,28 @@ bool PotentiallyImplementedAsEigenConvolution( // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); - const Shape& kernel_shape = convolution.operand(0)->shape(); + const Shape& kernel_shape = convolution.operand(1)->shape(); + const Shape& output_shape = convolution.shape(); + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(input_shape) || !is_aligned(kernel_shape) || + !is_aligned(output_shape)) { + return false; + } + if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // Make sure input and kernel has the same data type. + CHECK( + ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape) || - ShapeUtil::ElementIsComplex(kernel_shape)) { + if (ShapeUtil::ElementIsComplex(input_shape)) { return false; } if (window_util::HasWindowReversal(convolution.window())) { @@ -69,7 +100,6 @@ bool PotentiallyImplementedAsEigenConvolution( } } - const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && dnums.output_batch_dimension() == 0 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 34b2003916933f5ec0a15d9e219063c0a912fa40..68fbc7caaa9bfec0ecd7cc7f473c8ca8afce19db 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -17,13 +17,20 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace cpu { bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution); + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features); + +// Computes the minimum alignment guaranteed for a tensor of shape `shape` on +// the target machine. +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features); // Dynamic loop bounds are specified as an array of dimension index // [start, limit) pairs of ir values (one for each partitioned outer dimension). diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..530ebce854fedf4e4db12139d5b56087b1176a6c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) { + const char* const hlo_string = R"( +HloModule ModuleWithConv + +ENTRY Conv { + input = f32[32,50,28,28]{3,2,1,0} parameter(0) + kernel = f32[0,32,5,5]{3,2,1,0} parameter(1) + ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel), + window={size=5x5}, + dim_labels=b01f_01io->b01f +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* entry_computation = module->entry_computation(); + + HloInstruction* conv_instr = entry_computation->root_instruction(); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution( + *conv_instr, target_machine_features)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4dffaee87f6b33933b58c8c58478eec918569197..59223fddac2f5f7e2e85de4d37e4b6c5760ae697 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -83,7 +83,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine_features, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), @@ -93,10 +93,8 @@ IrEmitter::IrEmitter( computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), - parallel_cpu_backend_( - options::CpuParallelBackendRequested(hlo_module_config_)), is_top_level_computation_(false), - target_machine_features_(target_machine), + target_machine_features_(*target_machine_features), external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() @@ -162,41 +160,59 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant // arrays into the Ir, in the off chance the LLVM optimizer can do something // interesting with it. + // + // TODO(b/29904935): Remove the large constant pool. const int kMaxInternalConstantSizeInBytes = 128; if (external_constant_pool_ && ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - global_for_const = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); + result = result_global; } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - global_for_const = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result = llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); + } + return result; +} + +Status IrEmitter::HandleConstant(HloInstruction* constant) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + const Literal& literal = constant->literal(); + llvm::Constant* global_for_const; + + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + emitted_literals_[&literal] = global_for_const; } emitted_value_[constant] = global_for_const; VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); @@ -216,32 +232,6 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { } } -// Calculate the alignment of a buffer with a particular size. -int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { - // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on - // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than kMallocAlignmentThreshold bytes and at least - // alignment 16 for allocations greater than or equal to - // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound - // by explicitly allocating the memory with posix_memalign. This is - // complicated by our desire to allow parameter buffers created by clients to - // be consumed directly by the JIT. - if (buffer_size == 0) { - // No need to align empty buffers. - return 1; - } - - const int64 kMallocAlignmentThreshold = 512; - - int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= kMallocAlignmentThreshold - ? 2 * pointer_size - : pointer_size; - DCHECK_GT(buffer_alignment, 0); - - return buffer_alignment; -} - // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); @@ -266,7 +256,7 @@ int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); - return MinimumAlignmentForBufferSize(buffer_size); + return target_machine_features_.minimum_alignment_for_allocation(buffer_size); } void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, @@ -279,7 +269,8 @@ void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size) { - int alignment = MinimumAlignmentForBufferSize(buffer_size); + int alignment = + target_machine_features_.minimum_alignment_for_allocation(buffer_size); if (alignment > 1) { llvm_ir::SetAlignmentMetadataForLoad(load, alignment); } @@ -438,12 +429,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, - length_32, 1); + ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, + acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, - length_32, 1); + ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, + program_buffer_address, + /*SrcAlign=*/1, length_32); } ir_builder_.CreateCall(release_func, @@ -517,7 +510,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32, BF16})); + /*supported_types=*/{F32, BF16, S32})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -814,13 +807,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { "Dot with multiple contracting dimensions not implemented."); } - if (dnums.lhs_contracting_dimensions(0) != - std::min(lhs->shape().dimensions_size() - 1, 1) || - dnums.rhs_contracting_dimensions(0) != 0) { - return Unimplemented( - "Dot with non-standard contracting dimensions not implemented."); - } - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -837,8 +823,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, /*addend_array=*/nullptr, + *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, target_machine_features_); } @@ -854,7 +839,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (PotentiallyImplementedAsEigenConvolution(*convolution)) { + // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support + // different data layouts. + if (PotentiallyImplementedAsEigenConvolution(*convolution, + target_machine_features_)) { const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -942,16 +930,26 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - bool multi_threaded_eigen = + bool multi_threaded = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool use_mkl_dnn = + hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); + + // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the + // potential race condition by setting the omp_num_threads. const char* fn_name = primitive_type == F16 - ? (multi_threaded_eigen + ? (multi_threaded ? runtime::kEigenConvF16SymbolName : runtime::kEigenSingleThreadedConvF16SymbolName) - : (multi_threaded_eigen - ? runtime::kEigenConvF32SymbolName + : (multi_threaded + ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName + : runtime::kEigenConvF32SymbolName) : runtime::kEigenSingleThreadedConvF32SymbolName); + if (!multi_threaded && use_mkl_dnn) { + LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " + "conv2d function."; + } llvm::Function* conv_func = llvm::cast( module_->getOrInsertFunction(fn_name, conv_type)); conv_func->setCallingConv(llvm::CallingConv::C); @@ -1010,12 +1008,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), - "convolution_sum_address", &ir_builder_, + lhs_llvm_type, "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); + llvm::Value* constant_zero = + llvm::Constant::getNullValue(lhs_llvm_type); + ir_builder_.CreateStore(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); std::vector kernel_spatial(num_spatial_dims); @@ -1169,7 +1169,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); @@ -1191,16 +1197,45 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { - if (hlo_module_config_.replica_count() == 1) { - // When there is a single replica, a cross replica sum is the identity - // function, and the buffer assignment expects a copy (we could eliminate - // these at the HLO level as an optimization). - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on CPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on CPU."); + } + + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + + // CRS with one operand and one replica is simply the identity function. + if (crs->operand_count() == 1) { return EmitMemcpy(*crs->operand(0), *crs); } - // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented("CrossReplicaSum is not implemented on CPU."); + // CRS with multiple operands and one replica produces a (one-deep) tuple. + std::vector operand_ptrs; + for (int64 i = 0; i < crs->operand_count(); ++i) { + llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i)); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(crs, {i})); + + const Shape& operand_shape = crs->operand(i)->shape(); + CHECK(ShapeUtil::IsArray(operand_shape)) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + + // TODO(b/63762267): Be more aggressive about specifying alignment. + ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, + ShapeUtil::ByteSizeOf(operand_shape)); + } + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + return Status::OK(); } // Fills up the free variables in 'index_with_free_var' with values from @@ -2061,44 +2096,7 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); - if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - DCHECK(root->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - fusion->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - fusion->operand(rhs_parameter->parameter_number()); - - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32})); - - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - - Shape target_shape = fusion->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array = GetIrArrayFor(fusion); - VLOG(2) << "HandleFusion kTransposeDot: "; - VLOG(2) << " lhs operand: " - << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); - VLOG(2) << " rhs operand: " - << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); - VLOG(2) << " target: " - << llvm_ir::DumpToString(*target_array.GetBasePointer()); - - // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *root, root->operand(0)->IsRank2Transpose(), - root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); - return Status::OK(); - } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, - assignment_)) { + if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); @@ -2141,9 +2139,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { GetIrArrayFor(fusion->operand(addend_param_number))); TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2161,8 +2159,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); - if (!computation->root_instruction()->outer_dimension_partitions().empty() && - !parallel_cpu_backend_) { + if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector call_args = GetArrayFunctionCallArguments( @@ -2441,7 +2438,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { auto* memcpy_instruction = ir_builder_.CreateMemCpy( - target, source, element_count * primitive_type_size, element_alignment); + target, /*DstAlign=*/element_alignment, source, + /*SrcAlign=*/element_alignment, element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. @@ -2538,8 +2536,12 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // nothing to do since the result was already written directly into the output // buffer. VLOG(2) << "FinishVisit root: " << root->ToString(); - llvm::Value* root_value = GetEmittedValueFor(root); - VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); + if (root->opcode() == HloOpcode::kOutfeed) { + VLOG(2) << " outfeed with value: " + << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0))); + } else { + VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root)); + } auto record_complete_computation = [&](llvm::Value* prof_counter) { if (prof_counter) { @@ -2547,22 +2549,6 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { } }; - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - record_complete_computation(GetProfileCounterFor(*instruction)); - return Status::OK(); - } - } - } - } - // For the entry computation this increment is cumulative of embedded // computations since it includes cycles spent in computations invoked by // While, Call etc. @@ -2905,7 +2891,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); + ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 509440251497cd7337284c39dae05c5f6c28e7c2..32c536e18fee86cc60067ba3b25ab1eb0e4233df 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -76,7 +76,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine, ExternalConstantPool* external_constant_pool); ~IrEmitter() override; @@ -514,9 +514,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Calculate the alignment of a buffer allocated for a given primitive type. int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); - // Calculate the alignment of a buffer with a particular size. - int MinimumAlignmentForBufferSize(int64 buffer_size); - // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; @@ -530,17 +527,32 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - const HloModuleConfig& hlo_module_config_; + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); - const bool parallel_cpu_backend_; + const HloModuleConfig& hlo_module_config_; bool is_top_level_computation_; - TargetMachineFeatures target_machine_features_; + const TargetMachineFeatures& target_machine_features_; int64 external_global_constant_counter_ = 0; ExternalConstantPool* external_constant_pool_; + struct LiteralPtrHashFunctor { + size_t operator()(const Literal* literal) const { return literal->Hash(); } + }; + + struct LiteralPtrEqualityFunctor { + bool operator()(const Literal* lhs, const Literal* rhs) const { + return *lhs == *rhs; + } + }; + + tensorflow::gtl::FlatMap + emitted_literals_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 557aa4a6bfc2ef70cafca4b226f8d8f15ea01e2b..2e55181eed867aca762f2b9b8310624ea12c7487 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -33,8 +33,8 @@ namespace cpu { // emitters for function and function argument access. // The llvm::Function is created with the standard function signature // used in the XLA CPU backend (see ir_function.cc for argument details). -// In addtion IrFunction saves the callers IR insert point during contruction, -// and restores it after desctruction. +// In addition IrFunction saves the callers IR insert point during construction, +// and restores it after destruction. // // Example usage: // diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc deleted file mode 100644 index 07a9f0efcb64db4b2ff0c6518d4b48eee9a505e0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ /dev/null @@ -1,531 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace se = ::perftools::gputools; - -namespace xla { -namespace cpu { - -ParallelCpuExecutable::ParallelCpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr> function_names, - std::unordered_map> - aligned_constants, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), - std::move(hlo_profile_index_map)), - jit_(std::move(jit)), - assignment_(std::move(assignment)), - function_names_(std::move(function_names)), - aligned_constants_(std::move(aligned_constants)) {} - -// Type of the computation function we expect in the JIT. -using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - int64*, int64*); - -// Given a pointer to an output buffer (following the CPU JIT calling -// conventions), mark addresses that are "live". The initial pointer itself is -// trivially live. If the shape of the buffer is a tuple, this analysis looks -// into the tuple's elements and marks them live as well (since tuples keep -// pointers to buffers) and also works recursively. -// address is an in-memory buffer address that contains some runtime XLA object. -// shape is its shape. marked_addresses is the set of live addresses to -// populate. -static void MarkLiveAddressesInOutput( - const void* address, const Shape& shape, - std::unordered_set* marked_addresses) { - marked_addresses->insert(address); - const uintptr_t* address_buffer = static_cast(address); - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const uintptr_t* element_address = address_buffer + i; - const void* element = reinterpret_cast(*element_address); - MarkLiveAddressesInOutput( - element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); - } - } -} - -namespace { - -// Executor manages the concurrent execution of 'functions' for instructions -// in 'pending' on 'thread_pool' (storing resulting data in 'results'). -class Executor { - public: - Executor(const HloInstructionMap& functions, - const ServiceExecutableRunOptions* run_options, - std::list* pending, - HloInstructionMap* results, void** temps_array, - int64* profile_counters_array, const BufferAssignment* assignment) - : functions_(functions), - run_options_(run_options), - pending_(pending), - results_(results), - temps_array_(temps_array), - profile_counters_array_(profile_counters_array), - thread_pool_(CHECK_NOTNULL(run_options_->xla_intra_op_thread_pool())), - assignment_(assignment) {} - - // Executes pending list of instructions on thread pool. - // Returns OK status on success, error status otherwise. - Status Run(); - - private: - // Schedules a parallel invocation of compute function for 'instruction' on - // 'thread_pool_', storing result in 'result_buffer'. - // If 'partition_buffers' is non-null, parallel task will be invoked on - // per-dimension partition [start, limit) values stored in - // 'partition_buffers'. - void Schedule(HloInstruction* instruction, int64* partition_buffers, - void* result_buffer); - - // Returns true if 'instruction' has been assigned parallel tasks (returns - // false otherwise). - bool HasParallelTasks(HloInstruction* instruction); - - // Returns in 'partition_buffers' the partition [size, limit) for each - // dimension. - int64* GetPartitionBuffers( - const std::vector>& partition); - - // Returns array of result buffers for all operands in 'instruction'. - const void** GetOperandBuffers(HloInstruction* instruction); - - // Arguments passed into Executor. - const HloInstructionMap& functions_; - const ServiceExecutableRunOptions* run_options_; - std::list* pending_; - HloInstructionMap* results_; - void** temps_array_; - int64* profile_counters_array_; - tensorflow::thread::ThreadPool* thread_pool_; - const BufferAssignment* assignment_; - - // Members used to manage instruction execution. - tensorflow::mutex completion_queue_lock_; - tensorflow::condition_variable completion_queue_cv_; - std::deque completion_queue_; - int64 instructions_in_flight_ = 0; - std::unordered_map tasks_in_flight_; -}; - -Status Executor::Run() { - while (!pending_->empty() || instructions_in_flight_ > 0) { - auto pending_it = pending_->begin(); - while (pending_it != pending_->end()) { - HloInstruction* instruction = *pending_it; - // Skip pending instructions whose operands aren't ready. - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), - [&](HloInstruction* operand) { - return !ContainsKey(*results_, operand); - })) { - ++pending_it; - continue; - } - - // Get 'result_buffer' reference to result buffer for 'instruction'. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array_[result_slice.index()]) + - result_slice.offset(); - - if (HasParallelTasks(instruction)) { - // 'instruction' has been assigned parallel task partitions. - CHECK_EQ(HloOpcode::kCall, instruction->opcode()); - HloInstruction* root = instruction->to_apply()->root_instruction(); - - // Create ShapePartitionIterator to iterate through all outer dimension - // partitions of 'instruction'. - ShapePartitionIterator partition_iterator( - root->shape(), root->outer_dimension_partitions()); - - const int64 partition_count = - partition_iterator.GetTotalPartitionCount(); - - // Record total parallel task count for 'instruction' before dispatch. - { - tensorflow::mutex_lock l(completion_queue_lock_); - tasks_in_flight_.insert(std::make_pair(instruction, partition_count)); - VLOG(2) << "Schedule PARALLEL" - << " instruction: " << instruction->name() - << " instruction.callee: " - << instruction->to_apply()->root_instruction()->name() - << " partition_count: " << partition_count; - } - - for (int64 i = 0; i < partition_count; ++i) { - // Get partition [start, limit) for each dimension. - auto partition_buffers = - GetPartitionBuffers(partition_iterator.GetPartition(i)); - Schedule(instruction, partition_buffers, result_buffer); - } - - } else { - // Set tasks in-flight to '1' for sequential instruction execution. - { - tensorflow::mutex_lock l(completion_queue_lock_); - tasks_in_flight_.insert(std::make_pair(instruction, 1)); - VLOG(2) << "Schedule SEQUENTIAL" - << " instruction: " << instruction->name() - << " instruction.callee: " - << instruction->to_apply()->root_instruction()->name(); - } - Schedule(instruction, nullptr, result_buffer); - } - - ++instructions_in_flight_; - pending_it = pending_->erase(pending_it); - } - // Wait for a completed HLO instruction to be present in the queue. We will - // pop it out of the queue and make the result available to its users. - HloInstruction* instruction; - do { - tensorflow::mutex_lock l(completion_queue_lock_); - if (completion_queue_.empty()) { - completion_queue_cv_.wait(l); - } - if (!completion_queue_.empty()) { - instruction = completion_queue_.front(); - completion_queue_.pop_front(); - break; - } - } while (true); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array_[result_slice.index()]) + - result_slice.offset(); - InsertOrDie(results_, instruction, result_buffer); - --instructions_in_flight_; - } - return Status::OK(); -} - -void Executor::Schedule(HloInstruction* instruction, int64* partition_buffers, - void* result_buffer) { - // The thread pool entry takes ownership of |operand_buffers|. - auto operand_buffers = GetOperandBuffers(instruction); - - auto function = FindOrDie(functions_, instruction); - const auto* exec_run_options = &run_options_->run_options(); - thread_pool_->Schedule([this, instruction, result_buffer, operand_buffers, - partition_buffers, exec_run_options, function]() { - function(result_buffer, exec_run_options, operand_buffers, temps_array_, - partition_buffers, profile_counters_array_); - - delete[] operand_buffers; - delete[] partition_buffers; - // Push the completed HLO instruction on the queue, the main - // thread will pop it off and potentially launch more work which - // uses the result. - // TODO(b/27458679) Consider alternative task scheduling and synchronization - // schemes. For example, we could avoid the overhead associate with the - // condvar here if the thread just dequed the next instruction to execute - // on completion. - { - tensorflow::mutex_lock l(completion_queue_lock_); - // Decrement in-flight task count for this completion. - if (--FindOrDie(tasks_in_flight_, instruction) == 0) { - completion_queue_.push_back(instruction); - completion_queue_cv_.notify_all(); - tasks_in_flight_.erase(instruction); - } - } - }); -} - -int64* Executor::GetPartitionBuffers( - const std::vector>& partition) { - // Return in 'partition_buffers' partition [size, limit) for each dimension. - auto partition_buffers = new int64[partition.size() * 2]; - for (int i = 0; i < partition.size(); ++i) { - partition_buffers[2 * i + 0] = partition[i].first; - partition_buffers[2 * i + 1] = partition[i].first + partition[i].second; - } - return partition_buffers; -} - -bool Executor::HasParallelTasks(HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kCall && - !instruction->to_apply() - ->root_instruction() - ->outer_dimension_partitions() - .empty(); -} - -const void** Executor::GetOperandBuffers(HloInstruction* instruction) { - // We cannot use a move-only RAII type like std::unique_ptr because the - // list of operands is allocated on the main thread and transferred to the - // worker via the lambda passed to enqueue_function. In order for the - // lambda to take ownership, we would need to use generalized lambda - // capture which is a feature new to C++14. - // TODO(b/27458679) Avoid dynamic allocations in Executor. - auto operand_buffers = new const void*[instruction->operand_count()]; - std::transform(instruction->operands().begin(), instruction->operands().end(), - operand_buffers, [this](HloInstruction* operand) { - return FindOrDie(*results_, operand); - }); - return operand_buffers; -} - -} // namespace - -Status ParallelCpuExecutable::AllocateBuffers( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); - VLOG(3) << "Allocating " << assignment_->Allocations().size() - << " allocations for module " << module().name(); - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - auto& allocation = assignment_->GetAllocation(i); - - VLOG(3) << allocation.ToString(); - - if (allocation.is_entry_computation_parameter()) { - VLOG(3) << "allocation #" << i << " is a parameter"; - continue; - } - - if (allocation.is_thread_local()) { - VLOG(3) << "buffer #" << i << " is thread-local"; - continue; - } - - int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { - VLOG(3) << "buffer #" << i - << " is in the preallocated result ShapedBuffer"; - } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); - - VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; - } - - // Since the output buffer and all the temporary buffers were written into - // by the JITed code, msan has no way of knowing their memory was - // initialized. Mark them initialized so that msan doesn't flag loads from - // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); - } - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - VLOG(3) << "result index: " << result_slice.index(); - - return Status::OK(); -} - -Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { - // Allocate profiling counters for each hlo instruction that we would like to - // profile. - std::vector* profile_counters = nullptr; - if (hlo_execution_profile) { - profile_counters = hlo_execution_profile->mutable_profile_counters(); - } - - std::vector buffer_pointers; - buffer_pointers.reserve(buffers.size()); - for (auto device_allocation : buffers) { - buffer_pointers.push_back(device_allocation.opaque()); - } - - // Resolve functions for all the HLO instructions ahead of time. - HloInstructionMap functions; - for (auto& entry : *function_names_) { - tensorflow::mutex_lock lock(jit_mutex_); - HloInstruction* instruction = entry.first; - llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry.second); - TF_RET_CHECK(sym); - InsertOrDie( - &functions, instruction, - reinterpret_cast(cantFail(sym.getAddress()))); - } - - // Map containing pointers to result buffers for each instruction. - HloInstructionMap results; - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - - std::list pending; - - // Call the function for each HLO instruction in topological order. - const HloComputation& entry_computation = *module().entry_computation(); - for (auto* instruction : entry_computation.MakeInstructionPostOrder()) { - // Parameters and constants have no functions associated with them. Instead - // just copy the existing buffer into the map containing instruction - // results.. - if (instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie( - &results, instruction, - arguments[instruction->parameter_number()]->root_buffer().opaque()); - } else if (instruction->opcode() == HloOpcode::kConstant) { - unsigned char* aligned_data = - FindOrDie(aligned_constants_, instruction).get(); - InsertOrDie(&results, instruction, aligned_data); - } else { - TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall); - pending.push_back(instruction); - } - } - - // TODO(b/27458679) Manage scheduling based on in-flight concurrency limits. - // For example, if we expect a library conv/matmul call to run at max - // concurrency, we should not dispatch runnable instructions until the - // library call is finished (to avoid expensive cache invalidation). - Executor executor( - functions, run_options, &pending, &results, buffer_pointers.data(), - profile_counters ? profile_counters->data() : nullptr, assignment_.get()); - - TF_RETURN_IF_ERROR(executor.Run()); - - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - - { - tensorflow::mutex_lock lock(mutex_); - double nanoseconds = (end_micros - start_micros) * 1000.0; - execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - } - - return Status::OK(); -} - -StatusOr> ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - auto result_buffer = MakeUnique( - /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), - stream->parent()->platform(), stream->parent()->device_ordinal()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, - hlo_execution_profile)); - - // Copy DeviceMemoryBase values which into the respective location in - // ShapedBuffer which is returned to the caller. - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( - [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { - const auto& sources = this->GetRootPointsToSet().element(index); - - // The points to set is unambiguous so the set should be a singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer such as a - // tuple element. The source instruction should have a non-parameter - // buffer assigned. - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice(src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); - - // Free all buffers not in the result. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } - - return std::move(result_buffer); -} - -StatusOr> -ParallelCpuExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { - // TODO(b/30671675): Implement asynchronous execution mode. - return Unimplemented( - "Asynchronous execution on stream is not yet supported on CPU."); -} - -const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { - return assignment_->points_to_analysis().GetPointsToSet( - module().entry_computation()->root_instruction()); -} - -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h deleted file mode 100644 index c393e9b8ea39bfb4c605ebba8e2cd29726bc4af9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { -namespace cpu { - -// CPU-targeting parallel implementation of the XLA Executable interface. -// -// Wraps a JIT-ed object that can be executed "on device". We JIT for the host -// architecture, so JIT-ed code and host code share the same ABI. -class ParallelCpuExecutable : public Executable { - public: - ParallelCpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr> function_names, - std::unordered_map> - aligned_constants, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map); - ~ParallelCpuExecutable() override {} - - StatusOr> ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; - - StatusOr> ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; - - // This should be called after set_ir_module_string. - const string& ir_module_string() const { return ir_module_string_; } - - void set_ir_module_string(const string& ir_module_string) { - ir_module_string_ = ir_module_string; - } - - static int64 ShapeSizeBytes(const Shape& shape) { - // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { - return sizeof(void*); - } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - } - - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on CPU parallel executable. - return Unimplemented( - "Equality test on CPU parallel executable is not implemented."); - } - - private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); - - // Calls the generated functions in 'function_names_', performing the - // computation with the given arguments using the supplied buffers. - Status ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); - - // Returns the points-to set of the root instruction of the entry - // computation. Uses points-to analysis from buffer assignment. - const PointsToSet& GetRootPointsToSet() const; - - // The JIT containing compiled modules. - tensorflow::mutex jit_mutex_; - const std::unique_ptr jit_ GUARDED_BY(jit_mutex_); - - // Buffer assignment for the buffers we need to allocate. - const std::unique_ptr assignment_; - - // The LLVM IR, in string format, of the unoptimized module generated for this - // ParallelCpuExecutable. We save a string instead of an llvm::Module* because - // leaving llvm::Module* in a singleton can cause the heap checker to emit - // false positives. - string ir_module_string_; - - // Map containing the JITted function names for each HLO instruction. - const std::unique_ptr> function_names_; - - // Map from HLO Constant instructions to a pointer to their literal data. - // The data stored in the protocol buffer might be insufficiently aligned, - // we create a sufficiently aligned copy and store it in this map. - const std::unordered_map> - aligned_constants_; - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCpuExecutable); -}; - -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 1e439cde11cf74272101b80c867a308e51ab26a6..54af40506dab48b3c2a3a44eb0b5f5fb213a32ec 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -29,7 +29,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( : LoopEmitter(target_element_generator, target_array, ir_builder), dynamic_loop_bounds_(dynamic_loop_bounds) {} -llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( +std::vector +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) { CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); @@ -69,7 +70,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); CHECK(exit_bb_ != nullptr); - return array_index; + return {array_index}; } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index ce92e36a944de33b991d97460f0b2e859ad56081..755715634aa70a822b21d25dcae20a8fe053477a 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -60,7 +60,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; - llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + std::vector EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) override; private: diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index deb21bf4ef5895cfdbec5c2449b6ce7b306a7008..4fa5984b0466b178a587e97cbced97deac749f74 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -38,7 +38,7 @@ class SimpleCostModel : public ParallelCostModel { const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -63,7 +63,7 @@ class DefaultCostModel : public ParallelCostModel { int64 max_parallelism; // Calculate flops-to-bytes-ratio for 'instruction'. const int64 bytes_accessed = - std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction)); const float flops_to_bytes_ratio = cost_analysis_->flop_count(*instruction) / static_cast(bytes_accessed); @@ -71,7 +71,7 @@ class DefaultCostModel : public ParallelCostModel { if (flops_to_bytes_ratio <= 1.0) { // Limit max parallelism for I/O bound instructions by assuming a // sub-linear scaling function (fit based on empirical benchmark results). - // TODO(29630486) Develop system bandwidth model. + // TODO(b/29630486) Develop system bandwidth model. max_parallelism = std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); // Use shape size instruction cost and L2 cache size min per-thread cost. @@ -81,7 +81,7 @@ class DefaultCostModel : public ParallelCostModel { // Use max parallelism for compute bound instructions. max_parallelism = max_parallelism_; // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. + // TODO(b/29630486) Improve on this linear cost model. // Consider making 'min_cost_per_thread' be a function of the target // bandwidth limit for instructions with low arithmetic complexity. instruction_cost = @@ -93,7 +93,7 @@ class DefaultCostModel : public ParallelCostModel { } // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel { ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module, + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -127,25 +129,28 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Emit custom loops (kSelectAndScatter). + // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - instruction->opcode() == HloOpcode::kGetTupleElement || - instruction->opcode() == HloOpcode::kBitcast || - instruction->opcode() == HloOpcode::kFft || - (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && + auto opcode = instruction->opcode(); + if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || + opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || + opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || + opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || + opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || + opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + (opcode == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) || + PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_) || + (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { return 1; } + // Consult 'cost_model_' to compute target parallel task count. return cost_model_->GetParallelTaskCount(instruction); } @@ -230,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { ParallelTaskAssignment parallel_task_assignment(max_parallelism_, - shape_size_function_, module); + shape_size_function_, module, + &target_machine_features_); // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 7140dabe516cd7ea9260456e994e8b63b68c60d6..8becc8fa23424d7454cc783eb9d853aecb5d053b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,7 +40,8 @@ class ParallelTaskAssignment { // 'module': the containing HloModule. ParallelTaskAssignment(const int64 max_parallelism, const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + HloModule* module, + const TargetMachineFeatures* target_machine_features); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -47,6 +49,7 @@ class ParallelTaskAssignment { private: std::unique_ptr cost_model_; + const TargetMachineFeatures& target_machine_features_; }; // ParallelTaskAssigner computes target parallel task counts for all HLOs @@ -63,8 +66,11 @@ class ParallelTaskAssigner : public HloPassInterface { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size, + const TargetMachineFeatures* target_machine_features) + : max_parallelism_(max_parallelism), + shape_size_function_(shape_size), + target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -94,6 +100,7 @@ class ParallelTaskAssigner : public HloPassInterface { int64 max_parallelism_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc2efbaf9a22b02cd729da2f367d53bc15506836 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class ParallelTaskAssignmentTest : public HloVerifiedTestBase { + protected: + const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = + cpu::CpuExecutable::ShapeSizeBytes; + + // Use any value larger than 2 since we only test whether a module is + // parallelized or not + const int max_parallelism_ = 10; + + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; + + ParallelTaskAssignmentTest() + : target_machine_features_([](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }) {} + + StatusOr RunParallelTaskAssigner(HloModule* module) { + return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_, + &target_machine_features_) + .Run(module); + } +}; + +TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_Dot + ENTRY Dot { + dot_lhs = f32[196614,2]{1,0} parameter(0) + dot_rhs = f32[2,1]{1,0} parameter(1) + ROOT dot = f32[196614,1]{1,0} dot(dot_lhs, dot_rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, + FusedComputationWithDotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_DotNestedInFusedComp + fused_computation.0 { + parameter.0 = f32[196614,2]{1,0} parameter(0) + parameter.0.1 = f32[2,1]{1,0} parameter(1) + parameter.0.2 = f32[196614,1]{1,0} parameter(2) + dot.0 = f32[196614,1]{1,0} dot(parameter.0, parameter.0.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add.0 = f32[196614,1]{1,0} add(dot.0, parameter.0.2) + + } + ENTRY DotNestedInFusedComp { + parameter = f32[196614,2]{1,0} parameter(0) + parameter.1 = f32[2,1]{1,0} parameter(1) + parameter.2 = f32[196614,1]{1,0} parameter(2) + ROOT fusion = f32[196614,1]{1,0} fusion(parameter, parameter.1, + parameter.2), kind=kOutput, calls=fused_computation.0 + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_rng + ENTRY Rng { + src0 = f32[] parameter(0) + src1 = f32[] parameter(1) + ROOT rng0 = f32[1234567,2]{1,0} rng(f32[] src0, f32[] src1), + distribution=rng_uniform + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_infeed_outfeed + ENTRY InfeedOutfeed { + infeed0 = u32[12345678,2]{1,0} infeed() + ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h index 39e20ed45639040110b99ddb52eb6f6dab26dfaa..7337c907f5c83d608641b7382e75902e6f6c05d4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc new file mode 100644 index 0000000000000000000000000000000000000000..c60580d6e763c659102b570ed044706f87899437 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" +#include +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int64; + +#ifdef INTEL_MKL +#include +#include "mkldnn.hpp" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" + +namespace { + +// Downcast an int64 to int and check if value is in range. +int ToInt(int64 input) { + int output = static_cast(input); + if (static_cast(output) != input) { + std::cerr << "Error occurred in downcasting int64 to int32: Value " << input + << " is out-of-range for type int32. \n"; + exit(1); + } + return output; +} + +using mkldnn::convolution_direct; +using mkldnn::convolution_forward; +using mkldnn::engine; +using mkldnn::memory; +using mkldnn::padding_kind; +using mkldnn::primitive; +using mkldnn::prop_kind; +using mkldnn::reorder; +using mkldnn::stream; + +template +void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, int64 input_batch, int64 input_rows, + int64 input_cols, int64 input_channels, int64 kernel_rows, + int64 kernel_cols, int64 kernel_channels, int64 kernel_filters, + int64 output_rows, int64 output_cols, int64 row_stride, + int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, + int64 rhs_row_dilation, int64 rhs_col_dilation) { + auto cpu_engine = engine(engine::cpu, 0); + + // Create a vector primitive to hold the network. + std::vector net; + + // Since memory::dims takes int for each dimension, we downcast the int64 + // values to int using the ToInt function defined above. + memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels), + ToInt(input_rows), ToInt(input_cols)}; + memory::dims conv1_weights_dim = {ToInt(kernel_filters), + ToInt(kernel_channels), ToInt(kernel_rows), + ToInt(kernel_cols)}; + memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters), + ToInt(output_rows), ToInt(output_cols)}; + memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)}; + // Note: In MKL_DNN dilation starts from 0. + memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1), + ToInt(rhs_col_dilation - 1)}; + memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)}; + memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)}; + + // Create memory for user data. Input and output data have format of NHWC and + // kernel data has format of HWIO. + // Note that as a convention in MKL-DNN, the dimensions of the data is always + // described in NCHW/IOHW, regardless of the actual layout of the data. + auto user_src_memory = + memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc}, + cpu_engine}, + lhs); + auto user_weights_memory = memory( + {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio}, + cpu_engine}, + rhs); + auto user_dst_memory = + memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc}, + cpu_engine}, + out); + + // Create memory descriptors for convolution data with no specified format for + // best performance. + auto conv1_src_mem_desc = memory::desc( + {conv1_src_dim}, memory::data_type::f32, memory::format::any); + auto conv1_weights_mem_desc = memory::desc( + {conv1_weights_dim}, memory::data_type::f32, memory::format::any); + auto conv1_dst_mem_desc = memory::desc( + {conv1_dst_dim}, memory::data_type::f32, memory::format::any); + + // Create a convolution. + auto conv1_desc = convolution_forward::desc( + prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc, + conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates, + conv1_padding_l, conv1_padding_r, padding_kind::zero); + auto conv1_prim_desc = + convolution_forward::primitive_desc(conv1_desc, cpu_engine); + + // Create reorders for data and weights if layout requested by convolution is + // different from NCHW/OIHW. + auto conv1_src_memory = user_src_memory; + if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) != + user_src_memory.get_primitive_desc()) { + conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc()); + net.push_back(reorder(user_src_memory, conv1_src_memory)); + } + + auto conv1_weights_memory = user_weights_memory; + if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) != + user_weights_memory.get_primitive_desc()) { + conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc()); + net.push_back(reorder(user_weights_memory, conv1_weights_memory)); + } + + // Check if output need layout conversion. If yes, create memory for + // intermediate layer of conv1_dst_memory. + bool need_output_conversion = + (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) != + user_dst_memory.get_primitive_desc()); + auto conv1_dst_memory = need_output_conversion + ? memory(conv1_prim_desc.dst_primitive_desc()) + : user_dst_memory; + + // Create convolution primitive and add it to net. + net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory, + conv1_weights_memory, conv1_dst_memory)); + if (need_output_conversion) { + net.push_back(reorder(conv1_dst_memory, user_dst_memory)); + } + stream(stream::kind::eager).submit(net).wait(); +} +} // namespace +#endif // INTEL_MKL + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConvF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, + int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { +#ifdef INTEL_MKL + // Since MKL_DNN cannot handle transposed convolution, this is handled by + // Eigen. + if (lhs_row_dilation > 1 || lhs_col_dilation > 1) { + __xla_cpu_runtime_EigenConvF32( + run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols, + input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, col_stride, + padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); + } else { + MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols, + input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, + col_stride, padding_top, padding_bottom, padding_left, + padding_right, lhs_row_dilation, lhs_col_dilation, + rhs_row_dilation, rhs_col_dilation); + } +#else + std::cerr << "Attempt to call MKL Conv2D runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +#endif // INTEL_MKL +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h new file mode 100644 index 0000000000000000000000000000000000000000..b239e71d231c5237a51a7048025bc2dcbd54fbe5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ + +#include +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_MKLConvF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e02475babad7160d0f43bb23de0b50e..0bf693edd0b985a4e62c16414646cc6a17db26ee 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; + // Unsupported FFT type + abort(); } } @@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + // Unsupported FFT rank + abort(); } } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc b/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc new file mode 100644 index 0000000000000000000000000000000000000000..af0275c8bd00c82220fbe116eb90d2692393713b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/core/platform/macros.h" + +namespace { +using tensorflow::uint16; +using tensorflow::uint32; + +// Helper class that lets us access the underlying bit representation +// of a float without breaking C++ strict aliasing. +class AliasedFloatInt { + public: + static_assert(sizeof(float) == sizeof(uint32), ""); + + static AliasedFloatInt FromFloat(float f) { + AliasedFloatInt value; + value.set_float(f); + return value; + } + + static AliasedFloatInt FromUInt(uint32 u) { + AliasedFloatInt value; + value.set_uint(u); + return value; + } + + void set_float(float f) { memcpy(&value_, &f, sizeof(f)); } + float as_float() const { + float f; + memcpy(&f, &value_, sizeof(f)); + return f; + } + + void set_uint(uint32 u) { value_ = u; } + uint32 as_uint() const { return value_; } + + private: + uint32 value_; +}; +} // namespace + +// __gnu_f2h_ieee and __gnu_h2f_ieee are marked as weak symbols so if XLA is +// built with compiler-rt (that also defines these symbols) we don't get a +// duplicate definition linker error. Making these symbols weak also ensures +// that the compiler-rt definitions "win", but that isn't essential. + +// Algorithm copied from Eigen. +uint16 TF_ATTRIBUTE_WEAK __gnu_f2h_ieee(float float_value) { + AliasedFloatInt f = AliasedFloatInt::FromFloat(float_value); + + const AliasedFloatInt f32infty = AliasedFloatInt::FromUInt(255 << 23); + const AliasedFloatInt f16max = AliasedFloatInt::FromUInt((127 + 16) << 23); + const AliasedFloatInt denorm_magic = + AliasedFloatInt::FromUInt(((127 - 15) + (23 - 10) + 1) << 23); + unsigned int sign_mask = 0x80000000u; + uint32 o = static_cast(0x0u); + + unsigned int sign = f.as_uint() & sign_mask; + f.set_uint(f.as_uint() ^ sign); + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.as_uint() >= + f16max.as_uint()) { // result is Inf or NaN (all exponent bits set) + o = (f.as_uint() > f32infty.as_uint()) ? 0x7e00 + : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.as_uint() < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.set_float(f.as_float() + denorm_magic.as_float()); + + // and one integer subtract of the bias later, we have our final float! + o = static_cast(f.as_uint() - denorm_magic.as_uint()); + } else { + unsigned int mant_odd = + (f.as_uint() >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.set_uint(f.as_uint() + (static_cast(15 - 127) << 23) + + 0xfff); + // rounding bias part 2 + f.set_uint(f.as_uint() + mant_odd); + // take the bits! + o = static_cast(f.as_uint() >> 13); + } + } + + o |= static_cast(sign >> 16); + return o; +} + +// Algorithm copied from Eigen. +float TF_ATTRIBUTE_WEAK __gnu_h2f_ieee(uint16 h) { + const AliasedFloatInt magic = AliasedFloatInt::FromUInt(113 << 23); + const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + AliasedFloatInt o; + + o.set_uint((h & 0x7fff) << 13); // exponent/mantissa bits + unsigned int exp = shifted_exp & o.as_uint(); // just the exponent + o.set_uint(o.as_uint() + ((127 - 15) << 23)); // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.set_uint(o.as_uint() + ((128 - 16) << 23)); // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.set_uint(o.as_uint() + (1 << 23)); // extra exp adjust + o.set_float(o.as_float() - magic.as_float()); // renormalize + } + + o.set_uint(o.as_uint() | (h & 0x8000) << 16); // sign bit + return o.as_float(); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fp16.h b/tensorflow/compiler/xla/service/cpu/runtime_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..01d92d031904af99884c2583a8c7b5086b289d44 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fp16.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ + +#include "tensorflow/core/platform/types.h" + +// Converts an F32 value to a F16. +extern "C" tensorflow::uint16 __gnu_f2h_ieee(float); + +// Converts an F16 value to a F32. +extern "C" float __gnu_h2f_ieee(tensorflow::uint16); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index bff57d33ae23fbba8c664cbd18df77e4c35eb592..39b13183ff093611a42b3931d45f64eadb420622 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -63,30 +63,41 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims); } +template +void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + if (m == 1 || n == 1) { + // Despite being single threaded, this version of matrix * vector is faster. + xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } else { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); + } +} + } // namespace +void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, + Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - // Despite being single threaded, this version of matrix * vector is faster. - xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - // Despite being single threaded, this version of matrix * vector is faster. - xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h index fdb644651dd5d0fa0345580f52ed0fb051672285..d96fe3d58bd5ffbad347e3ede3534d1d47be697a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { @@ -25,6 +26,12 @@ extern "C" { // order. 'out' is a pointer to a buffer sufficiently large to hold the result // of the operation. Following standard nomenclature: lhs is m x k, // rhs is k x n, and out is m x n. +extern void __xla_cpu_runtime_EigenMatMulF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenMatMulF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8c8dd5e93d53db8d87be0208b5cf4daac3464f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" +#include "third_party/intel_mkl_ml/include/mkl_cblas.h" +#include "third_party/intel_mkl_ml/include/mkl_service.h" + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/types.h" + +#define EIGEN_USE_THREADS +#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" + +using tensorflow::int32; +using tensorflow::int64; + +namespace { +// BLAS GEMM API for 32-bit Matrix Multiplication. + +// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c. +// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0. +// Matrix lhs, rhs and out are all colum-major. +void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const float alpha = 1.0f, beta = 0.0f; + // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c, + // respectively. For column-major matrices, the leading dimension is the + // stride between consecutive columns (which equals the number of rows). If + // the matrix is transposed, the leading dimension is the stride between + // consecutive rows (which equals the number of columns). + int lda = transpose_lhs ? k : m; + int ldb = transpose_rhs ? n : k; + int ldc = m; + cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans, + transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs, + lda, rhs, ldb, beta, out, ldc); +} + +// BLAS GEMM API for 64-bit Matrix Multiplication. + +// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c. +// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0. +// Matrix lhs, rhs and out are all colum-major. +void MatMulF64(const void* run_options_ptr, double* out, double* lhs, + double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const float alpha = 1.0f, beta = 0.0f; + // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c, + // respectively. For a column-major matrix, the leading dimension is the + // stride between consecutive columns (which equals the number of rows). If + // the matrix is transposed, the leading dimension is the stride between + // consecutive rows (which equals the number of columns). + int lda = transpose_lhs ? k : m; + int ldb = transpose_rhs ? n : k; + int ldc = m; + cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans, + transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs, + lda, rhs, ldb, beta, out, ldc); +} + +} // namespace + +void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, + float* lhs, float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread + // number specified in intra_op_thread_pool to MKL. + int prev_num_threads = mkl_set_num_threads_local( + run_options->intra_op_thread_pool()->numThreads()); + MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +// BLAS GEMM API for 64-bit Matrix Multiplication +void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, + double* lhs, double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread + // number specified in intra_op_thread_pool to MKL. + int prev_num_threads = mkl_set_num_threads_local( + run_options->intra_op_thread_pool()->numThreads()); + MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + // Set the thread number to 1 for single threaded excution. + int prev_num_threads = mkl_set_num_threads_local(1); + MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + // Set the thread number to 1 for single threaded excution. + int prev_num_threads = mkl_set_num_threads_local(1); + MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +#endif // INTEL_MKL diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h new file mode 100644 index 0000000000000000000000000000000000000000..831b796efb971f6fb0170e2321c00ac415f2830f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ + +#include +#include "tensorflow/core/platform/types.h" +#ifdef INTEL_MKL +#include "third_party/intel_mkl_ml/include/mkl_cblas.h" + +extern void __xla_cpu_runtime_MKLMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_MKLMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +#else +extern void __xla_cpu_runtime_MKLMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} +extern void __xla_cpu_runtime_MKLMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} + +#endif // INTEL_MKL +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc b/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc deleted file mode 100644 index 435820cdd36e2a906d9dfbe2555f4c0df623c729..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" - -using tensorflow::int32; -using tensorflow::int64; - -namespace { - -// Does mat * x or mat^T * x. -template -void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols, - int32 transpose) { - // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to - // be faster (b/30223679). See also: the matmul op kernel in TensorFlow, - // which implements the same optimization. - using Matrix = Eigen::Matrix; - using MatrixMap = Eigen::Map; - - using Vector = Eigen::Matrix; - using VectorMap = Eigen::Map; - - auto x = VectorMap(x_buf, cols); - auto out = VectorMap(out_buf, rows); - - int64 mat_rows = rows; - int64 mat_cols = cols; - - if (transpose) { - std::swap(mat_rows, mat_cols); - } - - auto mat = MatrixMap(mat_buf, mat_rows, mat_cols); - - if (transpose) { - out = mat.transpose() * x; - } else { - out = mat * x; - } -} - -// Converts matmul-style args to matvec. -template -void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, - int32 transpose_lhs, int32 transpose_rhs) { - // If the input is in the form x * A, where x is the vector, then bring A back - // over to the left hand side. We make use of the identity - // - // (x * A)^T = A^T * x^T - // - // We do not need to take the transpose of x or of the result since taking - // the transpose of a vector does not change the memory layout. - const int64 cols = k; - - T* mat; - T* vec; - int64 rows; - bool transpose_mat; - - bool is_mat_vec = (n == 1); - - if (is_mat_vec) { - mat = lhs; - vec = rhs; - rows = m; - transpose_mat = transpose_lhs; - } else { - mat = rhs; - vec = lhs; - rows = n; - transpose_mat = !transpose_rhs; - } - - MatVec(out, mat, vec, rows, cols, transpose_mat); -} - -} // namespace - -namespace xla { - -void EigenMatVecF32(float* out, float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { - assert((m == 1 || n == 1) && "not a matrix-vector multiply"); - DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -void EigenMatVecF64(double* out, double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { - assert((m == 1 || n == 1) && "not a matrix-vector multiply"); - DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h index 1bd8dfb377acc1f7cfbe9a92773f87f0ef25de3a..70eb98c54169824e220d9287753c0849362eade6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h @@ -16,10 +16,86 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#include "third_party/eigen3/Eigen/Core" + #include "tensorflow/core/platform/types.h" namespace xla { +namespace detail { + +using tensorflow::int32; +using tensorflow::int64; + +// Does mat * x or mat^T * x. +template +void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols, + int32 transpose) { + // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to + // be faster (b/30223679). See also: the matmul op kernel in TensorFlow, + // which implements the same optimization. + using Matrix = Eigen::Matrix; + using MatrixMap = Eigen::Map; + + using Vector = Eigen::Matrix; + using VectorMap = Eigen::Map; + + auto x = VectorMap(x_buf, cols); + auto out = VectorMap(out_buf, rows); + + int64 mat_rows = rows; + int64 mat_cols = cols; + + if (transpose) { + std::swap(mat_rows, mat_cols); + } + + auto mat = MatrixMap(mat_buf, mat_rows, mat_cols); + + if (transpose) { + out = mat.transpose() * x; + } else { + out = mat * x; + } +} + +// Converts matmul-style args to matvec. +template +void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, + int32 transpose_lhs, int32 transpose_rhs) { + // If the input is in the form x * A, where x is the vector, then bring A back + // over to the left hand side. We make use of the identity + // + // (x * A)^T = A^T * x^T + // + // We do not need to take the transpose of x or of the result since taking + // the transpose of a vector does not change the memory layout. + const int64 cols = k; + + T* mat; + T* vec; + int64 rows; + bool transpose_mat; + + bool is_mat_vec = (n == 1); + + if (is_mat_vec) { + mat = lhs; + vec = rhs; + rows = m; + transpose_mat = transpose_lhs; + } else { + mat = rhs; + vec = lhs; + rows = n; + transpose_mat = !transpose_rhs; + } + + MatVec(out, mat, vec, rows, cols, transpose_mat); +} + +} // namespace detail + // Performs a matrix-vector multiplication using Eigen. 'lhs' and 'rhs' are // pointers to buffers containing input matrices in column-major order. 'out' is // a pointer to a buffer sufficiently large to hold the result of the @@ -30,15 +106,15 @@ namespace xla { // // TODO(b/64684907): Compare runtime performance of these functions with dot // simplification. -void EigenMatVecF32(float* out, float* lhs, float* rhs, tensorflow::int64 m, - tensorflow::int64 n, tensorflow::int64 k, - tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); - -void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m, - tensorflow::int64 n, tensorflow::int64 k, - tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); +template +void EigenMatVec(T* out, T* lhs, T* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + assert((m == 1 || n == 1) && "not a matrix-vector multiply"); + detail::DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h index f216bd0152aa93b8753d881938c63a9cabea899b..44b201725b2c724f48c1a3f0373c41e76211e0c2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..2613ddb12704aea7d0884c6c8c062dc028383639 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000000000000000000000000000000000..dcd133d012cf074a4cd2f550585881388bea6156 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index ee8eb081556d60fcf6537b1036a4a5825c4c7bf6..17303e2f0d34e531a3a56aa147608b949e0f43ae 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -57,26 +57,38 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, C = A.contract(B, dims); } +template +void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + if (m == 1 || n == 1) { + xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } else { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); + } +} + } // namespace +void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h index 029eb9514287d8c69cde2cfb06e0d56e78d6f165..82a1fcce594fa5b04f4fe459870991863c32a91a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { @@ -25,6 +26,12 @@ extern "C" { // 'out' is a pointer to a buffer sufficiently large to hold the result of the // operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and // out is m x n. +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index b3f4609d465efb4df8921abb684bafd263fe040f..167aa4adda995a259190a932a76a34ca5883444c 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -48,13 +48,13 @@ int main(int argc, char** argv) { client->TransferToServer(*param1_literal).ConsumeValueOrDie(); // Build computation. - xla::ComputationBuilder builder(client, ""); + xla::XlaBuilder builder(""); auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto add = builder.Add(p1, p0, {0}); - xla::StatusOr computation_status = builder.Build(); - xla::Computation computation = computation_status.ConsumeValueOrDie(); + xla::StatusOr computation_status = builder.Build(); + xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); // Execute and transfer result of computation. xla::ExecutionProfile profile; diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 61b408b8c24dded134218110d4e219c31f1685a8..d12c5396148d32adb178b955a34e050cc56784da 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -20,12 +20,13 @@ namespace cpu { std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { // Gather outer-most dims where dim_size >= 'target_partition_count'. - // Note: always leave inner-dim static for vectorization/optimizations. + // This may include the inner-dim as LLVM can vectorize loops with dynamic + // bounds. std::vector outer_dims; int64 outer_dim_size = 1; // TODO(b/27458679) Consider reserving enough minor dimensions (based on // target vector register width) to enable vector instructions. - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { const int64 dimension = shape_.layout().minor_to_major(i); outer_dims.push_back(dimension); outer_dim_size *= shape_.dimensions(dimension); @@ -114,7 +115,7 @@ ShapePartitionIterator::ShapePartitionIterator( for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { const int64 dim_size = shape_.dimensions(dimensions_[i]); dimension_partition_sizes_[i] = - std::max(1LL, dim_size / dimension_partition_counts_[i]); + std::max(int64{1}, dim_size / dimension_partition_counts_[i]); } // Calculate the partition strides for each dimension. diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h index 33d02b70e61e3311c9af934e80874939fbe3adae..db2cda2936c834ad79a529bef6596d2f33822a3d 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.h +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -38,7 +38,7 @@ namespace cpu { // // [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8) // -// Note that the last partition has residule because the dimension size is +// Note that the last partition has residual because the dimension size is // not a multiple of the partition count. // // diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ee0c53fa6d7c41481a53350e57e5844dea2644c1..ae80a6f4977f85cfd9f872734fd0a69432a1f382 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -30,105 +30,65 @@ class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; - void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + void RunR2Test(const Shape& shape, int64 max_target_partition_count, + const std::vector* expected_partitions) { ShapePartitionAssigner assigner(shape); - // Check all partitions of outer dimension. - for (int64 i = 1; i <= expected_max_partition_count; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), - assigner.Run(/*target_partition_count=*/i))); + // Iterate through 1..max_target_partition_count. + for (int64 i = 1; i <= max_target_partition_count; ++i) { + std::vector actual_partitions = + assigner.Run(/*target_partition_count=*/i); + EXPECT_THAT(actual_partitions, expected_partitions[i - 1]); } - // Check target_partition_count > outer dimension size. - EXPECT_TRUE(ContainersEqual( - Vec({expected_max_partition_count}), - assigner.Run( - /*target_partition_count=*/expected_max_partition_count + 1))); } }; TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); + std::vector expected_partitions[] = {{1} /* 1 */, {1, 2} /* 2 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); + std::vector expected_partitions[] = { + {1} /* 1 */, {1, 2} /* 2 */ + }; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); + std::vector expected_partitions[] = {{1} /* 1 */, {2} /* 2 */, + {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 6, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 4, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 5; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {4, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {4, 3} /* 12 */, + {4, 3} /* 13 */, {4, 3} /* 14 */, {5, 3} /* 15 */, {4, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}), 16, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 3; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */, + {2, 2} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {3, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {3, 4} /* 12 */, + {3, 4} /* 13 */, {3, 4} /* 14 */, {3, 5} /* 15 */, {3, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}), 16, + expected_partitions); } class ShapePartitionIteratorTest : public HloTestBase { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index aa8d4ad9dc51b2c1f500898f8bbd2c548f710643..c4c90515ac7ec2721cb9ea48d42e3c5080e249af 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -31,10 +31,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -70,24 +74,33 @@ llvm::StringRef GetHostCpuName() { } } // namespace +/*static*/ std::unique_ptr +SimpleOrcJIT::InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level) { + std::unique_ptr target_machine( + llvm::EngineBuilder() + .setTargetOptions(target_options) + .setOptLevel(opt_level) + .selectTarget( + /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", + /*MCPU=*/GetHostCpuName(), + /*MAttrs=*/DetectMachineAttributes())); + CHECK(target_machine != nullptr); + return target_machine; +} + SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook) - : target_machine_( - CHECK_NOTNULL(llvm::EngineBuilder() - .setTargetOptions(target_options) - .setOptLevel(opt_level) - .selectTarget( - /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/GetHostCpuName(), - /*MAttrs=*/DetectMachineAttributes()))), + : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - execution_session_(string_pool_), symbol_resolver_(llvm::orc::createLegacyLookupResolver( + execution_session_, [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); }, @@ -177,19 +190,30 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); + registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); + #undef REGISTER_CPU_RUNTIME_SYMBOL // Register both the f32 (float) and f64 (double) versions of a libm symbol. diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index d0011e0a185cd0284d2f9334594f6e06d9284be7..1851a3ee0bb97b4860605d7211a6ae70ac88686b 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -46,9 +46,7 @@ namespace cpu { class SimpleOrcJIT { public: using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; - using CompileFtor = - std::function( - llvm::Module&)>; + using CompileFtor = std::function; using CompileLayerT = llvm::orc::IRCompileLayer; using VModuleKeyT = llvm::orc::VModuleKey; @@ -97,6 +95,12 @@ class SimpleOrcJIT { return &external_constant_pool_; } + // Creates an llvm::TargetMachine suitable for JITting code that will run on + // the current machine. + static std::unique_ptr InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level); + private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); @@ -104,7 +108,6 @@ class SimpleOrcJIT { std::unique_ptr target_machine_; const Disassembler disassembler_; const llvm::DataLayout data_layout_; - llvm::orc::SymbolStringPool string_pool_; llvm::orc::ExecutionSession execution_session_; std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index eeb049737dddd11ef2ce229df772baec3ac03dd8..a0cd8ee2d2be10bcee9c2e216e24908d949e2d7b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { namespace cpu { -llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( +llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor( const llvm::Function& function) const { auto it = target_transform_info_cache_.find(&function); if (it == target_transform_info_cache_.end()) { @@ -31,5 +31,30 @@ llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( return &it->second; } +int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation( + int64 size_bytes) const { + // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16 + // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is + // complicated by our desire to allow parameter buffers created by clients to + // be consumed directly by the JIT. + if (size_bytes == 0) { + // No need to align empty buffers. + return 1; + } + + const int64 kMallocAlignmentThreshold = 512; + + int pointer_size = target_machine_->getPointerSize(0); + int buffer_alignment = + size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size; + DCHECK_GT(buffer_alignment, 0); + + return buffer_alignment; +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 703942615e552dccde7ddec8c8b90e8a486652af..8b00ae9e47eeed26ffe80707b89593b267e8dbb8 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -24,43 +24,68 @@ limitations under the License. namespace xla { namespace cpu { -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// LLVM IR code generation decisions. +// Abstract interface for classes providing information about the target we're +// compiling for. class TargetMachineFeatures { public: static constexpr int kX86AvxVectorByteSize = 32; - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} + // Input and output tensor buffers must be aligned to this many bytes if we + // want to call an Eigen backed GEMM or Convolution. + static constexpr int kEigenExpectedTensorAlignment = 16; // Return the vectorization factor, which is the number of bytes of data // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { - // Ideally this should be a function of the cache line size (which we can - // get from llvm::TargetTransformInfo::getCacheLineSize) of the target - // machine. Guess a value of 128 bytes for now. - return 128; - } + virtual int vectorization_factor_in_bytes() const = 0; // Return the size of the largest vector size in bytes. We need to pass in // "function" since llvm functions can contain annotations for specializing // them to specific micro-architectures (though currently XLA does not use // this functionality). - int vector_register_byte_size(const llvm::Function& function) const { - llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); - return tti->getRegisterBitWidth(/*Vector=*/true) / 8; - } + virtual int vector_register_byte_size( + const llvm::Function& function) const = 0; // Return the number of elements of type `type` that can fit into the largest // vector register available. We need to pass in "function" since llvm // functions can contain annotations for specializing them to specific // micro-architectures (though currently XLA does not use this functionality). + virtual int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const = 0; + + // Returns the minimum alignment for a buffer of size size_bytes. + virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; + + virtual ~TargetMachineFeatures() = default; +}; + +// Implements the TargetMachineFeatures interface using an llvm::TargetMachine. +class LLVMTargetMachineFeatures : public TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + int vectorization_factor_in_bytes() const override { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } + int vector_register_num_elements(const llvm::Function& function, - PrimitiveType type) const { + PrimitiveType type) const override { return vector_register_byte_size(function) / (primitive_util::BitWidth(type) / 8); } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override; + private: llvm::TargetTransformInfo* GetTargetTransformInfoFor( const llvm::Function& function) const; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h new file mode 100644 index 0000000000000000000000000000000000000000..ffc6927cbe1a2b6fd1a1ca3aac9b6e047741c2af --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { +// Delegates calls to minimum_alignment_for_allocation to a user provided +// std::function, crashes on all other methods. +// +// Primarily useful for testing. +class TargetMachineFeaturesWithFakeAlignmentLogic + : public TargetMachineFeatures { + public: + explicit TargetMachineFeaturesWithFakeAlignmentLogic( + std::function fake_alignment_logic) + : fake_alignment_logic_(std::move(fake_alignment_logic)) {} + + int vectorization_factor_in_bytes() const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int64 minimum_alignment_for_allocation(int64 size_bytes) const override { + return fake_alignment_logic_(size_bytes); + } + + private: + std::function fake_alignment_logic_; +}; +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..66ae5ef0f66e90982102d73e474f5d0582f5415c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -0,0 +1,176 @@ +# Description: +# Tests for LLVM-based CPU backend for XLA. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "cpu_codegen_test", + testonly = True, + hdrs = ["cpu_codegen_test.h"], + deps = [ + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_fusion_test", + srcs = ["cpu_fusion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_bytesizeof_test", + srcs = ["cpu_bytesizeof_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_external_constants_test", + srcs = ["cpu_external_constants_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "cpu_noalias_test", + srcs = ["cpu_noalias_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm//:core", + ], +) + +tf_cc_test( + name = "cpu_intrinsic_test", + srcs = ["cpu_intrinsic_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_eigen_dot_operation_test", + srcs = ["cpu_eigen_dot_operation_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_infeed_test", + srcs = ["cpu_infeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_literal_caching_test", + srcs = ["cpu_literal_caching_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_outfeed_test", + srcs = ["cpu_outfeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5bbe7677ace67c0500750d1911bf98ff791aa60 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +class CpuByteSizeOfTest : public ::testing::Test {}; + +TEST_F(CpuByteSizeOfTest, ARM32) { + llvm::DataLayout data_layout( + "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"); + auto tuple_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout), + data_layout.getPointerSize(0 /* default address space */)); +} + +TEST_F(CpuByteSizeOfTest, ARM64) { + llvm::DataLayout data_layout("e-m:e-i64:64-i128:128-n32:64-S128"); + auto tuple_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout), + data_layout.getPointerSize(0 /* default address space */)); +} diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h new file mode 100644 index 0000000000000000000000000000000000000000..7c8d07a10baf55dba8cbd347ebe1459b78e268e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ + +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" + +namespace xla { +namespace cpu { + +// Tests that verify IR emitted by the CPU backend is as expected. +class CpuCodegenTest : public LLVMIRGenTestBase {}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fcce42eaa4599eb8a6dacc1bd39eefd39aa5e50 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that we call into Eigen for dot operations as needed. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +struct DotTestSpec { + PrimitiveType primitive_type; + string filecheck_lines; +}; + +string DotTestSpecToString(const ::testing::TestParamInfo& info) { + return PrimitiveType_Name(info.param.primitive_type); +} + +class CpuEigenDotOperationTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + protected: + void CompileAndCheck(std::unique_ptr entry_computation, + const string& filecheck_lines) { + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(entry_computation)); + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, + filecheck_lines, + /*match_optimized_ir=*/true); + } +}; + +TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { + HloComputation::Builder builder(TestName()); + DotTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128}); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "input")); + + builder.AddInstruction( + HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + CompileAndCheck(builder.Build(), spec.filecheck_lines); +} + +TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { + HloComputation::Builder builder(TestName()); + DotTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128}); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "input")); + HloInstruction* lhs_transposed = builder.AddInstruction( + HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); + + builder.AddInstruction( + HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + CompileAndCheck(builder.Build(), spec.filecheck_lines); +} + +std::vector GetDotTestCases() { + std::vector result; + result.push_back( + {F16, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF16)"}); + result.push_back( + {F32, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF32)"}); + result.push_back( + {F64, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF64)"}); + return result; +} + +INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation, + CpuEigenDotOperationTest, + ::testing::ValuesIn(GetDotTestCases()), + DotTestSpecToString); + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..faac927027c48e44eb8ff1fcc4109fbc177fc579 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class CpuExternalConstantsTest : public CpuCodegenTest { + public: + void TestWithArray(int64 rows, int64 cols, const char* filecheck_pattern) { + HloComputation::Builder builder(TestName()); + + Array2D backing_array(rows, cols); + backing_array.FillUnique(); + + auto shape = ShapeUtil::MakeShape(F32, {rows, cols}); + + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2FromArray2D(backing_array))); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); + + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CompileAndVerifyIr(std::move(module), filecheck_pattern, + /*match_optimized_ir=*/false); + } +}; + +TEST_F(CpuExternalConstantsTest, Basic) { + TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( +CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +)"); +} + +TEST_F(CpuExternalConstantsTest, BasicNegative) { + // The constant array in this test case is small enough that there is no need + // to externalize it. + TestWithArray(/*rows=*/4, /*cols=*/4, R"( +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [16 x float] {{.*}}, align 8 +)"); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..23e7a3de4d8188a3add259582e11030539e154c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +class CpuFusionTest : public HloTestBase { + protected: + CpuFusionTest() {} + + ErrorSpec error_spec_{0.0001, 1e-5}; +}; + +TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { + auto builder = HloComputation::Builder(TestName()); + auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + Shape vshape = input_literal1->shape(); + + auto input1 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal1))); + auto input2 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal2))); + + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kAdd, input1, input2)); + builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + auto fusion_instruction = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + EXPECT_EQ(HloOpcode::kNegate, + fusion_instruction->fused_expression_root()->opcode()); + // There should be four fused instructions: 2 parameters, the add, and the + // negate. + EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); +} + +TEST_F(CpuFusionTest, FuseElementwiseOpChain) { + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + Shape vshape = input_literal->shape(); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); + auto floor = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + auto fusion_instruction = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + EXPECT_EQ(HloOpcode::kMultiply, + fusion_instruction->fused_expression_root()->opcode()); + // There should be 7 fused instructions: 2 parameters and the fused + // operations. + EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, + error_spec_); +} + +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { + // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the + // middle. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + Shape vshape = input_literal->shape(); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + + auto cshape = ShapeUtil::MakeShape(F32, {6}); + auto concatenate = builder.AddInstruction( + HloInstruction::CreateConcatenate(cshape, {ceil, ceil}, /*dimension=*/0)); + + // Build an x+y computation to use in a reduce. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto embedded_builder = HloComputation::Builder("f32+f32"); + embedded_builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, + embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")), + embedded_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "y")))); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); + + // This is a nop reduction. + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + cshape, + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), + /*init_value=*/ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{1}, add_f32)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); + auto floor = builder.AddInstruction( + HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); + + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + + auto fusion_instruction1 = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); + EXPECT_EQ(HloOpcode::kMultiply, + fusion_instruction1->fused_expression_root()->opcode()); + // There should be 5 fused instructions in the root fusion instruction: 2 + // parameters, multiply, floor, and exp. + EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + << fusion_instruction1->fused_instructions_computation()->ToString(); + + auto fusion_instruction2 = reduce->operand(0); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); + EXPECT_EQ(HloOpcode::kReshape, + fusion_instruction2->fused_expression_root()->opcode()); + // There should be 5 fused instructions in the second fusion instruction: 1 + // parameter, negate, ceil, concat, and reshape. + EXPECT_EQ(5, fusion_instruction2->fused_instruction_count()) + << fusion_instruction2->fused_instructions_computation()->ToString(); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, + *result, error_spec_); +} + +TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { + // Test that the operands of an instruction to be fused are considered in the + // proper order to avoid duplication. Test input: + // + // constant = {...} + // negate = neg(constant) + // ceil = ceil(negate) + // add1 = add(negate, ceil) + // add2 = add(ceil, negate) + // + // In this example, the operands of both add1 and add2 should be fused in the + // order {ceil, negate} even though they have different orders in their + // operand vectors. Test for this problem by counting the number of nodes in + // each fusion instruction to ensure that negate is not duplicated. + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({1.0, 2.0, 3.0}); + Shape vshape = input_literal->shape(); + + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, constant)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, negate, ceil)); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, ceil, negate)); + + // Tie together the two adds with a tuple to create a single root. + auto result = + builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); + + // Create computation and module. + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Run fusion. + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + auto fusion1 = result->operand(0); + auto fusion2 = result->operand(1); + EXPECT_EQ(HloOpcode::kFusion, fusion1->opcode()); + EXPECT_EQ(HloOpcode::kFusion, fusion2->opcode()); + + // Each fusion instruction should have 4 fused instruction inside: add, ceil, + // negate, and the fused parameter. + EXPECT_EQ(4, fusion1->fused_instruction_count()); + EXPECT_EQ(4, fusion2->fused_instruction_count()); + + // Each fusion instruction should have one parameter and the parameter should + // be the constant. + EXPECT_EQ(1, fusion1->operand_count()); + EXPECT_EQ(constant, fusion1->operand(0)); + EXPECT_EQ(1, fusion2->operand_count()); + EXPECT_EQ(constant, fusion2->operand(0)); +} + +TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { + // Verify that expensive operations will not be fused if the fusion results in + // duplication. Test code: + // + // constant = 42.0 + // exp1 = exp(constant) + // negate1 = negate(exp1) + // exp2 = exp(constant) + // negate2 = negate(exp2) + // tuple = tuple(negate1, negate2, exp2) + // + // exp1 should be fused down into negate1, but exp2 will not be fused into + // negate2 because this will result in duplication of the expensive exp + // computation. The duplication is caused by the other use of exp2 in the + // tuple. + auto builder = HloComputation::Builder(TestName()); + auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + Shape shape = constant->shape(); + + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp1)); + + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant)); + auto negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp2)); + + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({negate1, negate2, exp2})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The only fusion instruction should be operand 0 of the tuple (formerly + // negate1). + EXPECT_EQ(HloOpcode::kFusion, tuple->operand(0)->opcode()); + EXPECT_EQ(HloOpcode::kNegate, tuple->operand(1)->opcode()); + EXPECT_EQ(HloOpcode::kExp, tuple->operand(2)->opcode()); + + auto fusion_inst = tuple->operand(0); + // There should be three fused instructions: negate2, exp2, and the fused + // parameter. + EXPECT_EQ(3, fusion_inst->fused_instruction_count()); + EXPECT_EQ(1, fusion_inst->operand_count()); + EXPECT_EQ(constant, fusion_inst->operand(0)); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dd63b998e9b6d04981ec6f7300c883c9b23b154f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -0,0 +1,294 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class InfeedTest : public ClientLibraryTestBase { + protected: + // Transfers the given literal to the infeed interface of the device, and + // check if the returned data from Infeed HLO is same as the literal. + void TestInfeedRoundTrip(const Literal& literal) { + // TODO(b/31037751) Explicitly reset the Infeed state so that the + // test is not affected by the state from the previous tests by + // adding ClearInfeed if necessary when it is implemented. For now + // don't use ResetDevice since it is not implemented on CPU. + ASSERT_IS_OK(client_->TransferToInfeed(literal)); + XlaBuilder builder(TestName()); + builder.Infeed(literal.shape()); + if (ShapeUtil::IsTuple(literal.shape())) { + // TODO(b/30609564): Use ComputeAndCompareLiteral instead. + ComputeAndCompareTuple(&builder, literal, {}); + } else { + ComputeAndCompareLiteral(&builder, literal, {}); + } + } +}; + +TEST_F(InfeedTest, SingleInfeedR0Bool) { + TestInfeedRoundTrip(*Literal::CreateR0(true)); +} + +TEST_F(InfeedTest, SingleInfeedR1U32) { + TestInfeedRoundTrip(*Literal::CreateR1({1, 2, 3})); +} + +TEST_F(InfeedTest, SingleInfeedR2F32) { + TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); +} + +TEST_F(InfeedTest, SingleInfeedR3F32) { + TestInfeedRoundTrip( + *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { + const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); + const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); + + TestInfeedRoundTrip( + *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); + + TestInfeedRoundTrip( + *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); +} + +TEST_F(InfeedTest, SingleInfeedR4S32) { + TestInfeedRoundTrip(*Literal::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +TEST_F(InfeedTest, SingleInfeedTuple) { + TestInfeedRoundTrip( + *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR0(false).get()})); +} + +TEST_F(InfeedTest, SingleInfeedEmptyTuple) { + TestInfeedRoundTrip(*Literal::MakeTuple({})); +} + +// Tests Infeed operation used in a while loop, as in the code below. The +// computation is launched asynchronously, and then infeed data is transferred. +// +// float acc = 0.0f; +// while (acc < 40.0f) { +// acc += reduce_add(Infeed()); +// } +// return acc; +// TODO(b/30671675) enable this test once asynchronous execution is +// implemented for CPU. +TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { + XlaBuilder builder(TestName()); + const auto infeed_shape = ShapeUtil::MakeShape(F32, {3}); + const auto result_shape = ShapeUtil::MakeShape(F32, {}); + + // Create a computation for the condition: repeat until (prev < 40.0f) holds. + XlaComputation condition; + { + XlaBuilder builder("condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(40.0f), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + // Create a computation for the body: add the reduced value of the Infeed + // data to the result variable. + XlaComputation body; + { + XlaBuilder builder("body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto infeed = builder.Infeed(infeed_shape); + auto addend = + builder.Reduce(infeed, builder.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + builder.Add(prev, addend); + body = builder.Build().ConsumeValueOrDie(); + } + // Create a While node with computations for the condition and the body. + auto init = builder.ConstantR0(0.0f); + builder.While(condition, body, init); + + // Build and asynchronously launch the computation. + auto computation = builder.Build().ConsumeValueOrDie(); + std::unique_ptr result; + tensorflow::Thread* computation_thread = + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions{}, "computation_thread", [&] { + result = client_->Execute(computation, {}, &execution_options_) + .ValueOrDie(); + }); + + // Send 5 Infeed data of shape F32[3]. + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({1, 2, 3}))); + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({4, 5, 6}))); + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*Literal::CreateR1({10, 11, 12}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*Literal::CreateR1({13, 14, 15}))); + + delete computation_thread; // Joins the thread. + auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); + + // Only the first 3 infeed data should be added. + LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); +} + +// Tests two Infeed operations with a total order. The order is enforced by +// using the result of the first while loop as the initial value of the second +// while loop. The shapes of both Infeeds are Tuples, where the first tuple +// element (R1F32) is for the data to reduce and accumulate, and the second +// tuple element (PRED) to indicate whether the loop should continue. The +// computation is launched asynchronously, and then infeed data is transferred. +// +// float acc = 0.0f; +// continue = true; +// while (!continue) { +// (data, continue) = Infeed(shape1); +// acc += reduce_add(data) +// } +// continue = true; +// while(!continue) { +// (data, continue) = Infeed(shape2); +// acc += reduce_add(data) +// } +// return acc; +// TODO(b/30671675) enable this test once asynchronous execution is +// implemented for CPU. +TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { + XlaBuilder builder(TestName()); + const auto infeed1_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(PRED, {})}); + const auto infeed2_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(PRED, {})}); + const auto result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(PRED, {})}); + + // Create a computation for the condition: repeat until the second tuple + // element is false. + XlaComputation condition; + { + XlaBuilder builder("condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.GetTupleElement(prev, 1); + condition = builder.Build().ConsumeValueOrDie(); + } + + // A lambda that builds the body computation of a while loop with the given + // infeed shape, and returns the computation with the ownership. + // + // The body adds the reduced value of the Infeed data (first tuple element) + // to the previous accumulator, and returns the accumulator and the continue + // flag (second tuple element) as a tuple. + const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + XlaComputation body; + XlaBuilder builder("body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto infeed = builder.Infeed(infeed_shape); + auto addend = builder.Reduce( + builder.GetTupleElement(infeed, 0), builder.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + auto result = builder.Add(builder.GetTupleElement(prev, 0), addend); + builder.Tuple({result, builder.GetTupleElement(infeed, 1)}); + return builder.Build().ConsumeValueOrDie(); + }; + + // Create the first while loop with infeed1_shape. + auto init = builder.Tuple( + {builder.ConstantR0(0.0f), builder.ConstantR0(true)}); + auto while1 = builder.While(condition, build_body(infeed1_shape), init); + auto result1 = builder.Tuple( + {builder.GetTupleElement(while1, 0), builder.ConstantR0(true)}); + + // Create the second while loop with infeed2_shape. Note that the result from + // the first while loop is used as the initial value. + auto while2 = builder.While(condition, build_body(infeed2_shape), result1); + builder.GetTupleElement(while2, 0); + + // Build the computation. + auto computation = builder.Build().ConsumeValueOrDie(); + + // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({3, 4}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({5, 6}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({7, 8}).get(), + Literal::CreateR0(false).get()}))); + + // Asynchronously launch the execution on the device. + std::unique_ptr result; + tensorflow::Thread* computation_thread = + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions{}, "computation_thread", [&] { + result = client_->Execute(computation, {}, &execution_options_) + .ValueOrDie(); + }); + + // Wait for a second to ensure testing that the execution is waiting on the + // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). + sleep(1); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({7, 8, 9}).get(), + Literal::CreateR0(false).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({4, 5, 6}).get(), + Literal::CreateR0(true).get()}))); + + // Wait for the execution to be done, and transfer the result. + delete computation_thread; // Joins the thread. + auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); + + // Only the first 6 infeed data should be added. + LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..973aac8766f5aabca15e5173b43480c113c100dd --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +const char* const kTriple_x86_64 = "x86_64-pc-linux"; +const char* const kTriple_android_arm = "armv7-none-android"; + +struct IntrinsicTestSpec { + HloOpcode opcode; + tensorflow::StringPiece triple; + tensorflow::StringPiece features; + tensorflow::StringPiece check_lines; +}; + +// Tests that unary functions get lowered using intrinsic calls. +class CpuUnaryIntrinsicTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static string Name(const ::testing::TestParamInfo& info) { + auto spec = info.param; + + string opcode = HloOpcodeString(spec.opcode); + opcode[0] = toupper(opcode[0]); + + string triple{spec.triple.data(), spec.triple.size()}; + if (triple == kTriple_x86_64) { + triple = "x86_64"; + } else if (triple == kTriple_android_arm) { + triple = "android_arm"; + } else { + triple = "Unknown"; + } + + string features{spec.features.data(), spec.features.size()}; + if (!features.empty()) { + std::replace_if(features.begin(), features.end(), + [](char c) { return c != '_' && !isalnum(c); }, '_'); + } else { + features = ""; + } + + return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), + features.empty() ? "" : "_With", + features.c_str()); + } +}; + +// Creates a module with a call to the unary op, and tests if the +// compiler replaced it with a call to the intrinsic. +TEST_P(CpuUnaryIntrinsicTest, DoIt) { + HloComputation::Builder builder(TestName()); + IntrinsicTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(F32, {1024}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + builder.AddInstruction( + HloInstruction::CreateUnary(param_shape, spec.opcode, param)); + std::unique_ptr computation = builder.Build(); + + string triple{spec.triple.data(), spec.triple.size()}; + string features{spec.features.data(), spec.features.size()}; + + CpuAotCompilationOptions options{ + /*triple=*/triple, /*cpu_name=*/"", /*features=*/features, + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + string check_lines{spec.check_lines.data(), spec.check_lines.size()}; + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, check_lines, + /*match_optimized_ir=*/true); +} + +IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { + // The intrinsics are always inlined, so we match a line from it instead of + // a function call. + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_x86_64, "", + R"(CHECK: fmul fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_x86_64, "+avx", + R"(CHECK: fmul fast <8 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_android_arm, "+neon", + R"(CHECK: fmul fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_x86_64, "", + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_x86_64, "+avx", + R"(CHECK: fcmp fast uge <8 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_android_arm, "", + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_x86_64, "", + R"(CHECK: fadd fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_x86_64, "+avx", + R"(CHECK: fadd fast <8 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_android_arm, "", + R"(CHECK: fadd fast <4 x float> )"}}; + +INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation, + CpuUnaryIntrinsicTest, + ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), + CpuUnaryIntrinsicTest::Name); + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27044b1d62027e3b83744c486cb790269e505aff --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuDuplicateConstantsTest : public CpuCodegenTest {}; + +TEST_F(CpuDuplicateConstantsTest, RepeatedArrayConstants) { + // We use a while loop here to force the two constant HloInstructions to be in + // different computations. Otherwise the HLO optimizer itself CSEs them. + const string hlo_text = R"( +HloModule RepeatedConstants + +while_body { + arg_body = f32[2,3,2] parameter(0) + ROOT const = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) +} + +while_cond { + arg_cond = f32[2,3,2] parameter(0) + ROOT unknown = pred[] infeed() +} + +ENTRY main { + param = f32[2,3,2] parameter(0) + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body + + out0 = () outfeed(f32[2,3,2] const_a) + ROOT out1 = () outfeed(f32[2,3,2] const_b) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [12 x float] +CHECK-NOT: private constant [12 x float] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { + // We use a while loop here to force the two constant HloInstructions to be in + // different computations. Otherwise the HLO optimizer itself CSEs them. + const string hlo_text = R"( +HloModule RepeatedConstants + +while_body { + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) +} + +while_cond { + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT unknown = pred[] infeed() +} + +ENTRY main { + param = f32[2,3,2] parameter(0) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body + + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [1 x float] +CHECK: private constant [2 x float] +CHECK-NOT: private constant [1 x float] +CHECK-NOT: private constant [2 x float] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b6b0ed74065615fb9e47a0ec3c6c4ab078e45c4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { + +class CpuNoAliasTest : public CpuCodegenTest {}; + +// Creates a simple HLO ir_module (runs concat(concat(x, y), x)), and then +// inspects the aliasing information for loads to its buffers. +TEST_F(CpuNoAliasTest, Concat) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + HloInstruction* concat1 = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1)); + HloInstruction* concat2 = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 6}), {concat1, param_x}, 1)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. + auto status_or_buffer_assn = BufferAssigner::Run( + hlo_module.get(), MakeUnique(hlo_module.get()), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; }); + ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); + + llvm::LLVMContext context; + llvm_ir::AliasAnalysis aa(*hlo_module, *status_or_buffer_assn.ValueOrDie(), + &context); + + // Construct an LLVM module containing loads that we annotate as being from + // the buffers in the HLO module. We'll inspect these loads to ensure that + // they have the expected alias information. + llvm::Module ir_module("test", context); + llvm::Function* func = llvm::cast( + ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); + llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); + llvm::IRBuilder<> ir_builder(bb); + auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); + llvm_ir::IrArray::Index zero2D({zero, zero}); + + llvm::ArrayType* array2d_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm::Type::getFloatTy(context), 100), 100); + + { + llvm::Value* param_x_val = + ir_module.getOrInsertGlobal("param_x", array2d_type); + llvm_ir::IrArray param_x_array(param_x_val, param_shape); + aa.AddAliasingInformationToIrArray(*param_x, ¶m_x_array); + param_x_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_param_x_array"); + } + + { + llvm::Value* concat1_val = + ir_module.getOrInsertGlobal("concat1", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 4}); + llvm_ir::IrArray concat1_array(concat1_val, shape); + aa.AddAliasingInformationToIrArray(*concat1, &concat1_array); + concat1_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_concat1_array"); + } + + { + llvm::Value* concat2_val = + ir_module.getOrInsertGlobal("concat2", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 6}); + llvm_ir::IrArray concat2_array(concat2_val, shape); + aa.AddAliasingInformationToIrArray(*concat2, &concat2_array); + concat2_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_concat2_array"); + } + + // Check the AA info in the loads. + const char* filecheck_pattern = R"( + CHECK: %read_param_x_array = load {{.*}} !noalias [[param_x_noalias:![0-9]+]] + CHECK: %read_concat1_array = load {{.*}} !alias.scope [[concat1_scope:![0-9]+]], !noalias [[concat1_noalias:![0-9]+]] + CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] + CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 + CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} + CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} + )"; + + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_match, + RunFileCheck(llvm_ir::DumpModuleToString(ir_module), filecheck_pattern)); + EXPECT_TRUE(filecheck_match); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ee279290b6fcfe775ce9867d424b1c031f5d2bd --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuOutfeedTest : public CpuCodegenTest {}; + +TEST_F(CpuOutfeedTest, OutfeedRoot) { + const string hlo_text = R"( +HloModule Outfeed + +ENTRY main { + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + + ROOT out = () outfeed(f32[2,3,2] const_a) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [12 x float] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 150db1cb6edec1af6724a8bca6a5f6272f1a7416..c444d151858d3a152a01b99657ffae89ebc6b487 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -370,6 +370,9 @@ std::vector VectorSupportLibrary::ComputeHorizontalSums( std::vector VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( std::vector vectors, llvm::Value* init_values) { + // vectors are N llvm vector values, each with N elements. + int64 lane_width = vectors.size(); + while (vectors.size() != 2) { std::vector new_vectors; for (int i = 0; i < vectors.size(); i += 2) { @@ -390,10 +393,14 @@ VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( high = AddInternal(ExtractHighHalf(init_values), high); } + // `low` has the first `lane_width / 2` horizontal reductions, and `high` has + // the next `lane_width / 2` horizontal reductions. + std::vector results; - for (int i = 0; i < 8; i++) { + for (int i = 0; i < lane_width; i++) { llvm::Value* scalar_result = ir_builder()->CreateExtractElement( - i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + i < (lane_width / 2) ? low : high, + ir_builder()->getInt32(i % (lane_width / 2)), name()); results.push_back(scalar_result); } @@ -420,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const { void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } + +TileVariable::TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value) { + for (llvm::Value* initial_vector_value : initial_value) { + storage_.emplace_back(vector_support, initial_vector_value); + } +} + +std::vector TileVariable::Get() const { + std::vector result; + c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); + return result; +} + +void TileVariable::Set(tensorflow::gtl::ArraySlice value) { + CHECK_EQ(value.size(), storage_.size()); + for (int64 i = 0, e = value.size(); i < e; i++) { + storage_[i].Set(value[i]); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 6479bf76aab581ae3ec2923d98dab53720cab203..49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -143,6 +144,12 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements, int64 scale) { + return ComputeOffsetPointer( + base_pointer, + ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { return ComputeOffsetPointer(base_pointer, @@ -311,6 +318,21 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; + +// This wraps a set of alloca-backed stack variables that can, as a whole, store +// a tile. A "tile" is a sequence of vectors that is typically used as a 2D +// grid of scalar values (e.g. for tiled GEMMs). +class TileVariable { + public: + TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value); + + std::vector Get() const; + void Set(tensorflow::gtl::ArraySlice value); + + private: + std::vector storage_; +}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc new file mode 100644 index 0000000000000000000000000000000000000000..d938f3a2c4b5bfdd70d5a614b9890b4d7bf050f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/despecializer.h" + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/defuser.h" +#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" + +namespace xla { + +Despecializer::Despecializer() : pipeline_("despecializer") { + // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); + pipeline_.AddPass(); + pipeline_.AddPass(); +} + +StatusOr Despecializer::Run(HloModule* module) { + return pipeline_.Run(module); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h new file mode 100644 index 0000000000000000000000000000000000000000..cc1695b7f863805e0b483478639c17cb9061310a --- /dev/null +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Creates an HloPassPipeline containing multiple HloPasses that can +// despecialize an optimized HloModule. This is useful to run an HloModule +// optimized for one specific platform on a different platform (undoing platform +// specific passes) with matching numerics for comparison. +// +// Current despecialization passes are Defuser, ImplicitBroadcastRemover, +// and BFloat16MixedPrecisionRemoval. +class Despecializer : public HloPassInterface { + public: + Despecializer(); + tensorflow::StringPiece name() const override { return "despecializer"; } + StatusOr Run(HloModule* module) override; + + private: + HloPassPipeline pipeline_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_ diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 78e7aa48accdbb51a8477455f5f9c004828c068f..e228bb56bce8febcca28ae171f6de90973d020ab 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -24,45 +24,37 @@ limitations under the License. namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( - const perftools::gputools::Platform* platform, - tensorflow::gtl::ArraySlice - stream_executors) + const se::Platform* platform, + tensorflow::gtl::ArraySlice stream_executors) : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} -StatusOr -StreamExecutorMemoryAllocator::Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) { - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, +StatusOr StreamExecutorMemoryAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - perftools::gputools::DeviceMemoryBase result = - stream_executor->AllocateArray(size); + se::DeviceMemoryBase result = stream_executor->AllocateArray(size); if (size > 0 && result == nullptr) { return ResourceExhausted( "Failed to allocate request for %s (%lluB) on device ordinal %d", tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, device_ordinal); } - return result; + return OwningDeviceMemory(result, device_ordinal, this); } -tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) { - if (!mem->is_null()) { - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, +Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, + se::DeviceMemoryBase mem) { + if (!mem.is_null()) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - // We make a local copy of 'mem' so the original is not zeroed out by the - // Deallocate() call below. This gives us a better chance of - // catching double-free bugs, since Deallocate silently succeeds for null - // values. - perftools::gputools::DeviceMemoryBase mem_copy(*mem); - stream_executor->Deallocate(&mem_copy); + stream_executor->Deallocate(&mem); } - return tensorflow::Status::OK(); + return Status::OK(); } -StatusOr -StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) { +StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( + int device_ordinal) { if (device_ordinal < 0) { return InvalidArgument("device ordinal value (%d) must be non-negative", device_ordinal); diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 39dfad84c1c1c1c461c24de555ecd919cea47d83..d87b86caf0d3acaa5bf9a455cff2315cedb2496d 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -33,30 +34,44 @@ class DeviceMemoryAllocator { public: // Parameter platform indicates which platform the allocator allocates memory // on. Must be non-null. - explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform) + explicit DeviceMemoryAllocator(const se::Platform* platform) : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} + // Allocates memory on the device. + // + // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory + // must not be null. If size == 0, must return a null OwningDeviceMemory. + // // 'retry_on_failure': If false, and the first attempt to allocate the memory - // fails, the allocation should return immediately without retrying. - // An example use case is optional scratch spaces where a failure - // has only performance impact. - // Allocate() should return a null pointer for a size-0 allocation. - // Deallocate() must be a no-op for null pointers. - virtual StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure = true) = 0; - virtual tensorflow::Status Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0; + // fails, the allocation should return immediately without retrying. An + // example use case is optional scratch spaces where a failure has only + // performance impact. + virtual StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) = 0; + + // Two-arg version of Allocate(), which sets retry-on-failure to true. + // + // (We don't simply use a default argument on the virtual Allocate function + // because default args on virtual functions are disallowed by the Google + // style guide.) + StatusOr Allocate(int device_ordinal, uint64 size) { + return Allocate(device_ordinal, size, /*retry_on_failure=*/true); + } + + // Must be a nop for null pointers. + virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. - const perftools::gputools::Platform* platform() const { return platform_; } + const se::Platform* platform() const { return platform_; } // Can we call Deallocate() as soon as a computation has been scheduled on // a stream, or do we have to wait for the computation to complete first? virtual bool AllowsAsynchronousDeallocation() const = 0; protected: - const perftools::gputools::Platform* platform_; + friend class OwningDeviceMemory; + const se::Platform* platform_; }; // Default memory allocator for a platform which uses @@ -64,25 +79,26 @@ class DeviceMemoryAllocator { class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( - const perftools::gputools::Platform* platform, - tensorflow::gtl::ArraySlice - stream_executors); + const se::Platform* platform, + tensorflow::gtl::ArraySlice stream_executors); - StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure = true) override; - tensorflow::Status Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + + // Pull in two-arg overload that sets retry_on_failure to true. + using DeviceMemoryAllocator::Allocate; + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; private: - StatusOr GetStreamExecutor( - int device_ordinal); + StatusOr GetStreamExecutor(int device_ordinal); // A vector indexed by device ordinal of StreamExecutors for each device of // the allocator's platform type. If an element is nullptr, then the device // with the respective device ordinal is not supported by XLA. - std::vector stream_executors_; + std::vector stream_executors_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 56723e765048698baedc50ae7b189d0287ee56b8..ee2b455730f8f520db6652f0352f8a96291cac73 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -138,6 +138,9 @@ class DfsHloVisitorBase { virtual Status HandleExp(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleExpm1(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleFloor(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -147,6 +150,12 @@ class DfsHloVisitorBase { virtual Status HandleLog(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleClz(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleLog1p(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCos(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -188,6 +197,10 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + virtual Status HandleDomain(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; @@ -230,6 +243,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index ecda5288ee17a3856ce95f0caa327c3524fd180b..6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -35,6 +35,12 @@ class HloInstruction; // DfsHloVisitor with default action based on the HloInstruction being visited. // Users should not use this class directly, but use the type aliases // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. +// +// Do *not* add an override to this class if the opcode is covered by +// HandleElementwiseUnary/Binary. These opcode handlers dispatch to +// HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler +// here will break passes which rely on the HandleElementwiseUnary/Binary +// handling these opcodes. template class DfsHloVisitorWithDefaultBase : public DfsHloVisitorBase { @@ -70,12 +76,6 @@ class DfsHloVisitorWithDefaultBase Status HandleConcatenate(HloInstructionPtr concatenate) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstructionPtr convert) override { - return DefaultAction(convert); - } - Status HandleCopy(HloInstructionPtr copy) override { - return DefaultAction(copy); - } Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } @@ -91,9 +91,6 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleCompare(HloInstructionPtr compare) override { - return DefaultAction(compare); - } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -191,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleGenerateToken(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062 --- /dev/null +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class DfsHloVisitorWithDefaultTest : public HloTestBase {}; + +TEST_F(DfsHloVisitorWithDefaultTest, DefaultElementwiseTest) { + // Verify that HandleElementwiseBinary and HandleElementwiseUnary are called + // on the appropriate HLO ops (elementwise binary/unary ops). + + class ElementwiseTestVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo) override { + // The HLO should be neither an elementwise unary nor binary op. These + // cases are handled in HandleElementwiseBinary/Unary. + TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2)) + << hlo->ToString(); + TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1)) + << hlo->ToString(); + return Status::OK(); + } + + Status HandleElementwiseBinary(HloInstruction* hlo) override { + // HLO should be elementwise binary. + TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2) + << hlo->ToString(); + return Status::OK(); + } + Status HandleElementwiseUnary(HloInstruction* hlo) override { + // HLO should be elementwise unary. + TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1) + << hlo->ToString(); + return Status::OK(); + } + }; + + // HLO module contains are arbitrary mix of elementwise and non-elementwise + // operations. + const string& hlo_string = R"( +HloModule TestModule + +ENTRY TestComputation { + arg = f32[] parameter(0) + tuple = (f32[]) tuple(arg) + gte = f32[] get-tuple-element(tuple), index=0 + abs = f32[] abs(arg) + add = f32[] add(arg, gte) + broadcast = f32[42] broadcast(add), dimensions={} + slice = f32[0] slice(broadcast), slice={[1:2]} + copy = f32[] copy(arg) + eq = pred[] equal-to(arg, gte) + neg = f32[] negate(arg) + ROOT convert = f64[] convert(f32[] arg) +})"; + std::unique_ptr module = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) + .ConsumeValueOrDie(); + ElementwiseTestVisitor visitor; + TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4468adbadbf823f1420a8b665a26f66cb7d36b43..93fea7ead7a86bb34c449668fd88a58145681eb1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -52,6 +52,13 @@ using tensorflow::strings::StrCat; namespace { +int64 GlobalRandomValue() { + static auto* mu = new tensorflow::mutex(); + static std::mt19937_64 rng{42}; + tensorflow::mutex_lock l(*mu); + return rng(); +} + llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, int64 mantissa_bits, llvm::IRBuilder<>* ir_builder) { @@ -226,7 +233,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (primitive_util::IsIntegralType(to_type)) { return ir_builder_->CreateIntCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(to_type)); + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -293,6 +300,12 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( return operand_value; } } + case HloOpcode::kClz: { + auto is_zero_undef = ir_builder_->getFalse(); + return llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::ctlz, {operand_value, is_zero_undef}, + {operand_value->getType()}, ir_builder_); + } case HloOpcode::kSign: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); @@ -405,8 +418,12 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value); + case HloOpcode::kExpm1: + return EmitExpm1(op->shape().element_type(), operand_value); case HloOpcode::kLog: return EmitLog(op->shape().element_type(), operand_value); + case HloOpcode::kLog1p: + return EmitLog1p(op->shape().element_type(), operand_value); case HloOpcode::kCos: return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: @@ -439,17 +456,15 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { - // (x == x) && abs(x) != inf + // abs(x) o!= inf, this works because the comparison returns false if + // either operand is NaN. auto type = operand_value->getType(); - auto equal_self = - ir_builder_->CreateFCmpOEQ(operand_value, operand_value); auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); @@ -480,6 +495,22 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex( op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } + case HloOpcode::kLog1p: { + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto one = llvm::ConstantFP::get(llvm_ty, 1.0); + auto a_plus_one = ir_builder_->CreateFAdd(a, one); + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(a_plus_one, a_plus_one), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -510,6 +541,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } + case HloOpcode::kExpm1: { + // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); + auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); + auto real_result = + ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); + auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + return EmitComposeComplex(op, real_result, imag_result); + } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) @@ -962,6 +1007,28 @@ StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto negative_half = llvm::ConstantFP::get(type, -0.5); + // When x is large, the naive evaluation of ln(x + 1) is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto for_large_x, + EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. + auto for_small_x = ir_builder_->CreateFMul( + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), + x); + const auto kAntilogarithmIsSmallThreshold = 1e-4; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -980,6 +1047,29 @@ StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto half = llvm::ConstantFP::get(type, 0.5); + // When the exponent is large, the naive evaluation of e^(x) - 1 is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); + auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. + // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. + auto x_squared = ir_builder_->CreateFAdd(x, x); + auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); + auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + const auto kExponentIsSmallThreshold = 1e-5; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { @@ -1003,6 +1093,30 @@ StatusOr ElementalIrEmitter::EmitReducePrecision( ir_builder_); } +static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, + llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* shift_result, + bool saturate_to_sign_bit) { + llvm::IntegerType* integer_type = + llvm::cast(lhs->getType()); + unsigned integer_bitsize = integer_type->getBitWidth(); + llvm::ConstantInt* integer_bitsize_constant = + llvm::ConstantInt::get(integer_type, integer_bitsize); + llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0); + llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1); + llvm::Value* saturated_value; + if (saturate_to_sign_bit) { + saturated_value = ir_builder->CreateSelect( + ir_builder->CreateICmpSLT(lhs, zero), minus_one, zero); + } else { + saturated_value = zero; + } + llvm::Value* shift_amt_in_range = + ir_builder->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); + return ir_builder->CreateSelect(shift_amt_in_range, shift_result, + saturated_value); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -1050,12 +1164,27 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); - case HloOpcode::kShiftLeft: - return ir_builder_->CreateShl(lhs_value, rhs_value); + + // Shifting out bits >= the number of bits in the type being shifted + // produces a poison value in LLVM which is basically "deferred undefined + // behavior" -- doing something observable with such a value precipitates + // UB. We replace the poison value with a constant to avoid this deferred + // UB. case HloOpcode::kShiftRightArithmetic: - return ir_builder_->CreateAShr(lhs_value, rhs_value); + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateAShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/true); + case HloOpcode::kShiftLeft: + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateShl(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: - return ir_builder_->CreateLShr(lhs_value, rhs_value); + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateLShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -1130,7 +1259,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::Value* increment = ir_builder_->getInt( llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); - auto random_value = [hlo]() { + auto random_value_from_hlo = [hlo]() { const HloModule* module = hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() : hlo->parent()->parent(); @@ -1152,10 +1281,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( /*Ty=*/ir_builder_->getInt64Ty(), /*isConstant=*/false, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/ir_builder_->getInt64(random_value()), + /*Initializer=*/ir_builder_->getInt64(random_value_from_hlo()), /*Name=*/"state_ptr0"); + + // When the module config seed is 0, the expected result of a prng is a random + // value. Instead of using the random_value_from_hlo, we need a global random + // value as the graph seed. This is because if we use random_value_from_hlo + // here, then for a newly built hlo graph, it always gives the same number. uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() - : random_value(); + : GlobalRandomValue(); llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( /*M=*/*module_, /*Ty=*/ir_builder_->getInt64Ty(), @@ -1287,6 +1421,482 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( }; } +StatusOr ElementalIrEmitter::EmitElementalSelect( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + return ir_builder_->CreateSelect( + ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), + on_true_value, on_false_value); +} + +StatusOr ElementalIrEmitter::EmitElementalClamp( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * min_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * max_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } +} + +StatusOr ElementalIrEmitter::EmitElementalConcatenate( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const { + const int64 concat_dim = hlo->dimensions(0); + auto source_index = target_index; + + llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(), + AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); + llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); + auto prior_insert_point = ir_builder_->GetInsertPoint(); + + ir_builder_->SetInsertPoint(init_block); + + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + const HloInstruction* operand = hlo->operand(operand_idx); + auto true_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand", operand_idx), + ir_builder_); + auto false_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_not_from_operand", operand_idx), + ir_builder_); + auto concat_dim_size = + llvm::ConstantInt::get(source_index[concat_dim]->getType(), + operand->shape().dimensions(concat_dim)); + ir_builder_->CreateCondBr( + ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size), + true_block, false_block); + + // Create the terminator of the true block before calling operand + // generators, because they require non-degenerate basic blocks. + ir_builder_->SetInsertPoint( + llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(source_index)); + output->addIncoming(value, ir_builder_->GetInsertBlock()); + + // Subtract the size of the concat dimension of the current operand + // from the source index. + ir_builder_->SetInsertPoint(false_block); + source_index[concat_dim] = + ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + } + + ir_builder_->CreateUnreachable(); + ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + return output; +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + // Emit IR to read dynamic start indices from hlo->operand(1). + const HloInstruction* input_hlo = hlo->operand(0); + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + + // Clamp the start index so that the sliced portion fits in the operand: + // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* operand_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(operand_dim_size, output_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + } + + llvm_ir::IrArray::Index input_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR which computes: + // input_index = start_index + offset_index + input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); + } + return operand_to_generator.at(input_hlo)(input_index); +} + +StatusOr ElementalIrEmitter::EmitElementalGather( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const Shape& operand_shape = hlo->operand(0)->shape(); + const Shape& indices_shape = hlo->operand(1)->shape(); + const Shape& output_shape = hlo->shape(); + + const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers(); + + const llvm_ir::ElementGenerator& operand_generator = + operand_to_generator.at(hlo->operand(0)); + const llvm_ir::ElementGenerator& indices_generator = + operand_to_generator.at(hlo->operand(1)); + + // This is the index into `operand` that holds the element we want to + // generate. This index "unsafe" as in the components in here may be + // out of bounds. + IrArray::Index unsafe_operand_index; + + // First copy in the window indices to unsafe_operand_index. + for (int64 i = 0, e = operand_shape.dimensions_size(), + unsafe_operand_index_dim = 0; + i < e; i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + } else { + unsafe_operand_index.push_back( + index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + } + } + + // This is the index of the index vector in the gather_indices tensor. + IrArray::Index gather_index_index; + { + std::vector gather_index_index_components; + for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + gather_index_index.push_back(index[i]); + } + } + + if (gather_index_index.size() != indices_shape.dimensions_size()) { + gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + } + } + + auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, + int64 dim) { + llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( + index_component, ir_builder_->getInt64Ty()); + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = + ir_builder_->CreateAdd( + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], + gather_dim_component_extended); + }; + + if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, 0); + } else { + int64 index_vector_size = + indices_shape.dimensions(dim_numbers.index_vector_dim()); + for (int64 i = 0; i < index_vector_size; i++) { + gather_index_index[dim_numbers.index_vector_dim()] = + ir_builder_->getInt64(i); + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, i); + } + } + + IrArray::Index safe_operand_index; + for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { + safe_operand_index.push_back(ir_builder_->CreateURem( + unsafe_operand_index[i], + ir_builder_->getInt64(operand_shape.dimensions(i)))); + } + + return operand_generator(safe_operand_index); +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const HloInstruction* input_hlo = hlo->operand(0); + const HloInstruction* update_hlo = hlo->operand(1); + const HloInstruction* start_hlo = hlo->operand(2); + // Calculate slice start/end indices. + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + llvm_ir::IrArray::Index slice_limit_index(rank); + // Slice intersection gathers (ANDs) conditions on all ranks for which + // 'input' is set to 'update' + llvm::Value* slice_intersection = ir_builder_->getTrue(); + + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(start_hlo)(dim_index)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* input_dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + slice_limit_index[i] = + ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); + } + + // Emit: + // if (slice_intersection) -> return data from 'update'. + // else -> return data from 'input'. + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "ret_value_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + slice_intersection, "slice_intersection", ir_builder_); + + // Handle true BB (return data from 'update') + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + // Compute update index for intersection case. + llvm_ir::IrArray::Index update_index(rank); + for (int64 i = 0; i < rank; ++i) { + update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); + } + TF_ASSIGN_OR_RETURN(llvm::Value * true_value, + operand_to_generator.at(update_hlo)(update_index)); + ir_builder_->CreateStore(true_value, ret_value_addr); + + // Handle false BB (return data from 'input') + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * false_value, + operand_to_generator.at(input_hlo)(index)); + ir_builder_->CreateStore(false_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalPad( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const { + auto index = padded_index; + llvm::Value* in_bounds = ir_builder_->getTrue(); + for (size_t i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalDot( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const { + auto lhs_generator = operand_to_generator.at(hlo->operand(0)); + auto rhs_generator = operand_to_generator.at(hlo->operand(1)); + + const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); + int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0); + int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); + + int64 contracted_dim_size = + hlo->operand(0)->shape().dimensions(lhs_contracting_dim); + int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); + int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); + + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( + IrName(hlo, "inner"), ir_builder_->getInt64(0), + ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), + ir_builder_); + + SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); + PrimitiveType primitive_type = hlo->shape().element_type(); + llvm::Type* primitive_type_llvm = + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); + llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( + primitive_type_llvm, "dot_acc", ir_builder_); + ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), + accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); + + // This is the inner reduction loop for a dot operation that produces + // one element in the output. If the operands to the dot operation have + // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. + // Given an output index [a,b,c,d,e] in the result, we compute: + // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) + + IrArray::Index lhs_index, rhs_index; + + for (int64 i = 0; i < lhs_dims - 1; i++) { + lhs_index.push_back(dot_result_index[i]); + } + lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); + + for (int64 i = 0; i < rhs_dims - 1; i++) { + rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); + } + rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); + + llvm::Value* current_accumulator = + ir_builder_->CreateLoad(accumulator_alloca); + TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); + TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); + llvm::Value* next_accumulator; + if (primitive_util::IsComplexType(primitive_type)) { + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); + next_accumulator = ir_builder_->CreateInsertValue( + current_accumulator, + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), + {0}); + next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator, + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), + {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { + next_accumulator = ir_builder_->CreateFAdd( + current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); + } else { + next_accumulator = ir_builder_->CreateAdd( + current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value)); + } + ir_builder_->CreateStore(next_accumulator, accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); + return ir_builder_->CreateLoad(accumulator_alloca); +} + llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) @@ -1295,15 +1905,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: @@ -1353,43 +1966,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSelect: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - return ir_builder_->CreateSelect( - ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), - on_true_value, on_false_value); + return EmitElementalSelect(hlo, operand_to_generator, index); }; case HloOpcode::kClamp: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - PrimitiveType prim_type = hlo->shape().element_type(); - if (primitive_util::IsFloatingPointType(prim_type)) { - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); - } else if (primitive_util::IsIntegralType(prim_type)) { - bool is_signed = primitive_util::IsSignedIntegralType(prim_type); - return EmitIntegralMin( - max_value, EmitIntegralMax(min_value, arg_value, is_signed), - is_signed); - } else { - return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); - } + return EmitElementalClamp(hlo, operand_to_generator, index); }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( @@ -1402,70 +1984,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { - const int64 concat_dim = hlo->dimensions(0); - auto source_index = target_index; - - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), - init_block->getTerminator() == nullptr); - - llvm::BasicBlock* exit_block; - if (ir_builder_->GetInsertPoint() == init_block->end()) { - exit_block = llvm_ir::CreateBasicBlock( - /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); - } else { - exit_block = init_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); - init_block->getTerminator()->eraseFromParent(); - } - - llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); - llvm::PHINode* output = - ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_), - hlo->operands().size()); - auto prior_insert_point = ir_builder_->GetInsertPoint(); - - ir_builder_->SetInsertPoint(init_block); - - for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); - ++operand_idx) { - const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), - ir_builder_); - auto false_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_not_from_operand", operand_idx), - ir_builder_); - auto concat_dim_size = - llvm::ConstantInt::get(source_index[concat_dim]->getType(), - operand->shape().dimensions(concat_dim)); - ir_builder_->CreateCondBr( - ir_builder_->CreateICmpULT(source_index[concat_dim], - concat_dim_size), - true_block, false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - ir_builder_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, ir_builder_->GetInsertBlock()); - - // Subtract the size of the concat dimension of the current operand - // from the source index. - ir_builder_->SetInsertPoint(false_block); - source_index[concat_dim] = - ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); - } - - ir_builder_->CreateUnreachable(); - ir_builder_->SetInsertPoint(exit_block, prior_insert_point); - return output; + return EmitElementalConcatenate(hlo, operand_to_generator, + target_index); }; case HloOpcode::kReverse: return [this, hlo, &operand_to_generator]( @@ -1483,15 +2003,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kBroadcast: return [this, hlo, &operand_to_generator]( const IrArray::Index& target_index) -> StatusOr { + const HloInstruction* operand = hlo->operand(0); // The `dimensions` member of the broadcast instruction maps from // input dimensions to output dimensions. - const HloInstruction* operand = hlo->operand(0); - int64 rank = ShapeUtil::Rank(operand->shape()); - IrArray::Index source_index(rank); - for (int64 i = 0; i < rank; ++i) { - source_index[i] = target_index[hlo->dimensions(i)]; - } - return operand_to_generator.at(operand)(source_index); + return operand_to_generator.at( + operand)(target_index.SourceIndexOfBroadcast( + hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_)); }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( @@ -1504,184 +2021,27 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kDynamicSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - // Emit IR to read dynamic start indices from hlo->operand(1). - const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; - } + return EmitElementalDynamicSlice(hlo, operand_to_generator, index); + }; - llvm_ir::IrArray::Index input_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); - } - return operand_to_generator.at(input_hlo)(input_index); + case HloOpcode::kGather: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + return EmitElementalGather(hlo, operand_to_generator, index); }; case HloOpcode::kDynamicUpdateSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - const HloInstruction* input_hlo = hlo->operand(0); - const HloInstruction* update_hlo = hlo->operand(1); - const HloInstruction* start_hlo = hlo->operand(2); - // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - - // Slice intersection gathers (ANDs) conditions on all ranks for which - // 'input' is set to 'update' - llvm::Value* slice_intersection = ir_builder_->getTrue(); - - for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); - - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); - slice_limit_index[i] = - ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, - ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - index_wraps, "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = - ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, - "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, - ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, - ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), - 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, - if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming( - slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; - } - - // Emit: - // if (slice_intersection) -> return data from 'update'. - // else -> return data from 'input'. - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "ret_value_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - slice_intersection, "slice_intersection", ir_builder_); - - // Handle true BB (return data from 'update') - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); - } - TF_ASSIGN_OR_RETURN(llvm::Value * true_value, - operand_to_generator.at(update_hlo)(update_index)); - ir_builder_->CreateStore(true_value, ret_value_addr); - - // Handle false BB (return data from 'input') - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * false_value, - operand_to_generator.at(input_hlo)(index)); - ir_builder_->CreateStore(false_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator, + index); + }; + case HloOpcode::kBitcast: + CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), + ShapeUtil::ElementsIn(hlo->operand(0)->shape())); + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + const HloInstruction* operand = hlo->operand(0); + return operand_to_generator.at(operand)(index.SourceIndexOfBitcast( + hlo->shape(), operand->shape(), ir_builder_)); }; case HloOpcode::kReshape: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), @@ -1702,155 +2062,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRng: return MakeRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: - return [=, &operand_to_generator]( + return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { - auto index = padded_index; - llvm::Value* in_bounds = ir_builder_->getTrue(); - for (size_t i = 0; i < index.size(); ++i) { - auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); - }; - const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpEQ( - index_typed_const(0), - ir_builder_->CreateURem( - index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = ir_builder_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); - } - - // if (in_bounds) { - // ret_value = operand0[index]; // source - // } else { - // ret_value = *operand1; // padding - // } - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "pad_result_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - // Don't create phi(operand_value, padding_value) here, because invoking - // operand_to_generator may create new basic blocks, making the parent - // of operand_value or padding_value no longer a predecessor of - // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalPad(hlo, operand_to_generator, padded_index); }; case HloOpcode::kDot: - return [=, &operand_to_generator](const IrArray::Index& dot_result_index) + return [this, hlo, + &operand_to_generator](const IrArray::Index& dot_result_index) -> StatusOr { - auto lhs_generator = operand_to_generator.at(hlo->operand(0)); - auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); - int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); - int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - - std::unique_ptr inner_loop = - llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), - ir_builder_->getInt64(1), ir_builder_); - - SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), - ir_builder_); - PrimitiveType primitive_type = hlo->shape().element_type(); - llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); - llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( - primitive_type_llvm, "dot_acc", ir_builder_); - ir_builder_->CreateStore( - llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); - - // This is the inner reduction loop for a dot operation that produces - // one element in the output. If the operands to the dot operation have - // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. - // Given an output index [a,b,c,d,e] in the result, we compute: - // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - - IrArray::Index lhs_index, rhs_index; - - for (int64 i = 0; i < lhs_dims - 1; i++) { - lhs_index.push_back(dot_result_index[i]); - } - lhs_index.push_back(inner_loop->GetIndVarValue()); - - for (int64 i = 0; i < rhs_dims - 2; i++) { - rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); - } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); - - llvm::Value* current_accumulator = - ir_builder_->CreateLoad(accumulator_alloca); - TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); - TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); - llvm::Value* next_accumulator; - if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value))); - next_accumulator = ir_builder_->CreateInsertValue( - current_accumulator, - ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), - product_real), - {0}); - next_accumulator = ir_builder_->CreateInsertValue( - next_accumulator, - ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), - product_imag), - {1}); - } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = ir_builder_->CreateFAdd( - current_accumulator, - ir_builder_->CreateFMul(lhs_value, rhs_value)); - } else { - next_accumulator = ir_builder_->CreateAdd( - current_accumulator, - ir_builder_->CreateMul(lhs_value, rhs_value)); - } - ir_builder_->CreateStore(next_accumulator, accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accumulator_alloca); + return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index c516a826d9e382bc738e54635426db639d17108c..d199473374ad394913413a7d3fe805f8782936f7 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -105,6 +105,9 @@ class ElementalIrEmitter { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const; @@ -114,6 +117,9 @@ class ElementalIrEmitter { virtual StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const; @@ -142,6 +148,46 @@ class ElementalIrEmitter { return ir_builder_->getIntN(128, 0); } + StatusOr EmitElementalSelect( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalClamp( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalConcatenate( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const; + + StatusOr EmitElementalDynamicSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalGather( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalPad( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const; + + StatusOr EmitElementalDot( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const; + llvm::IRBuilder<>* const ir_builder_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8980d4303353a132ada2b3c685b4f2856c33c6a1 --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ElementalIrEmitterExecutionTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { + const string hlo_text = R"( +HloModule FusedDot + +fused_computation { + arg0 = s32[1,2,1]{2,1,0} parameter(0) + reshape.lhs = s32[2,1]{1,0} reshape(arg0) + arg1 = s32[1,2,1]{2,1,0} parameter(1) + reshape.rhs = s32[2,1]{1,0} reshape(arg1) + ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY main { + entry_arg0 = s32[1,2,1]{2,1,0} parameter(0) + entry_arg1 = s32[1,2,1]{2,1,0} parameter(1) + ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation +} +)"; + + std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {lhs.get(), rhs.get()}); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 90481c7a88f90edea5399ee44aee2d2c77fc115f..6df172db8e541c5cef7aab04f3d8611fc735e8b0 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" @@ -28,18 +29,19 @@ using tensorflow::gtl::ArraySlice; namespace xla { -StatusOr>> -Executable::ExecuteOnStreams( +StatusOr> Executable::ExecuteOnStreams( ArraySlice run_options, ArraySlice> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); - std::vector> return_values(run_options.size()); + std::vector return_values; + return_values.reserve(run_options.size()); if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(return_values[0], + TF_ASSIGN_OR_RETURN(auto rv, ExecuteOnStream(&run_options[0], arguments[0], /*hlo_execution_profile=*/nullptr)); + return_values.push_back(std::move(rv)); return std::move(return_values); } @@ -47,8 +49,9 @@ Executable::ExecuteOnStreams( // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. - TF_ASSIGN_OR_RETURN(return_values[i], + TF_ASSIGN_OR_RETURN(auto rv, ExecuteAsyncOnStream(&run_options[i], arguments[i])); + return_values.push_back(std::move(rv)); } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); @@ -57,13 +60,13 @@ Executable::ExecuteOnStreams( return std::move(return_values); } -StatusOr> Executable::ExecuteOnStreamWrapper( +StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, ArraySlice arguments) { - perftools::gputools::Stream* stream = run_options->stream(); - std::unique_ptr timer; + se::Stream* stream = run_options->stream(); + std::unique_ptr timer; if (profile != nullptr) { - timer.reset(new perftools::gputools::Timer(stream->parent())); + timer.reset(new se::Timer(stream->parent())); stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); } @@ -77,8 +80,9 @@ StatusOr> Executable::ExecuteOnStreamWrapper( &hlo_profile_index_map()) : nullptr; - StatusOr> return_value = + StatusOr return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); + TF_RETURN_IF_ERROR(return_value.status()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; @@ -125,23 +129,22 @@ StatusOr> Executable::ExecuteOnStreamWrapper( return return_value; } -Status Executable::DumpSessionModule() { - TF_RET_CHECK(dumping()); +Status Executable::DumpHloSnapshot() { + TF_RET_CHECK(dumping_snapshot()); + TF_RET_CHECK(hlo_snapshot_->has_hlo() && + hlo_snapshot_->hlo().has_hlo_module()); const string& directory_path = module_config().debug_options().xla_dump_executions_to(); - VersionedComputationHandle versioned_handle = entry_computation_handle(); - // This filename does not include the version number because the computation - // is only ever executed at one version. + const auto& module = hlo_snapshot_->hlo().hlo_module(); string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), - session_module_->entry().name().c_str(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, - *session_module_); + "computation_%lld__%s__execution_%lld", module.id(), + module.entry_computation_name().c_str(), ++execution_count_); + return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, - const SessionModule& session_module) { + const HloSnapshot& hlo_session) { tensorflow::Env* env = tensorflow::Env::Default(); if (!env->IsDirectory(directory_path).ok()) { // NB! CreateDir does not work reliably with multiple XLA threads -- two @@ -154,7 +157,7 @@ Status Executable::DumpSessionModule() { string file_path = tensorflow::io::JoinPath(directory_path, filename); string result; TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(session_module, &result)); + tensorflow::SerializeToStringDeterministic(hlo_session, &result)); return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, result); } diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 0aee535ee780ef000bc5e9963ff48786b3a61eb2..dc1f26ea65cc707d4f0522af2aa3ec40621632f1 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,14 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -63,14 +61,14 @@ class Executable { // enabled. // // Returns a shaped buffer containing the result of the computation. - virtual StatusOr> ExecuteOnStream( + virtual StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr> ExecuteAsyncOnStream( + virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) = 0; @@ -78,7 +76,7 @@ class Executable { // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr>> ExecuteOnStreams( + virtual StatusOr> ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< @@ -91,14 +89,14 @@ class Executable { // has completed. virtual Status PopulateExecutionProfile( HloExecutionProfile* hlo_execution_profile, - perftools::gputools::StreamExecutor* executor) { + se::StreamExecutor* executor) { return Status::OK(); } // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. - StatusOr> ExecuteOnStreamWrapper( + StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, tensorflow::gtl::ArraySlice arguments); @@ -109,14 +107,6 @@ class Executable { return execution_profile_; } - // Returns Status::ok() if the two executables are equal to each other. - // - // An error status is returned otherwise. - virtual const Status EqualOrFail(const Executable& executable) { - return Unimplemented( - "Equality test on this executable is not implemented."); - } - const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); return *hlo_profile_printer_data_; @@ -140,29 +130,23 @@ class Executable { const HloModuleConfig& module_config() const { return hlo_module_->config(); } - // Returns the versioned computation handle of the computation computed by - // this executable. - const VersionedComputationHandle& entry_computation_handle() const { - return hlo_module_->entry_computation_handle(); - } - // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& result_shape() const { - return hlo_module_->config().entry_computation_layout().result_shape(); + const Shape& host_result_shape() const { + return hlo_module_->config().host_entry_computation_layout().result_shape(); } // Dumping helpers. - void set_session_module(std::unique_ptr session_module) { - session_module_ = std::move(session_module); + void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { + hlo_snapshot_ = std::move(hlo_snapshot); } - bool dumping() const { return session_module_ != nullptr; } - SessionModule* session_module() const { return session_module_.get(); } - Status DumpSessionModule(); + bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } + HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } + Status DumpHloSnapshot(); - // Dump session_module to directory_path/filename. + // Dump hlo snapshot to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, - const SessionModule& session_module); + const HloSnapshot& hlo_session); protected: mutable tensorflow::mutex mutex_; @@ -175,8 +159,8 @@ class Executable { // around. const std::unique_ptr hlo_module_; - // SessionModule this was compiled from. Null if not dumping executions. - std::unique_ptr session_module_; + // HloSnapshot this was compiled from. Null if not dumping executions. + std::unique_ptr hlo_snapshot_; // Execution count, used to generate a unique filename for each dumped // execution. diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979..6794cfe297b0fb9a15eb9b7e6906d225f9597d07 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend, } } -tensorflow::Status AsyncExecution::BlockUntilDone() const { +Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } - return tensorflow::Status::OK(); + return Status::OK(); } ExecutionTracker::ExecutionTracker() : next_handle_(1) {} @@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register( return execution_handle; } -tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { +Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { @@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { handle.handle()); } handle_to_execution_.erase(handle.handle()); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr ExecutionTracker::Resolve( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 5b6bddf9f16a85f7863f4d05c39c7d4c99209af1..4458152dd9a98890fc3a3e7f324245ec68821467 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -43,7 +43,7 @@ class AsyncExecution { AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); - tensorflow::Status BlockUntilDone() const; + Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -77,7 +77,7 @@ class ExecutionTracker { GlobalDataHandle data); // Unregisters the execution for the given handle. - tensorflow::Status Unregister(const ExecutionHandle& handle); + Status Unregister(const ExecutionHandle& handle); // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 2b6caa149439a86d6d047605099bc3ff7b295a8e..85409b330b11537158059dcce8c2a96c98d38f30 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -93,7 +93,7 @@ Status FlattenNode(const CallGraphNode& node) { auto current = worklist.back(); worklist.pop_back(); for (auto* instruction : current->instructions()) { - if (GetInstructionCallContext(instruction) != + if (GetInstructionCallContext(instruction->opcode()) != CallContext::kSequential) { continue; } diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md similarity index 100% rename from tensorflow/compiler/xla/tools/parser/README.md rename to tensorflow/compiler/xla/service/g3doc/hlo_parser.md diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d3e4b1fcdf6675955714cab262a8b2ca8ff4297 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -0,0 +1,388 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +using tensorflow::gtl::ArraySlice; + +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* gather_indices, int64 index_vector_dim) { + const Shape& gather_indices_shape = gather_indices->shape(); + + if (gather_indices_shape.dimensions_size() == index_vector_dim) { + return gather_indices; + } + + if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { + return gather_indices; + } + + std::vector permutation; + permutation.reserve(gather_indices_shape.dimensions_size()); + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(gather_indices, permutation); +} + +// Canonicalizes the gather_indices tensors so that we only have deal with some +// specific cases in the while loop that does the heavy lifting. +// +// See the "High Level Algorithm" section for a broader picture. +static StatusOr CanonicalizeGatherIndices( + HloInstruction* gather_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_gather_indices, + TransposeIndexVectorDimToLast(gather_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == gather_indices->shape().dimensions_size(); + + // The number of dimensions in gather_indices that are index dimensions. + const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. gather_indices has rank 1 and this gather + // is really just a dynamic slice) add a leading degenerate dimension for + // uniformity. Otherwise create a "collapsed" leading dimension that subsumes + // all of the non-index-vector dimensions. + const Shape& shape = transposed_gather_indices->shape(); + if (shape.dimensions_size() == index_dims_in_gather_indices) { + return PrependDegenerateDims(transposed_gather_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in gather_indices containing the + // index vectors. + return CollapseFirstNDims( + transposed_gather_indices, + shape.dimensions_size() - index_dims_in_gather_indices); + } +} + +// Expands out or contracts away the gather dimensions in the accumulator +// produced by the while loop. +static StatusOr AdjustGatherDimsInAccumulator( + const Shape& gather_indices_shape, HloInstruction* accumulator, + int64 index_vector_dim) { + std::vector output_gather_dim_bounds; + output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + } + } + + if (output_gather_dim_bounds.empty()) { + // If output_gather_dim_bounds is empty we must be lowering a (effectively) + // dynamic-slice. In that case, there is a leading degenerate gather + // dimension that we added to make this special case play well with the + // general while loop which we need to remove now. + return ElideDegenerateDims(accumulator, {0}); + } + + return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); +} + +// Expand an index vector from the gather_indices tensor into a vector that can +// be used to dynamic-slice out of the gather operand. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.gather_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// This generates the body of the while that implements the main data movement +// behavior of gather using dynamic-slice and dynamic-update-slice. +static StatusOr> GatherLoopBody( + const HloInstruction& gather, HloInstruction* induction_var, + const std::vector& incoming_loop_state) { + const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); + CHECK_EQ(incoming_loop_state.size(), 3); + HloInstruction* const operand = incoming_loop_state[0]; + HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const output_accumulator = incoming_loop_state[2]; + + bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + gather.operand(1)->shape().dimensions_size()); + + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + HloInstruction* index_vector; + + if (has_scalar_indices) { + // In this case gather_indices has rank 1 and induction_var_as_vector (of + // shape {1}) is an index into this rank 1 tensor. + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1})); + } else { + // In this case gather_indices has rank 2 and induction_var_as_vector (of + // shape {1}) is an index into just the first dimension of this rank 2 + // tensor. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_gather_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + + int64 index_vector_size = gather_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + {1, index_vector_size})); + + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + + TF_ASSIGN_OR_RETURN( + HloInstruction * gathered_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, + MakeDynamicSliceHlo(operand, gathered_slice_start, + gather.gather_window_bounds())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * gathered_slice_with_dims_elided, + ElideDegenerateDims(gathered_slice, + AsInt64Slice(dim_numbers.elided_window_dims()))); + + TF_ASSIGN_OR_RETURN( + HloInstruction * gathered_slice_for_update, + PrependDegenerateDims(gathered_slice_with_dims_elided, 1)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_into_accumulator, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/ + gathered_slice_with_dims_elided->shape().dimensions_size())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_accumulator, + MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, + index_vector_into_accumulator)); + + // New loop state -- only the accumulator has changed. The + // WhileUtil::MakeCountedLoop functions takes care of the induction variable + // and the while loop exit condition. + return StatusOr>{ + {operand, gather_indices, updated_accumulator}}; +} + +static StatusOr CreateGatherLoopAccumulatorInitValue( + HloComputation* computation, PrimitiveType element_type, + ArraySlice window_bounds, int64 gather_loop_trip_count, + const GatherDimensionNumbers& dim_numbers) { + std::vector accumulator_state_shape_dims; + accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.push_back(gather_loop_trip_count); + for (int64 i = 0; i < window_bounds.size(); i++) { + if (!c_binary_search(dim_numbers.elided_window_dims(), i)) { + accumulator_state_shape_dims.push_back(window_bounds[i]); + } + } + return BroadcastZeros(computation, element_type, + accumulator_state_shape_dims); +} + +// `accumulator` is almost the tensor the gather operation would have produced, +// except that it has the dimensions in the wrong order -- the gather dimensions +// are the major dimensions and the window dimensions are the minor dimensions. +// Fix this up with a transpose. +static StatusOr PermuteGatherAndWindowDims( + HloInstruction* accumulator, ArraySlice output_window_dims, + int64 output_rank) { + std::vector permutation; + permutation.reserve(output_rank); + + int64 gather_idx_counter = 0; + int64 window_idx_counter = output_rank - output_window_dims.size(); + for (int64 i = 0; i < output_rank; i++) { + bool is_window_dim = c_binary_search(output_window_dims, i); + if (is_window_dim) { + permutation.push_back(window_idx_counter++); + } else { + permutation.push_back(gather_idx_counter++); + } + } + + return MakeTransposeHlo(accumulator, permutation); +} + +// High Level Algorithm +// +// We follow the following steps in sequence: +// +// 1. We canonicalize the gather_indices tensor such that it has rank +// 2 (i.e. is a matrix) where each row is an index vector into the +// operand. +// 2. We iterate over the set of indices in the canonicalized +// gather_indices tensor using a while loop, accumulating slices +// of the operand tensor into an accumulator using +// DynamicUpdateSlice. +// 3. The accumulator result from the while loop from (2) is then +// reshaped to split out all the individual gather dimensions and +// then transposed to give the final result. +// +// As an example, if we started with the following operation: +// +// HloModule TensorFlowGatherMultipleBatchDims +// +// ENTRY main { +// operand = s32[3,3] parameter(0) +// indices = s32[2,2] parameter(1) +// ROOT gather = s32[2,3,2] gather(operand, indices), +// output_window_dims={1}, +// elided_window_dims={1}, +// gather_dims_to_operand_dims={1}, +// index_vector_dim=2, +// window_bounds={3, 1} +// } +// +// We'd first reshape indices to s32[4,1], where each row is an index +// into operand. We'd then run a loop to slice out 4 tensors of shape +// [3,1] out of operand into an accumulator of shape [4,3,1]. We then +// reshape this result to [2,2,3] and finally transpose it to [2,3,2]. + +StatusOr GatherExpander::ExpandGather( + HloInstruction* gather_instr) { + CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + + HloComputation* computation = gather_instr->parent(); + HloInstruction* operand = gather_instr->mutable_operand(0); + HloInstruction* gather_indices = gather_instr->mutable_operand(1); + const Shape& gather_indices_shape = gather_indices->shape(); + const Shape& output_shape = gather_instr->shape(); + int64 output_rank = output_shape.dimensions_size(); + + const GatherDimensionNumbers& dim_numbers = + gather_instr->gather_dimension_numbers(); + + int64 gather_loop_trip_count = 1; + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + gather_loop_trip_count *= gather_indices_shape.dimensions(i); + } + } + + if (!IsInt32(gather_loop_trip_count)) { + return Unimplemented( + "Gather operations with more than 2147483647 gather indices are not " + "supported. This error occurred for %s.", + gather_instr->ToString().c_str()); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, + CanonicalizeGatherIndices( + gather_indices, dim_numbers.index_vector_dim())); + + CHECK_EQ(gather_loop_trip_count, + canonical_gather_indices->shape().dimensions(0)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_init, + CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_window_bounds(), gather_loop_trip_count, + gather_instr->gather_dimension_numbers())); + + StatusOr> gather_loop_result_or_error = + WhileUtil::MakeCountedLoop( + computation, gather_loop_trip_count, + {operand, canonical_gather_indices, accumulator_init}, + [&](HloInstruction* indvar, + const std::vector& loop_state) { + return GatherLoopBody(*gather_instr, indvar, loop_state); + }); + + TF_ASSIGN_OR_RETURN(std::vector gather_loop_result, + gather_loop_result_or_error); + + HloInstruction* accumulator_result = gather_loop_result.back(); + + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_with_output_gather_dims_decanonicalized, + AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result, + dim_numbers.index_vector_dim())); + + return PermuteGatherAndWindowDims( + accumulator_with_output_gather_dims_decanonicalized, + AsInt64Slice(dim_numbers.output_window_dims()), output_rank); +} + +StatusOr GatherExpander::Run(HloModule* module) { + auto is_nontrivial_gather = [](HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::HasZeroElements(inst->shape()); + }; + + std::vector gather_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), + is_nontrivial_gather); + } + + for (HloInstruction* inst : gather_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandGather(inst)); + TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + } + + return !gather_instrs.empty(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..c1fc8574da99fff223c7dbb570b4533f76905b9a --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass rewrites gather operations into (roughly) while loops of dynamic +// slices. This lets backends that don't support gather directly to +// nevertheless have a minimum level of support. +class GatherExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "gather_expander"; } + StatusOr Run(HloModule* module) override; + + private: + StatusOr ExpandGather(HloInstruction* gather_instr); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..020ffcd106862cb2641a9f3bceb70acdd969a458 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { +TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2147483647,5] parameter(1) + ROOT gather = s32[2147483647,3,5] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + Status status = GatherExpander{}.Run(module.get()).status(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + + ASSERT_THAT( + status.error_message(), + ::testing::HasSubstr("Gather operations with more than 2147483647 gather " + "indices are not supported.")); +} + +TEST(GatherExpanderTest, AvoidDegenerateDims) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + ASSERT_TRUE(changed); + + HloInstruction* while_instr = nullptr; + for (auto* instr : module->entry_computation()->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(while_instr, nullptr) + << "Expected exactly one while instruction in the entry computation " + "after gather expansion"; + while_instr = instr; + } + } + + ASSERT_NE(while_instr, nullptr) + << "Expected exactly one while instruction in the entry computation " + "after gather expansion"; + + // We want to avoid create while loop with shapes that have degenerate + // dimensions for TF gather. In this case we expect the loop state to be of + // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}). The leading + // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't + // check it here (though in theory the form of the while loop state is itself + // an implementation detail from WhileUtil::MakeCountedLoop). + + const Shape& while_shape = while_instr->shape(); + ASSERT_TRUE(ShapeUtil::IsTuple(while_shape)); + ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4); + + EXPECT_TRUE(ShapeUtil::SameDimensions( + ShapeUtil::MakeShape(S32, {3, 3}), + ShapeUtil::GetTupleElementShape(while_shape, 1))); + + EXPECT_TRUE(ShapeUtil::SameDimensions( + ShapeUtil::MakeShape(S32, {2}), + ShapeUtil::GetTupleElementShape(while_shape, 2))); + + EXPECT_TRUE(ShapeUtil::SameDimensions( + ShapeUtil::MakeShape(S32, {2, 3}), + ShapeUtil::GetTupleElementShape(while_shape, 3))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 78dc0ad4fcd167c93f19d0c2b18ea72d666897ef..5ee67ccb4ae147683c7b41941670c6fc413a0d09 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -32,29 +32,20 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size) - : platform_id_(platform_id), pointer_size_(pointer_size) { - // We currently only support kHostPlatformId for CPU, kCudaPlatformId for - // GPU and kInterpreterPlatformId for Interpreter. Before supporting other - // platforms, we need to test this transfer manager on them. - CHECK(platform_id_ == se::host::kHostPlatformId || - platform_id_ == se::interpreter::kInterpreterPlatformId || - platform_id_ == se::cuda::kCudaPlatformId); -} + : platform_id_(platform_id), pointer_size_(pointer_size) {} se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; } Status GenericTransferManager::WriteSingleTupleIndexTable( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { + const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); std::vector element_pointers; @@ -98,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( } Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -124,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const auto subliteral = LiteralView::Create(literal, index); + const auto subliteral = LiteralSlice(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), @@ -146,25 +137,24 @@ Status GenericTransferManager::TransferLiteralToDevice( } Status GenericTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { + se::StreamExecutor* executor, const LiteralSlice& literal) { return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) { + se::StreamExecutor* executor, int64 size, const void* source) { return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) { return Unimplemented( "Outfeed is not supported on this platform (b/30467474)"); } Status GenericTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice + tensorflow::gtl::ArraySlice /*executors*/) { return Unimplemented( "Device reset is not yet supported on this platform (b/30481585)"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 63a7c820cf4e5fbbdf870086a4fb5316ac50d10b..3da9570ef7eebcdf618439f628fb4d5589993e4f 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -36,46 +36,41 @@ namespace xla { // infeed. class GenericTransferManager : public TransferManager { public: - GenericTransferManager(perftools::gputools::Platform::Id platform_id, - size_t pointer_size); + GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size); ~GenericTransferManager() override {} - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; StatusOr> TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const ShapedBuffer& device_buffer) override; + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; - Status TransferLiteralToDevice(perftools::gputools::StreamExecutor* executor, - const Literal& literal, + Status TransferLiteralToDevice(se::StreamExecutor* executor, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) override; - Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, - const Literal& literal) override; - Status TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; Status ResetDevices( - tensorflow::gtl::ArraySlice - executors) override; + tensorflow::gtl::ArraySlice executors) override; int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source) override; Status WriteSingleTupleIndexTable( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, - perftools::gputools::DeviceMemoryBase* region) override; + se::StreamExecutor* executor, + tensorflow::gtl::ArraySlice elements, + const Shape& shape, se::DeviceMemoryBase* region) override; private: // The platform this transfer manager targets. - const perftools::gputools::Platform::Id platform_id_; + const se::Platform::Id platform_id_; // The size in bytes of pointers on this platform. const size_t pointer_size_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9da4fb97fa27a238fead74985cb481a9be1f4a65..5e02631a5856a45762027d246144d6f6dd913c26 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,8 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -23,6 +25,11 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +xla_proto_library( + name = "backend_configs", + srcs = ["backend_configs.proto"], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -133,6 +140,7 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -156,6 +164,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -241,6 +250,7 @@ cc_library( "gpu_executable.cc", "infeed_thunk.cc", "kernel_thunk.cc", + "memset_thunk.cc", "sequential_thunk.cc", "thunk_schedule.cc", "tuple_thunk.cc", @@ -257,6 +267,7 @@ cc_library( "gpu_executable.h", "infeed_thunk.h", "kernel_thunk.h", + "memset_thunk.h", "sequential_thunk.h", "thunk.h", "thunk_schedule.h", @@ -264,6 +275,7 @@ cc_library( "while_thunk.h", ], deps = [ + ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", ":infeed_manager", @@ -273,6 +285,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -288,11 +301,13 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep + "//tensorflow/stream_executor", ], ) @@ -317,6 +332,7 @@ cc_library( srcs = ["cudnn_convolution_algorithm_picker.cc"], hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", @@ -333,6 +349,7 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -384,8 +401,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) @@ -394,12 +413,44 @@ tf_cc_test( srcs = ["instruction_fusion_test.cc"], deps = [ ":instruction_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:multi_output_fusion", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + deps = [ + ":multi_output_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gpu_copy_insertion", srcs = ["gpu_copy_insertion.cc"], @@ -437,6 +488,8 @@ tf_cc_test( ":fusion_merger", ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -452,6 +505,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", ], @@ -497,6 +551,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", + ":multi_output_fusion", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -510,9 +565,11 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -529,6 +586,8 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -574,14 +633,18 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ + ":gpu_options", ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -610,6 +673,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", @@ -677,6 +741,27 @@ cc_library( ], ) +cc_library( + name = "gpu_options", + srcs = ["gpu_options.cc"], + hdrs = ["gpu_options.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + tf_cc_test( name = "gpu_hlo_support_checker_test", srcs = ["gpu_hlo_support_checker_test.cc"], @@ -690,17 +775,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto new file mode 100644 index 0000000000000000000000000000000000000000..640c6392b8b820c708b853c2a3cea4d4116e85a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package xla.gpu; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstrucitons and later +// uses during e.g. codegen. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + // Opaque algorithm number of cudnn algorithm chosen for this conv. + int64 algorithm = 1; + + // Whether we may use tensor cores when running this conv. Even if this is + // true, cudnn may choose not to use tensor cores, e.g. because the GPU or + // selected algorithm doesn't support it. + bool tensor_ops_enabled = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 2029c303d47e9a62135b003c3bd9be6f8b3438d4..ab5149dcdb09290cd0c0b2233029d0988a95f036 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -28,8 +28,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -39,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, } StatusOr> BufferAllocations::Builder::Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { - const int64 num_buffers = buffer_assignment.Allocations().size(); - auto buffer_allocations = WrapUnique( - new BufferAllocations(num_buffers, device_ordinal, memory_allocator)); + const int64 num_buffers = buffer_assignment->Allocations().size(); + auto buffer_allocations = WrapUnique(new BufferAllocations( + num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { // If buffer #i's address is already registered (e.g. external arguments or @@ -64,28 +62,28 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( - device_ordinal, buffer_size)); - if (buffer_address == nullptr) { - return ResourceExhausted( - "Out of memory when allocating %s for buffer %lld.", - tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), - i); - } - if (reinterpret_cast(buffer_address.opaque()) % + OwningDeviceMemory buffer; + TF_ASSIGN_OR_RETURN( + buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); + if (reinterpret_cast(buffer.opaque()) % kCudaMallocAlignBytes != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer_address.opaque()); + kCudaMallocAlignBytes, buffer.opaque()); } + // We do manual memory management within BufferAllocations. Be sure not + // to do a TF_RETURN_IF_ERROR between this line and the + // buffer_allocations->SetBuffer(buffer_address) call below! + buffer_address = buffer.Forget(); } + buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { if (seen_temp_buffer) { @@ -105,28 +103,42 @@ StatusOr> BufferAllocations::Builder::Build( << "B)"; } } - return std::move(buffer_allocations); } -tensorflow::Status BufferAllocations::TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment) { - // Deallocate temporary buffers. - const int64 num_buffers = buffer_assignment.Allocations().size(); +BufferAllocations::~BufferAllocations() { + if (!torn_down_) { + // Presumably if we're executing this branch, the caller is in an error + // state, otherwise it would have explicitly called TearDown so it could + // save some set of live addresses. So ignoring any errors in TearDown is + // sensible. + TearDown(/*live_addresses=*/{}).IgnoreError(); + } +} + +Status BufferAllocations::TearDown( + const std::set& live_addresses) { + // Deallocate temporary buffers, taking care to try to deallocate all of them + // even if one of the deallocations fails. + Status status; + const int64 num_buffers = buffer_assignment_->Allocations().size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); // Deallocate buffers marked "maybe_live_out" but aren't actually live out, // and temp buffers. if ((allocation.maybe_live_out() && !live_addresses.count(buffer_address)) || allocation.IsPreallocatedTempBuffer()) { - TF_RETURN_IF_ERROR( - memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); + auto dealloc_result = + memory_allocator_->Deallocate(device_ordinal_, buffer_address); + if (!dealloc_result.ok() && status.ok()) { + status = dealloc_result; + } } } - return tensorflow::Status::OK(); + torn_down_ = true; + return status; } se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index ea7f0eb3745f2e0e0bfd88c3dca79d6ad25884ed..636623502597b3a66523938ba430e9d5a82f796c 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -41,21 +41,22 @@ class BufferAllocations { // user-specified result buffers) to the given buffer index. The builder // will skip allocating buffers for registered buffer indices. void RegisterBuffer(BufferAllocation::Index index, - perftools::gputools::DeviceMemoryBase address); + se::DeviceMemoryBase address); // Builds a BufferAllocations object from the given buffer assignment. // `memory_allocator` is what this function uses to allocate device memory. // `device_ordinal` is the number of the device this function allocates // memory on. StatusOr> Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator); private: - std::map - registered_buffers_; + std::map registered_buffers_; }; + ~BufferAllocations(); + BufferAllocations(const BufferAllocations&) = delete; BufferAllocations& operator=(const BufferAllocations&) = delete; @@ -65,46 +66,45 @@ class BufferAllocations { // Returns the device address of buffer `buffer_index`. `buffer_index` must be // a valid index, i.e., in [0, buffer_count). This function returns null if // `buffer_index` is not assigned to a buffer address. - perftools::gputools::DeviceMemoryBase GetDeviceAddress( + se::DeviceMemoryBase GetDeviceAddress( BufferAllocation::Index buffer_index) const; // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. - perftools::gputools::DeviceMemoryBase GetDeviceAddress( + se::DeviceMemoryBase GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const; - perftools::gputools::DeviceMemoryBase GetTempBufferBase() const { - return temp_buffer_base_; - } + se::DeviceMemoryBase GetTempBufferBase() const { return temp_buffer_base_; } // Tears down all buffers allocated by this object that are not in // `live_addresses`. - tensorflow::Status TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment); + Status TearDown(const std::set& live_addresses); private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, - DeviceMemoryAllocator* memory_allocator) + DeviceMemoryAllocator* memory_allocator, + const BufferAssignment* buffer_assignment) : buffers_(buffer_count), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + buffer_assignment_(buffer_assignment) {} // Sets the device address of buffer `buffer_index`. void SetBuffer(BufferAllocation::Index buffer_index, - perftools::gputools::DeviceMemoryBase buffer); + se::DeviceMemoryBase buffer); // An array of device pointers that stores the address of each buffer // indexed by Index. Each element can point to a temporary buffer, an // input buffer, or nullptr if no buffer is needed for that Index. - std::vector buffers_; + std::vector buffers_; // The base address of the memory block that contains all temporary buffers. - perftools::gputools::DeviceMemoryBase temp_buffer_base_; + se::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; + const BufferAssignment* buffer_assignment_; + bool torn_down_ = false; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 790ca535b11ee47724ef6227de40726d940d6153..77a48965e031349b045a956fd3f28c58607328e5 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -35,18 +35,18 @@ ConditionalThunk::ConditionalThunk( true_thunk_(std::move(true_thunk_sequence), hlo), false_thunk_(std::move(false_thunk_sequence), hlo) {} -Status ConditionalThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); - TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); +Status ConditionalThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); return Status::OK(); } Status ConditionalThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream) { // Copy the predicate value from device. bool predicate; - perftools::gputools::DeviceMemoryBase predicate_address = + se::DeviceMemoryBase predicate_address = buffer_allocations.GetDeviceAddress(predicate_buffer_index_); stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool)); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index 7725c46a3b4b51af34a4dd977885353ff32c21f6..ee03865d174469285a9e98b8a30fea90d997df37 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -47,9 +47,10 @@ class ConditionalThunk : public Thunk { ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: BufferAllocation::Slice predicate_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 461747b699b542ae0c8735aea34cc9e57c1fb387..f0881124128c9b043392ffc4fa3aee2cd5b754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -25,17 +25,10 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { using se::dnn::AlgorithmDesc; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; ConvolutionThunk::ConvolutionThunk( CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 900d9cb6243088b56a1825fb3ab8c06cf8d74726..6d845025b1aef2b0a5f147401b6db0598ba94d6d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -66,23 +66,21 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: class ScratchAllocator; - Status Convolve( - const perftools::gputools::dnn::BatchDescriptor& input_descriptor, - perftools::gputools::DeviceMemory input_data, - const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, - perftools::gputools::DeviceMemory filter_data, - const perftools::gputools::dnn::BatchDescriptor& output_descriptor, - perftools::gputools::DeviceMemory output_data, - const perftools::gputools::dnn::ConvolutionDescriptor& - convolution_descriptor, - const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, - perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator, - perftools::gputools::dnn::ProfileResult* profile_result); + Status Convolve(const se::dnn::BatchDescriptor& input_descriptor, + se::DeviceMemory input_data, + const se::dnn::FilterDescriptor& filter_descriptor, + se::DeviceMemory filter_data, + const se::dnn::BatchDescriptor& output_descriptor, + se::DeviceMemory output_data, + const se::dnn::ConvolutionDescriptor& convolution_descriptor, + const se::dnn::AlgorithmConfig& algorithm_config, + se::Stream* stream, ScratchAllocator* scratch_allocator, + se::dnn::ProfileResult* profile_result); const CudnnConvKind convolution_kind_; diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index f4498663b1c039b3175376baf8f27c4ecec678ec..ee38c0318a878c7bcdc02afdcd146bfb4498d9a2 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -29,13 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - perftools::gputools::DeviceMemoryBase destination_data = +Status HostToDeviceCopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( @@ -47,15 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - perftools::gputools::DeviceMemoryBase destination_data = +Status DeviceToDeviceCopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); - perftools::gputools::DeviceMemoryBase source_data = + se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index e2783fd255239d31edc89701ea208f33ebb8d3fb..8b128386f61636de9ac41e856a2b00c578e05735 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -39,9 +39,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const void* source_address_; @@ -63,9 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 58d9c8caff31e878487fbef01afce566e6187fd9..68099fd63847ef9993f9bc7ac0e28b2939631b35 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -28,7 +28,6 @@ limitations under the License. namespace xla { namespace gpu { -namespace se = ::perftools::gputools; namespace dnn = se::dnn; static std::pair> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; -ScratchAllocator::~ScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -se::port::StatusOr> ScratchAllocator::AllocateBytes( +StatusOr> ScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -76,29 +62,30 @@ se::port::StatusOr> ScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } // Determines whether we can safely perform a winograd non-fused convolution for // the given input and output shapes. This works around b/68264959, an integer // overflow in cuDNNv5 and cuDNNv6. -// -// TODO(jlebar): We shouldn't need this check for cuDNNv7. -bool ShouldIncludeWinogradNonfusedAlgo( - const Shape& input_shape, const Shape& output_shape, - const ConvolutionDimensionNumbers& dnums) { +bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, + const Shape& output_shape, + const ConvolutionDimensionNumbers& dnums, + se::StreamExecutor* stream_exec) { + // Skip this check for cudnn7 and newer. + auto version = + stream_exec->AsDnn()->GetVersion(); + if (version.ok() && version.ValueOrDie().major_version() >= 7) { + return true; + } + int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); @@ -118,20 +105,20 @@ bool ShouldIncludeWinogradNonfusedAlgo( std::vector GetAlgorithms(CudnnConvKind kind, bool with_winograd_nonfused, - se::StreamExecutor* stream_exec_) { + se::StreamExecutor* stream_exec) { std::vector algorithms; switch (kind) { case CudnnConvKind::kBackwardFilter: - CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms( + CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( with_winograd_nonfused, &algorithms)); break; case CudnnConvKind::kBackwardInput: - CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms( + CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( with_winograd_nonfused, &algorithms)); break; case CudnnConvKind::kForward: - CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); + CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, + &algorithms)); break; } @@ -193,24 +180,44 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // We don't put any data in these buffers, because (in theory, anyway) the // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - se::port::StatusOr input_buf = + StatusOr maybe_input_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(input_shape)); - se::port::StatusOr filter_buf = + StatusOr maybe_filter_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(filter_shape)); - se::port::StatusOr output_buf = + StatusOr maybe_output_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(output_shape)); - if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || + !maybe_output_buf.ok()) { LOG(WARNING) << "Couldn't allocate space for input/filter/output of convolution " << instr->ToString() << ". Falling back to default algorithm."; return nullopt; } - const bool use_winograd_nonfused = - ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); + DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); + DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); + DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the inputs + // might affect performance if e.g. the inputs contain denormals, and this is + // easy enough. + if (!stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone() + .ok()) { + LOG(WARNING) + << "Couldn't zero out input/filter/output buffer for convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + + const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( + input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; @@ -221,12 +228,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - input_buf.ValueOrDie(), filter_buf.ValueOrDie(), - output_buf.ValueOrDie(), &scratch_allocator, window, - dnums, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); @@ -310,21 +317,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( Shape new_call_shape = ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - HloInstruction* algorithm_hlo = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); - HloInstruction* tensor_ops_enabled_hlo = - computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(tensor_ops_enabled))); + + CudnnConvBackendConfig backend_config; + backend_config.set_algorithm(algorithm); + backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, - tensor_ops_enabled_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1)}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( instr->convolution_dimension_numbers()); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely // (conv_result, u8[0]). diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 516210ec2e500cf03774d27408300ac3346e7b4f..bc5d1ce94afd2075a006899f0f6bcf64352e5e99 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -33,9 +33,8 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - CudnnConvolutionAlgorithmPicker( - perftools::gputools::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) + CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) : stream_exec_(stream_exec), allocator_(allocator) {} tensorflow::StringPiece name() const override { @@ -52,7 +51,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); - perftools::gputools::StreamExecutor* stream_exec_; // never null + se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index e4ae839e1dd4cb3a744a3f6a3329cabdaeb3f38d..0645fbb3ad39f1f1649caf45a6068b5a196c30b9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -22,8 +24,6 @@ namespace xla { namespace gpu { namespace { -namespace se = ::perftools::gputools; - using se::DeviceMemory; using se::DeviceMemoryBase; using se::Stream; @@ -115,8 +115,17 @@ Status RunCudnnConvolution( // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) + input_descriptor.set_layout(input_dl) .set_feature_map_count( input_shape.dimensions(dnums.input_feature_dimension())) .set_count(input_shape.dimensions(dnums.input_batch_dimension())); @@ -128,7 +137,7 @@ Status RunCudnnConvolution( } FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( @@ -151,7 +160,7 @@ Status RunCudnnConvolution( } BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) + output_descriptor.set_layout(output_dl) .set_feature_map_count( output_shape.dimensions(dnums.output_feature_dimension())) .set_count(output_shape.dimensions(dnums.output_batch_dimension())); @@ -215,14 +224,12 @@ string CudnnConvKindToString(CudnnConvKind kind) { Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::DeviceMemoryBase scratch_buf, const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result) { + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, output_buf, @@ -232,14 +239,12 @@ Status RunCudnnConvolution( Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::ScratchAllocator* scratch_allocator, - const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result) { + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); CHECK(output_primitive_type == F32 || output_primitive_type == F16) << ShapeUtil::HumanString(output_shape); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 3dbfa2730da359d3c7937140508017c4a7b02d6c..944e4ac686d45408b08ff1faa321510c1c8920ba 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -72,25 +72,21 @@ string CudnnConvKindToString(CudnnConvKind kind); // that size, if you like. Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::DeviceMemoryBase scratch_buf, const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::ScratchAllocator* scratch_allocator, - const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 5af7a77ea858563fbea05af8efd54f96a74aee93..b812dd7d3fbb25f279e87f79b647e299f29073ea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -53,11 +53,17 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrAppend; +namespace { // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value); + if (operand->opcode() == HloOpcode::kConstant && + operand->literal().IsAllFloat(value)) { + return true; + } + return operand->opcode() == HloOpcode::kBroadcast && + IsFPLiteralWithValue(operand->operand(0), value); } +} // namespace GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, @@ -227,6 +233,11 @@ StatusOr GpuElementalIrEmitter::EmitLog( return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitSin( PrimitiveType prim_type, llvm::Value* value) const { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); @@ -242,6 +253,11 @@ StatusOr GpuElementalIrEmitter::EmitExp( return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 77d4569b1e8e398005e8f517ff086a77aedd382d..91f4d960aa62fff3e0699ece37a8c74d7dcf2f59 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -64,6 +64,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const override; @@ -73,6 +76,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 66931bdc8b1030b2b2e7731ce6327c1e908d4ee6..e14ee6918bf148861ecccac99355fccf7ae93103 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -24,8 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -33,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -FftScratchAllocator::~FftScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } -se::port::StatusOr> FftScratchAllocator::AllocateBytes( +StatusOr> FftScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -60,18 +47,14 @@ se::port::StatusOr> FftScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate %lld bytes on device %d.", byte_size, - device_ordinal_); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } namespace { @@ -123,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type, input_shape_(input_shape), output_shape_(output_shape) {} -tensorflow::Status FftThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -224,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream( LOG(FATAL) << "unsupported fft type"; } if (launch_ok) { - return tensorflow::Status::OK(); + return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, FftTypeToString(fft_type_).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 52fb8c376d7acea0f15aaa865c23fa2382717338..b0a22564f3a09bb67a3c01723f6e37c604656d45 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -34,24 +34,22 @@ namespace gpu { // released on destruction. // // Not thread-safe in that AllocateBytes, destructor are not locked. -class FftScratchAllocator : public perftools::gputools::ScratchAllocator { +class FftScratchAllocator : public se::ScratchAllocator { public: FftScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator); - ~FftScratchAllocator() override; - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; + int64 GetMemoryLimitInBytes(se::Stream* stream) override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - perftools::gputools::port::StatusOr> - AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; @@ -73,17 +71,16 @@ class FftThunk : public Thunk { FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ // Does the FFT for the thunk on "stream". - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: - const perftools::gputools::fft::Type fft_type_; + const se::fft::Type fft_type_; const std::vector fft_length_; float scale_factor_; - std::unique_ptr fft_plan_; + std::unique_ptr fft_plan_; const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice output_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 283d21ca222a236a69e4bab1b6504665d4d1cdd3..b36539e0cb8d0a2f4758dd90acbdd8fc7181b8ca 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -30,20 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit, body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); +Status ForThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + return Status::OK(); } -tensorflow::Status ForThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { for (int64 i = 0; i < loop_limit_; ++i) { // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 832494d17e9c4e1d9e92e18ef331df1cf3689024..41ddfe0ceb1d0516c1c64feca53212a925632209 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,10 +36,10 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index c137fbc97e29e24edb3603c611a5c8f093bc62a6..3cd30b754c3242f00c704de1afab2282ed827b41 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -45,6 +45,7 @@ void MaybeResolveTupleElements(HloInstruction* instruction, // Returns the bytes read by fusion parameter 'param', by returning the byte // size of 'param' shape (or the cumulative byte sizes of all leaf tuple // elements if 'param' is tuple-shaped). +// // In the special case where all users of 'param' (or all users of a leaf // tuple element if 'param' is tuple-shaped) are Slice instructions, the size // of each slice instruction is accumulated instead, to give a more accurate @@ -63,11 +64,10 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (std::all_of(instruction->users().begin(), instruction->users().end(), - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -199,6 +199,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++total_visited_; // Skip 'fusion' instruction if there are no users into which we can merge. if (fusion->users().empty()) { + VLOG(3) << "Not merging " << fusion->name() << ": Has no users."; ++num_fail_no_users_; return Status::OK(); } @@ -208,24 +209,27 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Input fusion instructions need to be rooted at a particular HLO (e.g. // kReduce), so they shouldn't be further fused either. if (fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + VLOG(3) << "Not merging " << fusion->name() << ": Is not loop fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); } // Skip multiple output fusion. It's not yet supported. if (fusion->IsMultiOutputFusion()) { + VLOG(3) << "Not merging " << fusion->name() << ": Is multi-output fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); } // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!std::all_of(fusion->users().begin(), fusion->users().end(), - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == - HloInstruction::FusionKind::kLoop; - })) { + if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kFusion && + (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Some of its users are not loop/input fusion kernels."; ++num_fail_merge_all_users_; return Status::OK(); } @@ -233,18 +237,17 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if any of its fused instructions are expensive. // This is done to avoid the duplication of expensive instructions, which // would occur if 'fusion' were merged into multiple users. + // // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (!std::all_of(fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), - [](const HloInstruction* instruction) { - if (instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction)) { - return false; - } - return true; - })) { + if (c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; return Status::OK(); } @@ -253,6 +256,8 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // exceeds the threshold value. if (CalculateFlopsToBytesRatio(fusion) > FusionMerger::GetThresholdFlopsToBytesRatio()) { + VLOG(3) << "Not merging " << fusion->name() + << ": flops-to-bytes ratio is not favorable."; ++num_fail_flops_to_byte_ratio_; return Status::OK(); } @@ -265,6 +270,9 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { const double merged_to_current_bytes_ratio = merged_bytes_transferred / std::max(1.0, current_bytes_transferred); if (merged_to_current_bytes_ratio > 1.10) { + VLOG(3) << "Not merging " << fusion->name() + << ": merged-to-current-bytes-ratio of " + << merged_to_current_bytes_ratio << " is not favorable."; ++num_fail_net_bytes_transferred_ratio_; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index deef5966b80d1b7f16e9982eed9ac5c7131e9d73..b22bb1d39ba177ef42673c7a3755694b43c15d14 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -23,250 +25,12 @@ namespace xla { namespace gpu { namespace { -class FusionMergerTest : public HloTestBase { - protected: - FusionMergerTest() : module_(CreateNewModule()) {} - - // Builds the following computation: - // - // Param - // / | \ - // / | \ - // OnesVec GTE(0) GTE(1) GTE(2) - // \ / \ / - // Add Add OnesVec - // \ / \ / - // \ Add Mul OnesVec - // \ | | / - // \ Mul Add - // \ | / - // \ | / - // Tuple - // - HloComputation* BuildComputation0() { - auto builder = HloComputation::Builder(TestName() + ".Computation0"); - // Create param instruction to access computation state. - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape3_, "param")); - - // Create GetTupleElement instructions for each tuple element. - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 1)); - auto gte2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 2)); - - // Create const vector of ones to be used in element-wise computations. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // Create simple fusable computation for tuple element 0 (wont get merged). - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, one_vec, gte0)); - - // Create fusable computation which is dependent on second and third tuple - // elements (will initially be fused on its own). - auto add1 = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte1, gte2)); - - // Create two sub-computations, both of which are users of 'add1'. - - // First sub-computation: out1 = Mul(Add(add1, one_vec), one_vec) - auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, add1, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add2, one_vec)); - - // Second sub-computation: out2 = Add(Mul(add1, one_vec), one_vec) - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add1, one_vec)); - auto out2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1, out2})); - return module_->AddEntryComputation(builder.Build()); - } - - // Builds the following computation: - // - // Param - // / \ - // GTE(0) GTE(1) - // | | \ / - // | | Mul - // \ \ | - // \ Mul - // \ | - // OnesVec Mul OnesVec - // \ / \ / - // OnesVec Add Mul OnesVec - // \ | | / - // Mul Add - // \ / - // \ / - // Tuple - // - HloComputation* BuildComputation1() { - auto builder = HloComputation::Builder(TestName() + ".Computation1"); - Shape tuple_shape2_ = ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); - // Create param instruction to access computation state. - auto state = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape2_, "state")); - - // Create shared sub-computation (will initially be fused on its own). - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 2)); - // Calculate the flops we need to generate for this shared computation - // to exceed the threshold flops_to_bytes_ratio. - // Note that bytes transferred is multiplied by 3 because there are two - // operands and one output of size 'data_shape_'. - const int64 flops_needed = FusionMerger::GetThresholdFlopsToBytesRatio() * - ShapeUtil::ByteSizeOf(data_shape_) * 3; - const int64 vec_elements = ShapeUtil::ElementsIn(data_shape_); - const int64 iters = (flops_needed + vec_elements - 1) / vec_elements; - - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, gte0, gte1)); - for (int i = 0; i < iters; ++i) { - mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, gte0, mul0)); - } - - // Create two sub-computations, both of which are users of 'mul0'. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // First sub-computation: out0 = Mul(Add(mul0, one_vec), one_vec) - auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add0, one_vec)); - - // Second sub-computation: out1 = Add(Mul(mul0, one_vec), one_vec) - auto mul1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, mul0, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul1, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_->AddEntryComputation(builder.Build()); - } - - // Builds the following computation: - // - // Param - // / | | \ - // / | | \ - // / | | \ - // GTE(0) GTE(1) GTE(2) GTE(3) - // \ / / / - // Add / / - // \ / / - // Add / - // \ / - // \ / - // OnesVec Add OnesVec - // \ / \ / - // OnesVec Add Mul OnesVec - // \ | | / - // Mul Add - // \ / - // \ / - // Tuple - // - HloComputation* BuildComputation2(bool add_extra_input) { - auto builder = HloComputation::Builder(TestName() + ".Computation2"); - Shape state_shape = add_extra_input ? tuple_shape4_ : tuple_shape3_; - // Create param instruction to access computation state. - auto state = builder.AddInstruction( - HloInstruction::CreateParameter(0, state_shape, "state")); - - // Create GetTupleElement instructions for each tuple element. - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 1)); - auto gte2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 2)); - - // Create shared fusable computation that reduces its operands. - auto reduce0 = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1)); - auto reduce_out = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce0, gte2)); - if (add_extra_input) { - auto gte3 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 3)); - reduce_out = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce_out, gte3)); - } +namespace op = xla::testing::opcode_matchers; - // Create two fusable sub-computations which are dependent on shared - // computation 'reduce_out'. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // First sub-computation: out0 = Mul(Add(reduce_out, one_vec), one_vec) - auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce_out, one_vec)); - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add2, one_vec)); - - // Second sub-computation: out1 = Add(Mul(reduce_out, one_vec), one_vec) - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, reduce_out, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_->AddEntryComputation(builder.Build()); - } - - Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); - Shape tuple_shape2_ = ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); - Shape tuple_shape3_ = - ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); - Shape tuple_shape4_ = ShapeUtil::MakeTupleShape( - {data_shape_, data_shape_, data_shape_, data_shape_}); - - std::unique_ptr module_; -}; +class FusionMergerTest : public HloTestBase {}; // Tests that we can merge a fusion instruction that is below threshold. // -// Original computation: -// -// Param -// / | \ -// / | \ -// OnesVec GTE(0) GTE(1) GTE(2) -// \ / \ / -// Add Add OnesVec -// \ / \ / -// \ Add Mul OnesVec -// \ | | / -// \ Mul Add -// \ | / -// \ | / -// Tuple -// -// Computation after fusion passes: -// -// Param -// / \ -// Fusion3 Fusion2 -// | / \ -// \ Fusion0 Fusion1 -// \ | / -// \ | / -// Tuple -// // Computation after fusion merger pass (Fusion2 is merged into Fusion0 and // Fusion1): // Param @@ -276,19 +40,50 @@ class FusionMergerTest : public HloTestBase { // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto computation = BuildComputation0(); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); - // Run fusion merger pass, which should merge the shared fusion instruction - // into its two users. - EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); - - auto* root = computation->root_instruction(); + auto module = ParseHloString(R"( +HloModule MergeSharedFusionInstruction + +comp.3 { + constant.param_0 = f32[4]{0} parameter(0) + param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1) + get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0 + ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6) +} + +comp.2 { + param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1 + get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2 + ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5) +} + +comp.1 { + add.1.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3) +} + +comp { + add.1.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1) + ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY MergeSharedFusionInstruction.Computation0 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3 + fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2 + fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1 + fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6) +})") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + + auto* root = module->entry_computation()->root_instruction(); EXPECT_EQ(HloOpcode::kTuple, root->opcode()); // Check operand 0 (not merged). Should have 4 instructions. auto* operand0 = root->operand(0); @@ -307,156 +102,188 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { // Tests that we do not merge a fusion instruction that above flops to bytes // threshold. // -// Original computation: -// -// Param -// / \ -// GTE(0) GTE(1) -// | | \ / -// | | Mul -// \ \ | -// \ Mul -// \ | -// OnesVec Mul OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ | | / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes and fusion merger pass (Fusion2 is not -// merged because it exceeds the threshold flops to bytes ratio). -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - BuildComputation1(); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = ParseHloString(R"( +HloModule FlopsToBytesRatioThresholdExceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.3 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.4 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + multiply.29 = f32[4]{0} multiply(get-tuple-element.3, get-tuple-element.4) + multiply.30 = f32[4]{0} multiply(get-tuple-element.3, multiply.29) + multiply.31 = f32[4]{0} multiply(get-tuple-element.3, multiply.30) + multiply.32 = f32[4]{0} multiply(get-tuple-element.3, multiply.31) + multiply.33 = f32[4]{0} multiply(get-tuple-element.3, multiply.32) + multiply.34 = f32[4]{0} multiply(get-tuple-element.3, multiply.33) + multiply.35 = f32[4]{0} multiply(get-tuple-element.3, multiply.34) + multiply.36 = f32[4]{0} multiply(get-tuple-element.3, multiply.35) + multiply.37 = f32[4]{0} multiply(get-tuple-element.3, multiply.36) + multiply.38 = f32[4]{0} multiply(get-tuple-element.3, multiply.37) + multiply.39 = f32[4]{0} multiply(get-tuple-element.3, multiply.38) + multiply.40 = f32[4]{0} multiply(get-tuple-element.3, multiply.39) + ROOT multiply.41 = f32[4]{0} multiply(get-tuple-element.3, multiply.40) +} + +comp.1 { + multiply.12.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.3 = f32[4]{0} add(multiply.12.param_1.1, constant.param_1.3) + ROOT multiply.16 = f32[4]{0} multiply(add.3, constant.param_1.3) +} + +comp { + multiply.12.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.15 = f32[4]{0} multiply(multiply.12.param_1, constant.param_1.1) + ROOT add.2 = f32[4]{0} add(multiply.15, constant.param_1.1) +} + +ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the flops/bytes of the // shared fusion instruction exceeds the threshold ratio, and therefore // cannot be merged with other fusion instructions. - EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is exceeded. // -// Original computation: -// -// Param -// / | | \ -// / | | \ -// / | | \ -// GTE(0) GTE(1) GTE(2) GTE(3) -// \ / / / -// Add / / -// \ / / -// Add / -// \ / -// \ / -// OnesVec Add OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ | | / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes and fusion merger pass. Fusion2 is not -// merged because it exceeds the threshold bytes transferred. This is because -// the bytes read by Fusion2 (when replicated if the instruction is merged -// into Fusion0 and Fusion1) would exceed the bytes transferred threshold. -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is not merged because it exceeds the threshold bytes transferred. +// This is because the bytes read by Fusion2 (when replicated if the instruction +// is merged into Fusion0 and Fusion1) would exceed the bytes transferred +// threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - BuildComputation2(/*add_extra_input=*/true); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = ParseHloString(R"( +HloModule BytesTransferredThresholdExeceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.8 = f32[4]{0} get-tuple-element(state.param_1.1), index=1 + add.9 = f32[4]{0} add(get-tuple-element.7, get-tuple-element.8) + get-tuple-element.9 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + add.10 = f32[4]{0} add(add.9, get-tuple-element.9) + get-tuple-element.10 = f32[4]{0} get-tuple-element(state.param_1.1), index=3 + ROOT add.11 = f32[4]{0} add(add.10, get-tuple-element.10) +} + +comp.1 { + add.2.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.6 = f32[4]{0} add(add.2.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.6, constant.param_1.3) +} + +comp { + add.2.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.2.param_1, constant.param_1.1) + ROOT add.5 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY BytesTransferredThresholdExeceeded.Computation2 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would increase. - EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is not exceeded. // -// Original computation: -// -// Param -// / | \ -// / | \ -// / | \ -// GTE(0) GTE(1) GTE(2) -// \ / / -// Add / -// \ / -// OnesVec Add OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ / \ / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes: -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// -// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and -// Fusion1, because bytes read from Param by Fusion2 is reduced for this test -// which makes the merge operation into its operand below the bytes -// transferred threshold. -// -// Param -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is merged into Fusion0 and Fusion1, because bytes read from Param by +// Fusion2 is reduced for this test which makes the merge operation into its +// operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - BuildComputation2(/*add_extra_input=*/false); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = ParseHloString(R"( +HloModule BytesTransferredThresholdNotExeceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1 + add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6) + get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7) +} + +comp.1 { + add.1.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3) +} + +comp { + add.1.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1) + ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would not increase. - EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); +} + +// Check that we're willing to merge f1_computation into f2_computation, even +// though f2 is an input fusion node. +TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { + auto module = ParseHloString(R"( + HloModule m + + f1_computation { + f1_p0 = f32[10]{0} parameter(0) + ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[10]{0} parameter(0) + f2_mul = f32[10]{0} multiply(f2_p0, f2_p0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10]{0} parameter(0) + f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Fusion(op::Parameter())); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index ba482793e7632f0f423cc9da0dd9620bdf29c642..79fca43d022816645b8a07b9e806fe9cc3745e7c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -22,8 +22,6 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -49,7 +47,7 @@ struct MatrixDescriptor { // rhs_matrix, and stores the result to output_matrix. template bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); se::DeviceMemory lhs_data(lhs_matrix.data); @@ -65,7 +63,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, return stream ->ThenBlasGemm( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, &output_data, /*leading dim of output=*/output_matrix.num_rows) @@ -89,7 +87,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, template bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, + MatrixDescriptor output_matrix, double alpha, se::blas::ComputationType computation_type, se::blas::AlgorithmType algorithm, se::Stream* stream, se::blas::ProfileResult* output_profile_result) { @@ -108,11 +106,13 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, return stream ->ThenBlasGemmWithAlgorithm( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows, - computation_type, algorithm, output_profile_result) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/static_cast(alpha), lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, + /*beta=*/static_cast(0.0f), &output_data, + /*leading dim of output=*/output_matrix.num_rows, computation_type, + algorithm, output_profile_result) .ok(); } @@ -125,8 +125,8 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, template StatusOr DoGemmAutotune( MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::blas::ComputationType computation_type, - se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, + se::blas::ComputationType computation_type, se::Stream* stream) { std::vector algorithms; CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); @@ -138,8 +138,8 @@ StatusOr DoGemmAutotune( // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - computation_type, algorithm, stream, - &profile_result)); + alpha, computation_type, algorithm, + stream, &profile_result)); if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { @@ -161,6 +161,8 @@ StatusOr DoGemmAutotune( // DoGemm/DoGemmWithAlgorithm/DoGemmAutotune. auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { switch (type) { + case F16: + return &DoGemm; case F32: return &DoGemm; case F64: @@ -172,6 +174,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { auto GetGemmWithAlgorithmFn(PrimitiveType type) -> decltype(&DoGemmWithAlgorithm) { switch (type) { + case F16: + return &DoGemmWithAlgorithm; case F32: return &DoGemmWithAlgorithm; case F64: @@ -182,6 +186,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) } auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { switch (type) { + case F16: + return &DoGemmAutotune; case F32: return &DoGemmAutotune; case F64: @@ -196,6 +202,10 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { // separately from the precision of the inputs and result. se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { switch (type) { + case F16: + // Use F32 as computation type for F16 as we currently only implement the + // cuDNN pseudo half configuration for half precision. + return se::blas::ComputationType::kF32; case F32: return se::blas::ComputationType::kF32; case F64: @@ -205,14 +215,33 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { + if (hlo_instruction.opcode() == HloOpcode::kDot) { + return hlo_instruction.dot_dimension_numbers(); + } + CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); + CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); + CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), + HloOpcode::kMultiply); + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = + hlo_instruction.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo_instruction.fused_expression_root()->operand(1); + } + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + return dot->dot_dimension_numbers(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, const HloInstruction* hlo_instruction) + const Shape& output_shape, double alpha, + const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), @@ -220,11 +249,10 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs) {} + alpha_(alpha) {} -tensorflow::Status GemmThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -272,10 +300,12 @@ tensorflow::Status GemmThunk::ExecuteOnStream( shape.dimensions(!is_row_major)); }; - const MatrixDescriptor lhs_descriptor = - make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); - const MatrixDescriptor rhs_descriptor = - make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + + const MatrixDescriptor lhs_descriptor = make_descriptor( + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + const MatrixDescriptor rhs_descriptor = make_descriptor( + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. @@ -290,7 +320,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (autotune_it == autotune_results_.end()) { StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - computation_type, stream); + alpha_, computation_type, stream); autotune_it = autotune_results_.insert({device_name, best_algorithm}).first; @@ -311,12 +341,15 @@ tensorflow::Status GemmThunk::ExecuteOnStream( VLOG(2) << "Using algorithm " << algorithm << " chosen by autotuning on GemmThunk " << this; return GetGemmWithAlgorithmFn(element_type)( - lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm, - stream, + lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, + algorithm, stream, /*output_profile_result=*/nullptr); } + + // Autotune will fail when CUDA 8 and GPU sm_50 or older are used. + // Use the older Gemm API in this case. return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - stream); + alpha_, stream); }; bool launch_ok; @@ -335,7 +368,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (!launch_ok) { return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 8c6a1f51a8a09ef78950dfe7e89994a3fe247f49..7a4830d64e7caef5a1170cbdbf8ab373fdaf16e2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -34,29 +34,26 @@ namespace gpu { // This is thread-compatible. class GemmThunk : public Thunk { public: - // Constructs a thunk that computes "output = lhs rhs" using BLAS gemm. - // transpose_lhs and transpose_rhs indicate whether gemm should transpose the - // lhs and rhs operand. hlo_instruction is as in Thunk. + // Constructs a thunk that computes "output = (lhs rhs) * alpha" using + // BLAS gemm. hlo_instruction is as in Thunk. alpha is a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, + const Shape& output_shape, double alpha, const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; // Does the gemm operation for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to // introduce noise in our results. - bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream* stream) override { + bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override { return autotune_results_.count( stream->parent()->GetDeviceDescription().name()) != 0; } @@ -70,15 +67,13 @@ class GemmThunk : public Thunk { const Shape rhs_shape_; const Shape output_shape_; - const bool transpose_lhs_; - const bool transpose_rhs_; + const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune // results. The map's value is the best algorithm we've found for this thunk // on this device, or an error if none of the algorithms worked and we should // use the regular gemm without an algorithm. - std::unordered_map> + std::unordered_map> autotune_results_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 28ebd034ee0c89137f4e6eb417d8a37f4a00af7a..afefc740d707cd6fd01420e950eafd37abe80119 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -33,8 +33,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -71,6 +74,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -89,8 +94,6 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -100,7 +103,7 @@ namespace gpu { namespace { -using tensorflow::port::Tracing; +namespace tracing = tensorflow::tracing; // Returns the directory containing nvvm libdevice files. config_cuda_data_dir // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the @@ -128,9 +131,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { +Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -161,8 +163,10 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); + + // Rewrite gather ops into smaller ones. + pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -172,10 +176,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); } pipeline.AddPass( @@ -197,18 +203,28 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_device_entry_computation_layout(), stream_exec); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); // Choose the fastest algorithm for each conv. // - // In theory doing this here is way too early: It needs to happen after - // layout assignment, because the layout of the inputs/outputs affects the - // speed of the conv. But currently we only allow only one input/output - // layout when calling cudnn, so there's no ambiguity. - // - // We pick the algorithm at this early stage so we can generate better HLO. - // After CudnnConvolutionRewriter, our convolutions are CustomCalls which - // return a tuple (conv_result, scratch_memory), and the each conv uses 0 - // bytes of scratch: + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: // // customcall = (f32[...], f32[0]) // return gte(customcall, 0) @@ -224,20 +240,16 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after layout - // assignment, fusion would already have run, and the gte(customcall, 0) - // would probably already be into a fusion node. We can't simplify across - // HloComputation boundaries, so in this case we wouldn't be able to - // simplify away the new_tuple bits. - // - // We'll need to revisit this if we ever allow multiple layouts for the - // inputs/outputs of a cudnn convolution. + // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -247,6 +259,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); + fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); @@ -263,12 +276,21 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } - return tensorflow::Status::OK(); + + { + // Do an aggressive LICM pass over while loops. In particular, this hoists + // constants that were sunk by WhileLoopConstantSinking. Leaving them in + // the while loop may result in unnecessary copies. + HloPassPipeline pipeline("while-loop-licm"); + pipeline.AddPass(true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + return Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output @@ -277,15 +299,6 @@ tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); - pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }); - pipeline.AddPass(/*is_layout_sensitive=*/true); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which // materializes the value) or missing a necessary copy (later pass removes an @@ -399,7 +412,7 @@ void WarnIfBadDriverJITVersion() { // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, int cc_minor) { - Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); + tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -470,8 +483,8 @@ StatusOr> GpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); - Tracing::TraceMe annotation("HLO Transforms", module->name(), - /*is_expensive=*/true); + tracing::ScopedActivity activity("HLO Transforms", module->name(), + /*is_expensive=*/true); TF_RETURN_IF_ERROR( OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); @@ -658,6 +671,8 @@ StatusOr> GpuCompiler::RunBackend( if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + cost_analysis.set_bytes_per_second( + stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = @@ -679,7 +694,7 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, int cc_major, int cc_minor) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); - Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); + tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; // Pointers into compilation_cache_ where the ptx and (optional) cubin are @@ -764,9 +779,9 @@ se::Platform::Id GpuCompiler::PlatformId() const { } // namespace xla static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(se::cuda::kCudaPlatformId, []() { - return xla::MakeUnique(); - }); + xla::Compiler::RegisterCompilerFactory( + stream_executor::cuda::kCudaPlatformId, + []() { return xla::MakeUnique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index c352d4d8462fadb266c55ad437de998e86a6528e..f3b02ae5d8867bdf1d970e809bff95a15d9f54d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -45,25 +45,23 @@ class GpuCompiler : public LLVMCompiler { // Bring in // StatusOr>> Compile( // std::vector> modules, - // std::vector> + // std::vector> // stream_execs) using LLVMCompiler::Compile; StatusOr> RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> module, AotCompilationOptions const& options) override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { // Capture just the pointer size, not the entire GpuCompiler object. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 9db85bc788bde46c890a46ce9b0902ddce3f5675..c5ccdd4a7dcec02ddab8a1f748659de41f6202d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,14 +78,13 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } - } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last two arguments to a CUDNN convolution are two HLO constants for - // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo)) { - // For all other library calls, materialize all the operands into memory. + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum) { + // For all other library calls and cross-replica-sum, materialize all the + // operands into memory. (Cross-replica-sum gets its constant args + // materialized even if it's not implemented as a libcall to simplify the + // implementation. It's slower, but we can constant fold away constant + // args *anyway*, so we just need to make it work.) for (int64 i = 0; i < hlo->operand_count(); ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 623d6714de501000e38b7698620925f66425f157..25d8f720ea4791a4c94efcad6909cd0c113fbe70 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -32,26 +32,29 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { namespace { +using tensorflow::tracing::ScopedAnnotation; + // A helper class for profiling HLO in the course of GPU program execution. // All of the profiling is guarded internally, to avoid the caller needing to // have lots of conditionals sprinkled around. class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, - se::Stream* stream, - const HloComputation* computation) + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) : do_profile_(do_profile), profile_(profile), stream_(stream), + sub_streams_(sub_streams), computation_(computation) { if (do_profile_) { clock_rate_ghz_ = @@ -70,6 +73,7 @@ class HloExecutionProfiler { CHECK(!finished_execution_) << "Call FinishExecution only once!"; finished_execution_ = true; if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(execution_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( @@ -88,6 +92,7 @@ class HloExecutionProfiler { // that the hlo_instruction took to execute in the profile. void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(per_op_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( @@ -100,6 +105,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; @@ -131,9 +137,10 @@ Status GpuExecutable::ExecuteThunks( const BufferAllocations& buffer_allocations, bool block_host_until_done, HloExecutionProfile* hlo_execution_profile) { se::Stream* main_stream = run_options->stream(); + se::StreamExecutor* executor = main_stream->parent(); std::pair stream_compute_compatibility; - main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + executor->GetDeviceDescription().cuda_compute_capability( &stream_compute_compatibility.first, &stream_compute_compatibility.second); TF_RET_CHECK(stream_compute_compatibility == compute_capability_) @@ -147,26 +154,44 @@ Status GpuExecutable::ExecuteThunks( LOG(WARNING) << "PROFILING: profiling is enabled"; } - HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, - hlo_module_->entry_computation()); - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // Stream 0 indicates `main_stream` and substreams start from stream 1. std::vector::SmartPtr> sub_streams; + sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); - TF_ASSIGN_OR_RETURN( - sub_streams.back(), - run_options->BorrowStream(main_stream->parent()->device_ordinal())); + TF_ASSIGN_OR_RETURN(sub_streams.back(), + run_options->BorrowStream(executor->device_ordinal())); } - // The next event enqueued on stream N must not run until the thunk at - // last_blocking_thunk_for_stream[N] completes. - std::map last_blocking_thunk_for_stream; + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, + sub_streams, hlo_module_->entry_computation()); + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + + // This top-level trace serves two purposes: + // 1) It marks the scope of the whole XLA module. + // 2) It tells us whether tracing is enabled. We use this to avoid the + // expensive HloInstruction::ToString() calls inside the loop below if + // tracing is disabled. + ScopedAnnotation top_level_annotation(hlo_module_->name(), "XLA GPU module"); + std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - TF_RETURN_IF_ERROR(thunk->Initialize(*this)); + // Annotate execution of this op if tracing was enabled when we started + // running this module. If tracing is enabled *while* we're running the + // module, we won't get any data, but that's probably an OK trade-off. + // + // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), + // since we expect it to be an expensive call? + tensorflow::gtl::optional op_annotation; + if (top_level_annotation.IsEnabled()) { + op_annotation.emplace( + thunk->hlo_instruction() != nullptr + ? thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()) + : "", + "XLA op"); + } + + TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -176,18 +201,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - if (last_blocking_thunk_for_stream.count(stream_no)) { - stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, - last_blocking_thunk_for_stream[stream_no]) - .get()); - last_blocking_thunk_for_stream.erase(stream_no); - } - // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); - last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -195,22 +212,11 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { + if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); - - if (thunk->ShouldBlockFutureThunks()) { - // Set last_blocking_thunk_for_stream on all streams other than this one - // so that all other streams will wait for this thunk to complete before - // executing any events that occur later in the total order. - for (int32 i = 0; i < sub_streams.size() + 1; ++i) { - if (i != stream_no) { - last_blocking_thunk_for_stream[i] = thunk; - } - } - } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -247,7 +253,7 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } -StatusOr> GpuExecutable::ExecuteOnStream( +StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -262,23 +268,29 @@ StatusOr> GpuExecutable::ExecuteOnStream( ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { - // The caller must give us a buffer for ShapeIndex {} of every parameter. - // It can optionally give us a buffer for other ShapeIndices, but we - // ignore them: Because we can't rely on these sub-buffers' addresses - // being available, our generated code can't use them. Instead, it must - // chase pointers starting at the tuple root. - if (allocation.param_shape_index().empty()) { - auto param_no = allocation.parameter_number(); - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->root_buffer()); + auto param_no = allocation.parameter_number(); + se::DeviceMemoryBase buffer = + arguments[param_no]->buffer(allocation.param_shape_index()); + + // All top-level buffers and sub-buffers must have an explicit, non-null + // pointer, except for zero-sized buffers, which may be null. + if (buffer.is_null() && buffer.size() > 0) { + return FailedPrecondition( + "Cannot run XLA computation because pointer to (sub-)buffer at " + "index %s of parameter %lld was null. All pointers to " + "(sub-)buffers must not be null, unless the (sub-)buffer has zero " + "elements.", + allocation.param_shape_index().ToString().c_str(), param_no); } + + buffer_allocations_builder.RegisterBuffer(i, buffer); } } se::StreamExecutor* executor = run_options->stream()->parent(); TF_ASSIGN_OR_RETURN( auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); + buffer_allocations_builder.Build( + assignment_.get(), executor->device_ordinal(), memory_allocator)); bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); @@ -288,13 +300,13 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); - auto shaped_buffer = MakeUnique( - root->shape(), root->shape(), executor->platform(), device_ordinal); + ScopedShapedBuffer shaped_buffer(root->shape(), root->shape(), + memory_allocator, device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. std::set buffers_in_result; - TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus( [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); @@ -313,20 +325,19 @@ StatusOr> GpuExecutable::ExecuteOnStream( this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); CHECK(!slice.allocation()->is_entry_computation_parameter()); - perftools::gputools::DeviceMemoryBase src_base = + se::DeviceMemoryBase src_base = buffer_allocations->GetDeviceAddress(slice.index()); CHECK(!src_base.is_null() || src_base.size() == 0); *device_memory = src_base; buffers_in_result.insert(src_base); return Status::OK(); })); - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(buffers_in_result, *assignment_)); + TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); return std::move(shaped_buffer); } -StatusOr> GpuExecutable::ExecuteAsyncOnStream( +StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index b19cfd43debd0a5490495d176fa2f1fcd625da07..80ec38c3ac114fe4ad9d56784330c1144d913db1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -74,20 +74,15 @@ class GpuExecutable : public Executable { // ExecuteOnStream will fail if the compute capability of the stream doesn't // match the compute capability passed to this object's constructor. - StatusOr> ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on GPU executable. - return Unimplemented("Equality test on GPU executable is not implemented."); - } - private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 89f1e625884568bf7370b3801d851ef4846c2a98..8bf62dde8b9948375fc493fd1a524cfa7b062502 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,31 +18,72 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -// cuDNN convolutions are called with specific layouts on the input, output, -// and filter: -// -// input: DataLayout::kBatchDepthYX -// output: DataLayout::kBatchDepthYX -// filter: FilterLayout::kOutputInputYX -// -// The order dimensions in the constant name is major-to-minor (eg, the -// most-major dimension of the input is batch, most-minor is X). The -// specific dimension numbers these named dimensions correspond to is -// determined by the ConvolutionDimensionNumbers argument. Y is spatial -// dimension 0, and X is spatial dimension 1. -// -// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. -static Status AddBackendConstraintsToDnnConvCustomCall( +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::FilterLayout; + +static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} + +// Returns (input, filter, output) layouts. +static std::tuple +HeuristicLayoutAssignment(const HloInstruction* instr, + stream_executor::StreamExecutor* stream_executor) { + // DataLayout and FilterLayout uses weird enum names. Translations: + // N <=> Batch or Output + // C <=> Depth or Input + // H <=> Y + // W <=> X + // + // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + + // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x + // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version + // changes, as well as the hardware updates. + if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && + IsVoltaOrLater(*stream_executor))) { + return std::make_tuple(DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); + // For BackwardInput that has stride, full NHWC layouts run significantly + // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + // + // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, + // NHWC). + if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); +} + +// Adds layout constraints on the cudnn custom-call instruction. The layout +// constraints are represented in terms of minor_to_major fields of both +// operands and the output shape. Depending on the underlying algorithm, one of +// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. +Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloInstruction* instr, LayoutConstraints* constraints) { CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); Shape input_shape; @@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall( << instr->custom_call_target(); } - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instr->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; - --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + { + DataLayout input; + FilterLayout filter; + DataLayout output; + if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); + } else { + input = DataLayout::kBatchDepthYX; + filter = FilterLayout::kOutputInputYX; + output = DataLayout::kBatchDepthYX; + } + + TF_ASSIGN_OR_RETURN( + std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), + *output_shape.mutable_layout()), + StreamExecutorConvLayoutsToXlaLayouts( + instr->convolution_dimension_numbers(), input, filter, output)); } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that @@ -132,7 +159,13 @@ static Status AddBackendConstraintsToDnnConvCustomCall( Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { + // Add convolution constraints in reverse postorder that the earliest + // convolution layout propagates first. This reduces the likelihood of fusion + // nodes with copies. + auto post_order = constraints->computation()->MakeInstructionPostOrder(); + for (auto iterator = post_order.rbegin(); iterator != post_order.rend(); + ++iterator) { + HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 86a3a7111fd79494e469beecf3234f6cec9adb9c..ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { @@ -27,8 +28,10 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + se::StreamExecutor* stream_executor) + : LayoutAssignment(entry_computation_layout), + stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} protected: @@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment { LayoutConstraints* constraints) override; bool CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) override; + + private: + Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints); + + se::StreamExecutor* stream_executor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4c45d2e94aebce5496da94841f6a1ae9015615c1..e48165c1426ea04839c245bc20b851a0f1710246 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b4b4e20b633792de4251a4b0e89f4b579053ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { +namespace gpu { + +bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { + return !config.debug_options().xla_backend_extra_options().count( + "xla_gpu_experimental_conv_disable_layout_heuristic"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h new file mode 100644 index 0000000000000000000000000000000000000000..498d4a94955cb2c50e0b165f28ded44ac1c0bfff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +// Helper functions for querying options that are specific to the GPU backend. + +namespace xla { +namespace gpu { + +// Returns true if we should use heuristics to assign convolution layouts, as +// opposed to always assigning NCHW. +bool ConvUseLayoutHeuristic(const HloModuleConfig& config); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index af9897769fda371e47af06c19abce9a06015e094..7bb8df6581b49b1bf8c84a972f715e8dc119d8de 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -33,8 +33,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { // TODO(b/30467474) Once GPU infeed implementation settles, consider @@ -46,8 +44,8 @@ GpuTransferManager::GpuTransferManager() /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) .getPointerSize(0 /* default address space */)) {} -Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status GpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); @@ -153,8 +151,8 @@ static std::unique_ptr CreateGpuTransferManager() { } static bool InitModule() { - xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId, - &CreateGpuTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 9aa369c668364079504ead3491903e2590a142cc..09f8227f508a3159f3def285898e15bfad544552 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -36,21 +36,20 @@ class GpuTransferManager : public GenericTransferManager { GpuTransferManager(); ~GpuTransferManager() override {} - Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, - const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. StatusOr TransferBufferToInfeedInternal( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source); + se::StreamExecutor* executor, int64 size, const void* source); // Enqueues infeed data buffers with the infeed manager after their // transfer completes. - Status EnqueueBuffersToInfeed(perftools::gputools::StreamExecutor* executor, + Status EnqueueBuffersToInfeed(se::StreamExecutor* executor, std::vector buffers); TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 42c1539e86c2ab162fa473852b80b28b57d0e370..375709150e08996ea6a40f5e9e66a8f8d9287008 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -198,8 +199,8 @@ StatusOr> HloSchedule::Build( // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, - CreateMemoryMinimizingSequence( - *entry_computation, [pointer_size](const LogicalBuffer& buffer) { + ScheduleOneComputation( + *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); } else { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index ece9fa04dce3fd12713fb7e58097dc16ebba83df..45f0a1c645b2875cf90d2c11cfb66c3dd855d097 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -42,6 +42,14 @@ class HloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", config); + } + HloVec RemoveHlo(const HloVec& input, const std::unordered_set& remove) { HloVec result(input); @@ -65,9 +73,9 @@ TEST_F(HloScheduleTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -193,11 +201,11 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -259,24 +267,24 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index ee5b447c9cd0b1fde4d3a0943d5d4cb8cc5b3376..ae310beefad0c81c17fd4140b441b3a19a002e2c 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -19,8 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -51,13 +49,25 @@ void InfeedManager::EnqueueBuffers(const std::vector& buffers) { } InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); + bool became_empty = false; + InfeedBuffer* current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + current_buffer = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + dequeued_buffer_.insert(current_buffer); + if (enqueued_buffer_.empty()) { + became_empty = true; + } + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } } - InfeedBuffer* current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); return current_buffer; } @@ -90,6 +100,10 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } +void InfeedManager::RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); +} + InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index 73d5a5ce35497f156a181371bfb97fc37a8eb09e..a3fc15cfe36a490f38daabca9ff36fbb1012aead 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -46,7 +47,7 @@ namespace gpu { // the client. The client manages the memory of the buffer. class InfeedBuffer { public: - InfeedBuffer(perftools::gputools::StreamExecutor* executor, int64 length) + InfeedBuffer(se::StreamExecutor* executor, int64 length) : executor_(executor), length_(length) { device_memory_ = executor_->AllocateArray(length); CHECK(!device_memory_.is_null()); @@ -60,14 +61,12 @@ class InfeedBuffer { // client to manage memory for the infeed buffers. void Done() { delete this; } - perftools::gputools::DeviceMemoryBase* device_memory() { - return &device_memory_; - } + se::DeviceMemoryBase* device_memory() { return &device_memory_; } private: - perftools::gputools::StreamExecutor* executor_; // Not owned. + se::StreamExecutor* executor_; // Not owned. const int64 length_; - perftools::gputools::DeviceMemoryBase device_memory_; + se::DeviceMemoryBase device_memory_; }; // Client-side class used to enqueue infeed buffers. @@ -100,8 +99,11 @@ class InfeedManager { // new stream on the first invocation. On subsequent invocations, if // the cached executor is not the same as the requested executor, // returns null. - perftools::gputools::Stream* GetStream( - perftools::gputools::StreamExecutor* executor); + se::Stream* GetStream(se::StreamExecutor* executor); + + // Registers a callback that will be called when 'enqueued_buffer_' becomes + // empty. + void RegisterOnEmptyCallback(std::function callback); private: // TODO(b/30467474): Revisit if this mutex becomes a point of @@ -121,10 +123,14 @@ class InfeedManager { tensorflow::gtl::FlatSet dequeued_buffer_; // Cached host to device stream for queuing infeed data. - std::unique_ptr host_to_device_stream_; + std::unique_ptr host_to_device_stream_; // Executor that the host_to_device_stream belongs to. Not owned. - perftools::gputools::StreamExecutor* host_to_device_executor_; + se::StreamExecutor* host_to_device_executor_; + + // List of callbacks which will be called when 'enqueued_buffer_' becomes + // empty. + std::vector> on_empty_callbacks_; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 2ac95ceb692447c7ac6dbbcd8b9a38876f7a77b6..ea34d5b30c91e8b809e3e17a904e27e589fd6b5f 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -31,10 +31,10 @@ InfeedThunk::InfeedThunk( destination_buffer_(destination_buffer) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { + se::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - perftools::gputools::DeviceMemoryBase destination_address = + se::DeviceMemoryBase destination_address = buffer_allocations.GetDeviceAddress(destination_buffer_); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); @@ -45,7 +45,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, std::vector tuple_element_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - perftools::gputools::DeviceMemoryBase tuple_element_address = + se::DeviceMemoryBase tuple_element_address = buffer_allocations.GetDeviceAddress(tuple_element_buffer); InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 86918705fa0305217f11753e383200c7bd71474b..93713cb12defd95bdd69cb0aa7ad7b4e37fc8fae 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -44,7 +44,7 @@ class InfeedThunk : public Thunk { InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b5962f069bf499c913bd5479f263a7cb77c00555..6c4519185b34989eb53c884ba214d69b824b113c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -25,13 +27,19 @@ namespace gpu { namespace { bool IsFusile(const HloInstruction& hlo) { + // Don't fuse get-tuple-element on GPU: We can, but it's slower than not + // fusing. We never generate kernels for unfused GTEs. Instead, if an + // unfused GTE is an input to a kernel (including a fusion kernel), we + // compute the address of the GTE at the top of the kernel. Often we know the + // address of the GTE result statically, so we can do this without chasing any + // pointers. return (hlo.IsElementwise() && hlo.operand_count() > 0) || + hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || - hlo.opcode() == HloOpcode::kGetTupleElement || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || @@ -40,13 +48,101 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); - // Output fusion is not currently supported on GPUs. + // Check if we can use output fusion for (A @ B) * alpha + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { + int64 other_operand_index = 1 - operand_index; + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->operand_count() == 1 && + consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) { + return true; + } + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = consumer->operand(other_operand_index); + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { + return true; + } + } + } + + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { + auto producer_operand_index = consumer->operand_index(producer); + auto fused_parameter = consumer->fused_parameter(producer_operand_index); + const std::vector& fused_parameter_users = + fused_parameter->users(); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } + } + + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -70,17 +166,6 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // We may need to know original operand layout to emit input fusion, and so - // far, we merely use the layout of an operand of the fusion node, which means - // we must fuse only elementwise operations. This restriction should be lifted - // later if we need to fuse other operations, e.g. transpose, for performance. - if ((IsReductionToVector(*consumer) || - (HloOpcode::kFusion == consumer->opcode() && - HloInstruction::FusionKind::kInput == consumer->fusion_kind())) && - !producer->IsElementwise()) { - return false; - } - // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && @@ -89,15 +174,48 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Fuse scalar constants into loop fusion nodes, this reduces the number of + // parameters and makes matching scalar broadcasts easier. + if (ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion && + producer->opcode() == HloOpcode::kConstant) { + return true; + } + return IsFusile(*producer) && IsFusile(*consumer) && InstructionFusion::ShouldFuse(consumer, operand_index); } +bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + const HloInstruction* producer = consumer->operand(operand_index); + // The IR emitter has limited support for non-loop fusions with multi output + // at present. + // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Multi-output fusion requires instructions with compatible shapes. + if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { + return false; + } + // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for + // multi-output fusion. In particular, do not check whether an instruction is + // expensive to duplicate, since this doesn't matter here. + return GpuInstructionFusion::ShouldFuse(consumer, operand_index); +} + HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { + return HloInstruction::FusionKind::kOutput; + } if (HloOpcode::kFusion == consumer->opcode()) { return consumer->fusion_kind(); } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6dfc9de0a11566bb3a2fb3a1b62498ffa..f629d9ff2c7165b652369612c30979150f93bd24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,8 +27,13 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 2d6dad27a59978da6e4719afc50ebee5e641dde0..1963d9eef72d41fa0a275bea98f959671fa7e737 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -107,8 +111,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -124,8 +128,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -137,30 +141,469 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, GetTupleElementFused) { - HloComputation::Builder builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "param")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, param, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); +// Tests that broadcasts fused into a fusion with a reduce root. +TEST_F(InstructionFusionTest, BroadcastIntoReduce) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY BroadcastIntoReduce { + constant = f32[] constant(1) + broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={} + constant.1 = f32[] constant(0) + ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3}, + to_apply=add + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reduce(op::Broadcast(op::Constant()), op::Constant())); +} + +TEST_F(InstructionFusionTest, BitcastIntoAdd) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + bitcast = f32[4,1]{1,0} bitcast(p0) + ROOT add = f32[4,1] add(bitcast, p1) + })") + .ValueOrDie(); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - HloInstruction* fused_root = root->fused_expression_root(); - EXPECT_EQ(HloOpcode::kAdd, fused_root->opcode()); - // Check that operands of 'fused_root' are GTE. - EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(0)->opcode()); - EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Bitcast(op::Parameter()), op::Parameter())); +} + +TEST_F(InstructionFusionTest, AddIntoBitcast) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + add = f32[4,1] add(p0, p1) + ROOT bitcast = f32[4,1,1] bitcast(add) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Bitcast(op::Add(op::Parameter(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, DontFuseGTE) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY DontFuseGTE { + p0 = (f32[10], f32[10]) parameter(0) + gte0 = f32[10] get-tuple-element(p0), index=0 + gte1 = f32[10] get-tuple-element(p0), index=1 + ROOT add = f32[10] add(gte0, gte1) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DotOutputFusion) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); + EXPECT_THAT( + root->fused_expression_root(), + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Constant()))); +} + +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = ParseHloString(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())) + << module->ToString(); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = ParseHloString(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Constant()))); +} + +// Counts the HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +// Returns an HLO instruction from the given computation with the op code. +static StatusOr FindHloInstruction( + const HloComputation& computation, HloOpcode op) { + for (const auto* instruction : computation.instructions()) { + if (instruction->opcode() == op) { + return instruction; + } + } + return NotFound( + "Computation '%s' does not contain an instruction with op code '%s'.", + computation.name().c_str(), HloOpcodeString(op).c_str()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion) { + // sub --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT( + fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { + // tanh --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + tanh = f32[4,3]{1,0} tanh(p0) + add = f32[4,3]{1,0} add(tanh, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) + })") + .ValueOrDie(); + + // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion2) { + // sub --> add1 --\--------\ + // \----------> add2 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(sub, add1) + ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Add()), + op::Add(op::Subtract(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion3) { + // sub --> add1 ----\--------\ + // \ --> add2 --> add3 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + p3 = f32[4,3]{1,0} parameter(3) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(p2, sub) + add3 = f32[4,3]{1,0} add(add1, add2) + ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Add(), op::Add()), + op::Add(op::Parameter(), op::Subtract()))); +} + +TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { + // sub --> mul ---\ + // \--> call --> add --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + c = f32[] constant(42) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + sub = f32[4,3]{1,0} subtract(p0, p1) + mul = f32[4,3]{1,0} multiply(sub, c) + call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" + add = f32[4,3]{1,0} add(mul, call) + ROOT tuple = (f32[4,3]{1,0}) tuple(add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + // Visit instructions in post order to detect cycles. + // TODO(tjoerg): Add cycle detection to the HloVerifier. + class DummyVisitor : public DfsHloVisitorWithDefault { + public: + DummyVisitor() {} + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + } visitor; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + // Accept will return a FailedPrecondition when a cycle is detected. + EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); + } +} + +TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { + // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) + // \-------------------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[2,3]{1,0} parameter(2) + sub = f32[2,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` + // have incompatible shapes, expect that no multi-output fusion happens. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + fused_computation { + p1 = f32[10] parameter(0) + zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10] parameter(0) + mul = f32[10] multiply(p0, p0) + fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation + ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) + })") + .ValueOrDie(); + + // Multi-output fusion is not supported for non-loop fusions at present. Since + // `fused_computation` is a input fusion, expect no multi-output fusion to + // happen. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseScalarConstant) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY FuseScalarConstant { + p0 = f32[] parameter(0) + c0 = f32[] constant(1) + add1 = f32[] add(p0, c0) + b0 = f32[2]{0} broadcast(add1), dimensions={} + c1 = f32[2]{0} constant({1, 2}) + ROOT add2 = f32[2]{0} add(b0, c1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())), + op::Parameter())); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 2f65edffea81db7dba1f8545f92b27ea622044e7..67890bfed1136796c83c7ef6912ffc1ab1b7e332 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -49,14 +49,35 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. - bool type_is_allowed = (output_shape.element_type() == F32 || - output_shape.element_type() == F64); + PrimitiveType output_primitive_type = output_shape.element_type(); + bool type_is_allowed = + (output_primitive_type == F16 || output_primitive_type == F32 || + output_primitive_type == F64); return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape) && !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -67,24 +88,20 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - return true; + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && + hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) { + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = hlo.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo.fused_expression_root()->operand(1); + } + if (dot->opcode() == HloOpcode::kDot) { + return DotImplementedAsGemm(*dot); + } } return false; @@ -145,14 +162,8 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn - // algorithm to use. It's up to a later pass to choose the algorithm, so to - // indicate that we haven't yet made a choice, we speicfy -1 for that arg. - HloInstruction* negative_one = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); - HloInstruction* custom_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one}, call_target)); + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a3df67a87344d6ece2ea9047321ad9542c13f8cf..547af33e9a98c03e1429366172f9a401e385a9d1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -93,7 +94,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global_for_const, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); + bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } @@ -438,6 +442,32 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { return IrEmitter::DefaultAction(select); } +namespace { +llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { + return ir_builder->CreateExtractValue(x, {0}); +} + +llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { + return ir_builder->CreateExtractValue(x, {1}); +} + +std::pair MultiplyComplex( + llvm::Value* lhs_value, llvm::Value* rhs_value, + llvm::IRBuilder<>* ir_builder) { + llvm::Value* lhs_real = Real(lhs_value, ir_builder); + llvm::Value* lhs_imag = Imag(lhs_value, ir_builder); + llvm::Value* rhs_real = Real(rhs_value, ir_builder); + llvm::Value* rhs_imag = Imag(rhs_value, ir_builder); + llvm::Value* real_result1 = ir_builder->CreateFMul(lhs_real, rhs_real); + llvm::Value* real_result2 = ir_builder->CreateFMul(lhs_imag, rhs_imag); + llvm::Value* real_result = ir_builder->CreateFSub(real_result1, real_result2); + llvm::Value* imag_result1 = ir_builder->CreateFMul(lhs_real, rhs_imag); + llvm::Value* imag_result2 = ir_builder->CreateFMul(lhs_imag, rhs_real); + llvm::Value* imag_result = ir_builder->CreateFAdd(imag_result1, imag_result2); + return {real_result, imag_result}; +} +} // namespace + Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); @@ -456,21 +486,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto real = [&](llvm::Value* x) { - return ir_builder_.CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_.CreateExtractValue(x, {1}); - }; - llvm::Value* real_result = ir_builder_.CreateFSub( - ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value))); - llvm::Value* imag_result = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value))); + auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = ir_builder_.CreateInsertValue(result, real_result, {0}); - result = ir_builder_.CreateInsertValue(result, imag_result, {1}); + result = ir_builder_.CreateInsertValue(result, value.first, {0}); + result = ir_builder_.CreateInsertValue(result, value.second, {1}); } else { result = ir_builder_.CreateFMul(lhs_value, rhs_value); } @@ -548,20 +567,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum = ir_builder_.CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { -#define REAL(x) ir_builder_.CreateExtractValue(x, {0}) -#define IMAG(x) ir_builder_.CreateExtractValue(x, {1}) - llvm::Value* product_real = ir_builder_.CreateFSub( - ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)), - ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element))); - llvm::Value* product_imag = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)), - ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element))); - updated_accum = ir_builder_.CreateInsertValue( - accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0}); - updated_accum = ir_builder_.CreateInsertValue( - updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1}); -#undef IMAG -#undef REAL + auto value = MultiplyComplex(lhs_element, rhs_element, &ir_builder_); + llvm::Value* accum_real = Real(accum, &ir_builder_); + llvm::Value* real_sum = ir_builder_.CreateFAdd(accum_real, value.first); + updated_accum = ir_builder_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* accum_imag = Imag(accum, &ir_builder_); + llvm::Value* imag_sum = ir_builder_.CreateFAdd(accum_imag, value.second); + updated_accum = ir_builder_.CreateInsertValue(updated_accum, imag_sum, {1}); } else { llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); updated_accum = ir_builder_.CreateFAdd(accum, product); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index b0accc08d479258d65a18202122e4c9e90ff78d0..e55dfc6dae844ceb1d28ad389d133c80823bad9a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -120,10 +120,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } - // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice. - BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const { + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { return ir_emitter_context_->buffer_assignment() - .GetUniqueTopLevelSlice(&hlo) + .GetUniqueSlice(&hlo, index) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 3790ed313b9d0e167185a8b12c812132ee78811f..a78b4ff83075fd7ef330bb97ce217a198d450cf8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -32,7 +32,7 @@ class IrEmitterContext { public: IrEmitterContext(const HloModule* hlo_module, const BufferAssignment* buffer_assignment, - const perftools::gputools::DeviceDescription* device_desc, + const se::DeviceDescription* device_desc, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), @@ -47,7 +47,7 @@ class IrEmitterContext { const BufferAssignment& buffer_assignment() const { return *buffer_assignment_; } - const perftools::gputools::DeviceDescription& device_description() const { + const se::DeviceDescription& device_description() const { return *device_desc_; } llvm::Module* llvm_module() { return llvm_module_; } @@ -56,7 +56,7 @@ class IrEmitterContext { private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; - const perftools::gputools::DeviceDescription* device_desc_; + const se::DeviceDescription* device_desc_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 71aada080ae8df70bffce3e1854b5fbd833efd23..bb47a4280541ce2806472aa9365bb0ef38c0c3b3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/core/lib/core/status.h" @@ -116,6 +117,26 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { + // For MOF we give the loop emitter an array for every output it should + // generate. + if (hlo.IsMultiOutputFusion()) { + std::vector target_arrays; + for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; + ++i) { + target_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (const llvm_ir::IrArray& array : target_arrays) { + tuple_operand_ptrs.push_back(array.GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, + module_); + return Status::OK(); + } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &ir_builder_) .EmitLoop(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 30c88c0a5d38f6ea3f94d3b47b7b69c7122bf6ac..726434c3dfd4f1ef866d2eb9b6d7eb8b659e0984 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -30,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -44,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -55,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -76,6 +81,7 @@ namespace { using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::InlinedVector; using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; using tensorflow::strings::StrCat; @@ -142,37 +148,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } -// Tries to get a Slice for the given instruction at the given index, but -// returns nullopt if we might not know the slice's address at runtime without -// dereferencing a containing tuple. -// -// In particular, when XLA accepts a parameter of tuple type, the caller has the -// option of telling XLA what are the values inside of the tuple, or just giving -// XLA a pointer to the top-level tuple and letting us chase the pointers on the -// GPU. We therefore cannot rely having these pointers to parameter sub-buffers -// being present when we run the program. -optional GetKnownAtRuntimeSlice( - const HloInstruction* instr, const ShapeIndex& index, - const BufferAssignment& buffer_assn) { - auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); - if (!maybe_slice.ok()) { - return nullopt; - } - // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, - // but we don't necessarily know the runtime address of sub-buffers of input - // parameters. - const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); - const BufferAllocation* alloc = slice.allocation(); - if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && - !alloc->param_shape_index().empty()) { - return nullopt; - } - - // Otherwise, we will know the address of this slice at runtime without having - // to dereference a tuple. - return slice; -} - } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -203,7 +178,7 @@ bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -213,13 +188,13 @@ bool ImplementedAsDeviceToDeviceMemcpy( // // 1. `hlo` is a kCopy instruction. // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. The operand to `hlo` has a buffer assignment (constants do not, for - // instance) which means the source buffer also resides on the device. + // 3. `hlo` and its operand have a statically-known buffer assignment + // (constants do not, for instance), which means the source buffer also + // resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && - GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) - .has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && + buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); } } // namespace @@ -285,8 +260,39 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( return kernel; } +namespace { +// Computes the maximum valid unroll factor for a given instruction. +int ComputeMaxUnrollFactor(const HloInstruction* hlo) { + int max_unroll_factor = hlo->GetModule() + ->config() + .debug_options() + .xla_gpu_max_kernel_unroll_factor(); + + // Find the largest possible power of two to unroll by. + // TODO(kramerb): Make this smarter. + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); + for (int i = max_unroll_factor; i > 1; i /= 2) { + if (num_elements % i == 0) { + return i; + } + } + + // Cannot unroll. + return 1; +} +} // namespace + Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { - thunk_sequence_->emplace_back(BuildKernelThunk(hlo)); + int unroll_factor = 1; + // Unfused elementwise operations are usually memory bound, unroll them. + if (hlo->IsElementwise()) { + unroll_factor = ComputeMaxUnrollFactor(hlo); + } + + thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); return IrEmitter::DefaultAction(hlo); } @@ -419,15 +425,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const HloInstruction* algorithm_inst = custom_call->operand(2); - CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); - int64 algorithm = algorithm_inst->literal().Get({}); - - const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); - CHECK(tensor_ops_enabled_inst->IsConstant()) - << tensor_ops_enabled_inst->ToString(); - bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); - + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -442,7 +441,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -455,7 +455,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -468,7 +469,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -496,14 +498,32 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kTuple: case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(fusion)); - TF_RETURN_IF_ERROR(EmitInitializer( - fusion, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(fusion)); + ArraySlice output_instructions = + root->opcode() == HloOpcode::kTuple + ? root->operands() + : ArraySlice(&root, 1); + + // For multi-output fusion emit an initializer for each tuple element. + // Otherwise it's sufficient to just initialize the single output. + HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion, output_instructions[i] == root + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + first_reduce = + first_reduce == nullptr ? output_instructions[i] : first_reduce; + } + } + CHECK(first_reduce != nullptr); + thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; @@ -516,44 +536,50 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - Shape input_shape = root->operand(0)->shape(); - // EmitReductionToVector requires the input shape to have a layout, but - // fused instructions don't have one. So we determine its layout from - // the fusion's operands. The choice of the layout only affects - // performance but not correctness. - auto choose_input_layout = []( - tensorflow::gtl::ArraySlice operands, - Shape* input_shape) -> Status { - // Prefer the layout of an operand whose shape is compatible with - // input_shape. - for (const HloInstruction* operand : operands) { - if (ShapeUtil::Compatible(*input_shape, operand->shape())) { - return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), - input_shape); - } + // For multi-output fusion CHECK the constraints and feed all the + // reduces into a single loop code generator. Single-output reduce + // fusion is a special case of that. + InlinedVector input_gens; + InlinedVector init_value_gens; + std::vector> + extra_output_gens; + InlinedVector reducers; + InlinedVector reduce_output_shapes; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (root->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; } - // If no operand has a compatible shape, prefer an operand that has - // the same rank at least. - for (const HloInstruction* operand : operands) { - if (ShapeUtil::Rank(*input_shape) == - ShapeUtil::Rank(operand->shape())) { - // Do not use CopyLayoutBetweenShapes because input_shape and - // operand->shape() may be incompatible. - *input_shape->mutable_layout() = operand->shape().layout(); - return Status::OK(); - } + // TODO(kramerb): CHECK that layouts are equal. Currently this + // breaks multioutputfusion_test. The test has pre-fused + // instructions, but layout_assignment will not assign any layouts + // for instructions inside of a fused computation. It just removes + // the layouts instead. + if (inst->opcode() == HloOpcode::kReduce) { + CHECK(ShapeUtil::Compatible(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Compatible(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + CHECK(first_reduce->dimensions() == inst->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(inst->operand(1))); + reducers.push_back(inst->to_apply()); + reduce_output_shapes.push_back(std::move(output_shape_index)); + } else { + CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), + inst->shape())); + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); } - // When all the above fails, which is rare, set the default layout. - LayoutUtil::SetToDefaultLayout(input_shape); - return Status::OK(); - }; - TF_RETURN_IF_ERROR( - choose_input_layout(fusion->operands(), &input_shape)); - - return EmitReductionToVector( - root, input_shape, fused_emitter.GetGenerator(root->operand(0)), - fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), - root->to_apply()); + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + return EmitReductionToVector(first_reduce, input_shape, input_gens, + init_value_gens, + first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -598,7 +624,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); + + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); + + thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); } @@ -934,10 +964,33 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { + for (int i = 0; i != extra_output_gens.size(); ++i) { + const HloInstruction* output = reduce->parent()->FusionInstruction(); + llvm::Value* extra_output_address = + GetIrArray(*output, *output, extra_output_gens[i].second) + .EmitArrayElementAddress(index, &ir_builder_, + "extra_output_element_address"); + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + extra_output_gens[i].first(index)); + ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + } + return Status::OK(); +} + Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; int64 num_elems = ShapeUtil::ElementsIn(input_shape); @@ -989,14 +1042,19 @@ Status IrEmitterUnnested::EmitReductionToScalar( // auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; @@ -1029,11 +1087,16 @@ Status IrEmitterUnnested::EmitReductionToScalar( llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); }; // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's @@ -1068,20 +1131,24 @@ Status IrEmitterUnnested::EmitReductionToScalar( : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1097,14 +1164,21 @@ Status IrEmitterUnnested::EmitReductionToScalar( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), - output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + /*linear=*/ir_builder_.getInt64(0), + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterates through all input tiles, one per thread. @@ -1124,8 +1198,13 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Divide the input matrix into tiles of size Kx1. For example, when the // input matrix is 4x4 and K=2, the tiled matrix looks like // @@ -1135,9 +1214,13 @@ Status IrEmitterUnnested::EmitColumnReduction( // 4567 // Numbers indicate tile IDs. // // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. We - // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. - constexpr int64 kTileSize = 16; + // scalar is accumulated to the output vector using atomic operations. + // + // We choose 128 as the tile size based on empirical evidence. It's big enough + // to reduce the amount of atomic adds in the end, maximizing the memory + // bandwidth. + constexpr int64 kTileSize = 128; + // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); @@ -1167,15 +1250,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // } auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } // Emit an inner for-loop that partially reduces the elements in the given @@ -1233,13 +1321,18 @@ Status IrEmitterUnnested::EmitColumnReduction( .SourceIndexOfTranspose(normalized_input_shape, input_shape, transpose_dimension_mapping, &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); } - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); }; // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's @@ -1268,13 +1361,20 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterate through all input tiles. @@ -1292,12 +1392,42 @@ Status IrEmitterUnnested::EmitColumnReduction( .EmitLoop(IrName(reduce)); } +static std::pair ComputeTilingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); +} + Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // A naive algorithm is: - // 1. Divide the input tensor into tiles of size 1x1xK. + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. // 2. Partially reduces each tile to a scalar using one thread. // 3. Accumulates that scalar to the output vector using atomic operations. // @@ -1308,15 +1438,15 @@ Status IrEmitterUnnested::EmitRowReduction( // int y = linear_index / width_in_tiles % height; // int z = linear_index / (height * width_in_tiles); // float partial_result = 0; - // for (element_id_in_tile : range(kTileSize)) { - // int x = x_in_tiles * kTileSize + element_id_in_tile; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) // partial_result = reducer(partial_result, input[z][y][z]); // } // AtomicReducer(&output[y], partial_result); // } // - // Three optimizations are performed. + // Four optimizations are performed. // // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead @@ -1343,29 +1473,44 @@ Status IrEmitterUnnested::EmitRowReduction( // element_id_in_tile, which makes the code more friendly to optimizations // such as LICM. // + // 4. When the width is too small and x_tile_size is less than the target + // number of elements per thread and use a small factor of depth as + // z_tile_size to increase the number of elements calculated by each + // partial sum. This can reduce the needed number of dynamic shfl_down and + // atomic operations. + // // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < depth * height * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { // int x_in_tiles = linear_index % width_in_tiles; // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); + // int z_in_tiles = linear_index / (height * width_in_tiles); // int warp_id = x_in_tiles / warpSize; // int lane_id = x_in_tiles % warpSize; // float partial_result = 0; // int x = warp_id * kTileSize * warpSize + lane_id; - // if (width % (kTileSize * warpSize) == 0 || - // x + (kTileSize - 1) * warpSize < width) { - // // The entire tile is in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (width % (x_tile_size * warpSize) == 0 || + // x + (x_tile_size - 1) * warpSize < width) { + // // The entire x_tile is in bounds. + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0;element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, x += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } else { // // The tile is partially in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; element_id_in_x_tile < + // x_tile_size; // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (x < width) + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) @@ -1376,29 +1521,35 @@ Status IrEmitterUnnested::EmitRowReduction( // AtomicReducer(&output[y], partial_result); // } // - // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. - constexpr int64 kTileSize = 8; + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeTilingSchemeForReduction(depth, width, kWarpSize); + // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { - // Emit the loop body that reduces one tile. + auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + // Emit the loop body that reduces one z-x-tile. + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* z = tile_index[0]; + llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; llvm::Value* warp_id = ir_builder_.CreateUDiv( @@ -1406,102 +1557,132 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* lane_id = ir_builder_.CreateURem( x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); - // The x-location of the last element in this tile. - // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * + // x_tile_size); llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize - 1), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); - llvm::Value* x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - tile_element_loop->GetIndVarValue(), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - - // Points ir_builder_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); - } + lane_id, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(x_tile_size - 1), + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + KernelSupportLibrary ksl( + &ir_builder_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = ir_builder_.CreateNSWAdd( + z_indvar, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(z_tile_size), z_tile)); + + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1, + [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + warpSize * (element_id_in_x_tile + warp_id * + // x_tile_size); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + x_indvar, + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT( + x, ir_builder_.getInt64(width)), + "x_in_bounds", &ir_builder_); + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = + ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, + &ir_builder_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &ir_builder_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); + return Status::OK(); + }; - // Emit code that reads the input element and accumulates it to the - // partial reduction result. - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We - // need to convert that to an index to input_shape (the shape of the - // operand of "reduce"). This conversion is composed of a transposition - // from input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &ir_builder_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - } - return EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address); + return ksl.For("z_tile", + /*start=*/0, /*end=*/z_tile_size, /*step=*/1, + emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), + ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. @@ -1510,20 +1691,24 @@ Status IrEmitterUnnested::EmitRowReduction( : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1537,19 +1722,34 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + y, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + if (x_tile_size * z_tile_size < depth * width) { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } else { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } + } + return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {depth, height, width_in_tiles}, - {2, 1, 0}); + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); @@ -1570,10 +1770,14 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for // a fused kReduce). @@ -1608,8 +1812,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. if (input_dims_to_keep.empty()) { - return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, - reducer); + return EmitReductionToScalar(reduce, input_shape, input_gens, + init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } else if (input_dims_to_keep.front() == LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width @@ -1626,8 +1831,9 @@ Status IrEmitterUnnested::EmitReductionToVector( height *= input_shape.dimensions(input_dim); } } - return EmitColumnReduction(height, width, reduce, input_shape, input_gen, - init_value_gen, reducer); + return EmitColumnReduction(height, width, reduce, input_shape, input_gens, + init_value_gens, reducers, reduce_output_shapes, + extra_output_gens); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1653,7 +1859,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gen, init_value_gen, reducer); + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } } @@ -1668,25 +1875,24 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { if (IsReductionToVector(*reduce) && // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(reduce)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(reduce)); - TF_RETURN_IF_ERROR(EmitInitializer( - reduce, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(reduce)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(reduce)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), reduce)); + return EmitReductionToVector( - reduce, input->shape(), - [&](const llvm_ir::IrArray::Index& index) { + reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*input, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - [&](const llvm_ir::IrArray::Index& index) { + }}, + {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - dimensions_to_reduce, reducer); + }}, + dimensions_to_reduce, {reducer}, {{}}, {}); } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); @@ -1739,16 +1945,13 @@ Status IrEmitterUnnested::HandleSelectAndScatter( CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); CHECK_EQ(rank, window.dimensions_size()); - { - std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - TF_RETURN_IF_ERROR(EmitInitializer( - select_and_scatter, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(select_and_scatter)); + std::vector> thunks; + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1957,6 +2160,52 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on GPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on GPU."); + } + + // CRS with one operand and one replica is simply the identity function. + // Buffer assignment expects a copy, so that's what we do. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + if (crs->operand_count() == 1) { + CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + thunk_sequence_->push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); + return Status::OK(); + } + + // One-replica CRS with multiple operands produces a tuple of the inputs. + // Again, buffer assignment expects us to copy each. + std::vector> thunks; + std::vector tuple_element_buffers; + for (int64 i = 0; i < crs->operand_count(); ++i) { + tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(crs, {i}) + .ValueOrDie()); + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(i)), + /*destination_buffer=*/tuple_element_buffers.back(), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + } + + // Output a tuple of the buffers above. + thunks.push_back(MakeUnique(tuple_element_buffers, + GetAllocationSlice(*crs), crs)); + thunk_sequence_->push_back( + MakeUnique(std::move(thunks), crs)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -1993,38 +2242,54 @@ GetHloBufferSlices(const HloInstruction* hlo, -> optional> { // Simple, common case: Is the buffer for instr known at runtime? If so, // we're done. - auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); - if (slice.has_value()) { - return {{*slice, ShapeIndex()}}; + auto slice = buffer_assn.GetUniqueSlice(instr, index); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } - // If we don't know the buffer for instr at index, see if we know the buffer - // for instr at index without its last element. If so, we can dynamically - // find the buffer for instr by dereferencing a pointer in that buffer. - // Continue looking this way until we run out of elements in 'index'. - ShapeIndex new_index = index; - ShapeIndex gte_indices; - while (!new_index.empty()) { - gte_indices.push_front(new_index.back()); - new_index.pop_back(); - auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + // If that didn't work, walk up any bitcasts that we might see. These must + // appear before any GTE instructions, because it's illegal to bitcast to a + // tuple type. + const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kBitcast) { + parent = parent->operand(0); + + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } } - // If *that* didn't work, check whether instr is a GTE instruction. If it - // is, see if we can get a buffer for its parent, and continue walking up - // parents until we find a defined buffer or we hit something that's not a - // GTE. - const HloInstruction* parent = instr; + // Check whether instr is a GTE instruction. If it is, see if we can get a + // buffer for its parent, and continue walking up parents until we find a + // defined buffer or we hit something that's not a GTE. + ShapeIndex gte_indices; while (parent->opcode() == HloOpcode::kGetTupleElement) { gte_indices.push_front(parent->tuple_index()); parent = parent->operand(0); - auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; + } + } + + // Finally, if we don't know the buffer for instr at index, see if we know + // the buffer for instr at index without its last element. If so, we can + // dynamically find the buffer for instr by dereferencing a pointer in that + // buffer. Continue looking this way until we run out of elements in + // 'index'. + // + // We can almost always get a buffer without resorting to this. The only + // exception is for cases where the relevant sub-buffer is truly unknowable, + // for example the sub-buffer of a tuple-shaped select. + ShapeIndex new_index = index; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = buffer_assn.GetUniqueSlice(instr, new_index); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; } } @@ -2069,8 +2334,8 @@ Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { return Unimplemented("Gather is not implemented on GPUs."); } -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst) { +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( + const HloInstruction* inst, int unroll_factor) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2162,7 +2427,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( } return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - inst); + inst, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2207,6 +2472,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({})); + case F32: + return literal.Get({}); + case F64: + return literal.Get({}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2219,13 +2499,31 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. - false, // Do not transpose LHS. - false, // Do not transpose RHS. + 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - const HloInstruction* dot = inst->fused_expression_root(); + CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); + } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + DCHECK(dot->opcode() == HloOpcode::kDot); const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); @@ -2237,14 +2535,13 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->operand(rhs_parameter->parameter_number()); return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS. + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. inst); } @@ -2261,37 +2558,107 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } -Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, - KernelThunk* thunk) { +StatusOr> IrEmitterUnnested::BuildInitializerThunk( + const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - CHECK(inst->opcode() == HloOpcode::kSelectAndScatter || - inst->opcode() == HloOpcode::kReduce); - const HloInstruction* init_value = nullptr; - switch (inst->opcode()) { - case HloOpcode::kSelectAndScatter: - init_value = inst->operand(2); - break; - case HloOpcode::kReduce: - init_value = inst->operand(1); - break; - default: - LOG(FATAL) << "Opcode " << inst->opcode() - << " should not need an initializer."; - } + const HloInstruction* init_value_operand = [&] { + switch (inst->opcode()) { + case HloOpcode::kSelectAndScatter: + return inst->operand(2); + case HloOpcode::kReduce: + return inst->operand(1); + case HloOpcode::kTuple: + CHECK(hlo->IsMultiOutputFusion()) + << ": " << hlo->ToString() << " is not a multi-output fusion."; + CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce) + << ": Found '" << inst->operand(index.back())->opcode() << "' in " + << inst->ToString() << " but expected 'reduce'."; + // For multi-output fusion look through the tuple. + return inst->operand(index.back())->operand(1); + default: + LOG(FATAL) << "Opcode " << inst->opcode() + << " should not need an initializer."; + } + }(); + const HloInstruction* init_value = init_value_operand; if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } - return EmitTargetElementLoopInThunk( - *hlo, - [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); - }, - thunk); + // In the common case, the initializer is a constant. In this case, emit a + // device-memset call if we can. Currently StreamExecutor only supports + // zeroing and 32-bit memsets. + if (init_value->IsConstant()) { + CHECK(ShapeUtil::IsScalar(init_value->shape())); + int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape()); + const auto& literal = init_value->literal(); + + // Are all the bytes of this scalar equal to 0? If so, we can create a + // MemzeroThunk. + ArraySlice literal_bytes( + reinterpret_cast(literal.untyped_data()), num_bytes); + if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + uint16 pattern16; + if (num_bytes == 1) { + uint8 b = literal_bytes.front(); + pattern16 = uint16{b} | (uint16{b} << 8); + } else { + pattern16 = literal_bytes.front(); + } + uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); + return {MakeUnique( + pattern32, GetAllocationSlice(*hlo, index), hlo)}; + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(literal_bytes.data(), literal_bytes.data() + 4, + literal_bytes.size() - 4) == 0) { + uint32 word; + memcpy(&word, literal_bytes.data(), sizeof(word)); + return {MakeUnique( + word, GetAllocationSlice(*hlo, index), hlo)}; + } + } + + // Otherwise fall back to our slow initializer code. + std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); + LaunchDimensions launch_dimensions = + CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), + ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + // If the init_value was fused into this reduce we have to generate it first. + if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { + CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); + TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &ir_builder_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &ir_builder_) + .EmitLoop(IrName(hlo))); + + // Clean up state left behind by emitting the loop above. (This is normally + // done in IrEmitterUnnested::Postprocess().) + bindings_.UnbindAllLocalIrValues(); + + // Convert unique_ptr to StatusOr>. + return {std::move(kernel_thunk)}; } namespace { @@ -2452,18 +2819,22 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + int unroll_factor = thunk->unroll_factor(); VLOG(3) << bindings_.ToString(); const Shape& element_shape = hlo.IsMultiOutputFusion() ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); + VLOG(3) << "EmitTargetElementLoopInThunk " + << ShapeUtil::HumanStringWithLayout(hlo.shape()) + << " for unroll_factor " << unroll_factor; LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->device_description()); + element_shape, ir_emitter_context_->device_description(), unroll_factor); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, unroll_factor) .EmitLoop(IrName(&hlo)); } @@ -2473,7 +2844,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, + unroll_factor) .EmitLoop(IrName(&hlo))); std::vector tuple_operand_ptrs; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b83a2337e2decd9d4fba3d40fcf33f131fca8a3c..202231b82f3877c11cf932bd00a8aac350fd0afa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -38,7 +38,7 @@ namespace gpu { // // Examples of things that are not unnested computations: // -// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The reducer of a kReduce HLO. This is emitted using IrEmitterNested. // - The body of a fusion node. IrEmitterUnenested emits the relevant code // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not // really an IrEmitter, but is more an "IR generator generator".) @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -99,6 +100,13 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& inst, tensorflow::gtl::ArraySlice args); + // Helper for writing extra outputs from inside a reduce kernel. + Status EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); + // EmitColumnReduction and EmitRowReduction emit code for column and row // reduction of a matrix and/or 3D tensor. Row and column reduction have // different memory access pattern, so for performance their implementations @@ -109,28 +117,43 @@ class IrEmitterUnnested : public IrEmitter { // `EmitReductionToVector`. Note that input shape might not be // [height x width], but can be bitcast to [height x weight] with "height" // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitColumnReduction( + int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be // [depth x height x width], but can be bitcast to [depth x height x weight] // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitRowReduction( + int64 depth, int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or @@ -140,21 +163,31 @@ class IrEmitterUnnested : public IrEmitter { // generate elements of the input and the initial value. Other parameters mean // the same as for `HandleReduce`. // + // Multiple reduces can be emitted in the same loop, assuming they have the + // same input and output shapes, and the same reduce dimensions. + // + // extra_output_gens can contain extra generators for intermediate outputs. + // These must have the same shape as the reduce input as they are computed + // when the reduce inputs are being read. + // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); - - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned - // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + // Thunk object. The kernel implementation will be unrolled if unroll_factor + // is greater than one. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst, + int unroll_factor = 1); // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); @@ -163,6 +196,11 @@ class IrEmitterUnnested : public IrEmitter { // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + // Returns a thunk that, given a reduce or select-and-scatter op, initializes + // its memory to the appropriate initial value. + StatusOr> BuildInitializerThunk( + const HloInstruction* hlo, const ShapeIndex& index = {}); + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index c20a781a33fe89af4740ed31dd5bfb1a64473057..f56c1ce69f11ed79c8be76834269f29de93a9645 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -23,38 +23,50 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { KernelThunk::KernelThunk( tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction) + const string& kernel_name, const HloInstruction* hlo_instruction, + int unroll_factor) : Thunk(Kind::kKernel, hlo_instruction), args_(args.begin(), args.end()), - kernel_name_(kernel_name) {} + kernel_name_(kernel_name), + unroll_factor_(unroll_factor) {} -tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { +Status KernelThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); - if (loader_spec_) { - // Already initialized by another thread. - return tensorflow::Status::OK(); + if (!loader_spec_) { + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); + tensorflow::StringPiece ptx = executable.ptx(); + // Convert tensorflow::StringPiece to se::port::StringPiece because + // StreamExecutor uses the latter. + loader_spec_->AddCudaPtxInMemory( + se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), + kernel_name_); + } } - loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); - - if (!executable.cubin().empty()) { - loader_spec_->AddCudaCubinInMemory( - reinterpret_cast(executable.cubin().data()), kernel_name_); + // Load the kernel into the device if necessary. + // + // We could alternatively do this within ExecuteOnStream, but doing it here + // lets the time spent loading the kernel not count towards our execution + // profiles. + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } } - return tensorflow::Status::OK(); + return Status::OK(); } void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { @@ -62,21 +74,18 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { launch_dimensions_ = launch_dims; } -tensorflow::Status KernelThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; const se::KernelBase* kernel = nullptr; + { tensorflow::mutex_lock lock(mutex_); auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; - if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); - } - } + CHECK(it != kernel_cache_.end()) + << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; kernel = &it->second; } @@ -97,7 +106,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 9ae455e2fcc253a7a08ff95764721048a16b0bf7..7def27e189b66747569344a3dbe5c0c446f903be 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -47,20 +47,22 @@ class KernelThunk : public Thunk { // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. KernelThunk(tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction); + const string& kernel_name, const HloInstruction* hlo_instruction, + int unroll_factor); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; ~KernelThunk() override = default; const string& kernel_name() const { return kernel_name_; } + int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - tensorflow::Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; // Executes the kernel for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // Buffers passed to the kernel as arguments. @@ -69,6 +71,10 @@ class KernelThunk : public Thunk { // Entry kernel name for the computation. const string kernel_name_; + // The number of times this kernel should be unrolled. This works as a + // multiplier on the number of elements produced by a GPU thread. + const int unroll_factor_; + // The thread and block dimension used to launch the kernel. // Will be set by IrEmitterUnnested. LaunchDimensions launch_dimensions_; @@ -76,13 +82,12 @@ class KernelThunk : public Thunk { // Describes how to load this kernel. ExecuteOnStream reuses this loader // specification for all executions. mutable tensorflow::mutex mutex_; - std::unique_ptr loader_spec_ - GUARDED_BY(mutex_); + std::unique_ptr loader_spec_ GUARDED_BY(mutex_); - // Loaded kernels for each `StreamExecutor` - std::unordered_map - kernel_cache_ GUARDED_BY(mutex_); + // Loaded kernels for each `StreamExecutor`. Requires pointer stability of + // values. + std::unordered_map kernel_cache_ + GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index f4c4dcdafd6cc0cd64da5a8d1f23c8c0e7b2a9cb..7de8f9e1ee922bdbf65fd1299702482e1843f17e 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -47,7 +47,6 @@ cc_library( "@llvm//:scalar", "@llvm//:support", "@llvm//:target", - "@llvm//:transform_utils", ], ) @@ -68,17 +67,3 @@ tf_cc_test( "@llvm//:support", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index defd281d74bd38f7da3f268e0f55970fc1af8263..a4e4e85bf3d2c197cfc691b7fca0920aa6571729 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -34,7 +34,7 @@ limitations under the License. #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.def" +#include "llvm/CodeGen/CommandFlags.inc" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = - tensorflow::Env::Default()->GetMatchingPaths( + const Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); if (status.ok() && unified_libdevice_files.size() == 1) { @@ -273,7 +272,7 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(codegen_passes, pstream, + target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); codegen_passes.run(*module); } @@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary( - llvm::Module* module, std::pair compute_capability, - const string& libdevice_dir_path) { +Status LinkLibdeviceIfNecessary(llvm::Module* module, + std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { - return tensorflow::Status::OK(); + return Status::OK(); } llvm::Linker linker(*module); @@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary( return tensorflow::errors::Internal(tensorflow::strings::StrCat( "Error linking libdevice from ", libdevice_path)); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, @@ -491,7 +490,7 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - tensorflow::port::Tracing::TraceMe annotation( + tensorflow::tracing::ScopedActivity activity( "Compiling IR", llvm_ir::AsString(module->getName()), /*is_expensive=*/true); XLA_SCOPED_LOGGING_TIMER("Compile module " + diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..d4100a898b5bb9eec382c34932c2db104c9e985b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +Status MemzeroThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemZero(&dest_data, dest_data.size()); + return Status::OK(); +} + +Status Memset32BitValueThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemset32(&dest_data, value_, dest_data.size()); + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..51c332d287d139335b356fc66411b5ffaa448b5a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/stream_executor/stream_executor.h" + +// This file contains thunks that set a buffer's elements to a particular value. +// This can be faster than emitting a kernel to set the elements. + +namespace xla { +namespace gpu { + +// Thunk that zeroes out a given chunk of memory. +class MemzeroThunk : public Thunk { + public: + explicit MemzeroThunk(const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemzero, hlo), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; + + private: + const BufferAllocation::Slice dest_; +}; + +// Thunk that sets a given chunk of memory to a particular 32-bit value. The +// destination chunk must have size divisible by 32 bits. +class Memset32BitValueThunk : public Thunk { + public: + explicit Memset32BitValueThunk(uint32 value, + const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; + + private: + uint32 value_; + const BufferAllocation::Slice dest_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3f444a1268cf8e3e551a3c5b986fee0339fa327 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} + +bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) { + auto get_element_instr = + [&](const HloInstruction* instr) -> const HloInstruction* { + const HloInstruction* element_instr = instr; + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduce operand of the fusion root, + // because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (inst->opcode() == HloOpcode::kReduce) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } else { + element_instr = fused_expression_root; + } + } + return element_instr; + }; + + auto get_element_shape = [&](const HloInstruction* element_instr) { + // Special handling of kReduce instructions -- the fusion + // applies to the first operand. + if (element_instr->opcode() == HloOpcode::kReduce) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // The shapes in all tuple operands should agree, unless it is a reduce. + // In that case, the operand of the reduce needs to have the same shape + // as the other tuple operands, but also we need to compare the output + // shapes of the reduces. + auto* element_instr_1 = get_element_instr(instr1); + auto* element_instr_2 = get_element_instr(instr2); + if (element_instr_1->opcode() == HloOpcode::kReduce && + element_instr_2->opcode() == HloOpcode::kReduce && + !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { + return false; + } + // The elementwise output shapes must be the same (including layout). + return ShapeUtil::Equal(get_element_shape(element_instr_1), + get_element_shape(element_instr_2)); +} + +bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + +namespace { +bool IsReduction(HloInstruction* instr) { + if (instr->IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr->fused_expression_root()->operands()) { + if (operand->opcode() == HloOpcode::kReduce) { + return true; + } + } + return false; + } else if (instr->opcode() == HloOpcode::kFusion) { + return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + } else { + return instr->opcode() == HloOpcode::kReduce; + } +} +} // namespace + +bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { + return IsReduction(instr); +} + +int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, + HloInstruction* instr2) { + tensorflow::gtl::FlatSet in_list; + for (auto instr : instr1->operands()) { + if (!IsProfitableOperand(instr)) { + continue; + } + in_list.insert(instr); + } + int64 profit = 0; + for (auto instr : instr2->operands()) { + if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + continue; + } + profit += ShapeUtil::ByteSizeOf(instr->shape()); + } + VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() + << ", the profit is =" << profit; + return profit; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..5451a93cec5e4aeca05717c181edb9dad0305c83 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +namespace xla { +namespace gpu { + +// Multi-output fusion of sibling and producer-consumer instructions for the +// Jellyfish backend. +class GpuMultiOutputFusion : public MultiOutputFusion { + public: + GpuMultiOutputFusion(); + + protected: + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) override; + + // We currently only consider reduce and reduce fusion nodes as candidates. + bool IsFusible(HloInstruction* instr) override; + + // This function estimates the amount of memory reads saved by merging + // instr1 and instr2 into one multi-output fusion instruction. For a fusion + // instruction, all the operands need to be loaded from memory. If we merge + // instr1 and instr2, common operands will not be loaded twice. The profit is + // estimated as the size of the common operands b/w instr1 and instr2. + int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; + + // Whether fusing the instruction can reduce memory reads. + // + // TODO(tjoerg): Move this method up into the MultiOutputFusion base class. + bool IsProfitableOperand(HloInstruction* instr) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..924cfb11f3da76c475458ea14f201e809be61be8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -0,0 +1,230 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + + scalar_add_computation { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0) + } + scalar_mul_computation { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + })"; + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { + // Fusion with reduce instruction root and a sibling reduce instruction + // sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] constant(1) + fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation + reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[6400]{0} parameter(1) + mul = f32[6400]{0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[6400]{0} parameter(1) + r1 = f32[64,100]{0,1} reshape(p1.2) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[6400]{0} parameter(1) + fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10]{0} parameter(0) + ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1.3 = f32[10,10]{1,0} parameter(1) + fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1 + p2 = f32[] parameter(2) + fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { + // Two sibling fusions with reduce instruction roots sharing the same input + // param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { + // Multi-output fusion with two reduce instructions root and a sibling reduce + // instruction sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { + const.1 = f32[] constant(1) + p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) + mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1) + reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2) + } + + ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + const = f32[] constant(1) + fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation + get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0 + get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1 + reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { + // Verify that if we already have a multi-output fusion that we prefer to pick + // a reduce op from its operands for checking shape compatibility. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation + ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1) + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10] parameter(0) + ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[10,10]{1,0} parameter(1) + p2 = f32[10]{0} parameter(2) + fusion.1 = (f32[10,10], f32[10]) fusion(p0, p1), kind=kInput, calls=fused_computation_1 + get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=0 + get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=1 + fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 25846dc6cd4633c7becb6e62d6bc9585348a6eac..c8f0d4185c63c5bafca6f30acab31cbe8e987277 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -68,13 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - input = computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape( - /*operand_shape=*/input->shape(), - /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), - padding_config) - .ConsumeValueOrDie(), - input, padding, padding_config)); + input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } if (window_util::HasNegativePadding(conv_window)) { @@ -97,11 +92,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::max(0LL, -conv_window.dimensions(i).padding_high()); } - input = computation->AddInstruction(HloInstruction::CreateSlice( - ShapeInference::InferSliceShape(input->shape(), start_indices, - limit_indices, strides) - .ConsumeValueOrDie(), - input, start_indices, limit_indices, strides)); + input = + MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie(); } return input; @@ -134,13 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - return computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape( - /*operand_shape=*/kernel->shape(), - /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), - padding_config) - .ConsumeValueOrDie(), - kernel, padding, padding_config)); + return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -252,11 +238,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = - computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape(input->shape(), padding->shape(), - input_padding_config) - .ConsumeValueOrDie(), - input, padding, input_padding_config)); + MakePadHlo(input, padding, input_padding_config).ValueOrDie(); // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. @@ -388,26 +370,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - for (HloInstruction* instruction : - module->entry_computation()->MakeInstructionPostOrder()) { - if (IsCustomCallToDnnConvolution(*instruction)) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + for (HloInstruction* instruction : convs) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } return changed; } +StatusOr PadInsertion::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 5e1c68701daa02eba64f3e34933ce373a496c1b8..67e51509e4c717951c83c7e41943af1de762dee0 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -31,6 +31,7 @@ class PadInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 388dcc008b07a76ff9ed07df04181e49a8734f51..d8c07dc3119fb81a3ef22822acb11b7c4d5bbca5 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -32,25 +32,32 @@ namespace gpu { ParallelLoopEmitter::ParallelLoopEmitter( BodyEmitter body_emitter, const Shape& shape, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor) : LoopEmitter(body_emitter, shape, ir_builder), - launch_dimensions_(launch_dimensions) {} + launch_dimensions_(launch_dimensions), + unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor) : LoopEmitter(target_element_generator, target_arrays, ir_builder), - launch_dimensions_(launch_dimensions) {} + launch_dimensions_(launch_dimensions), + unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor) : LoopEmitter(target_element_generator, target_array, ir_builder), - launch_dimensions_(launch_dimensions) {} + launch_dimensions_(launch_dimensions), + unroll_factor_(unroll_factor) {} -llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( +std::vector +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; @@ -63,6 +70,9 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x" // // %nctaid.x is currently specified as 2147483647. + VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; + std::vector array_indices; + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), @@ -81,7 +91,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), "thread_id"); - llvm::Value* linear_index = ir_builder_->CreateAdd( + llvm::Value* linear_index_base = ir_builder_->CreateAdd( ir_builder_->CreateMul( block_id, ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", @@ -99,15 +109,30 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::assume, {ir_builder_->CreateICmpULT( - linear_index, + linear_index_base, ir_builder_->getInt64(launch_dimensions_.threads_per_block() * launch_dimensions_.block_count()), "linear_index_in_range")}, {}, ir_builder_); + if (unroll_factor_ > 1) { + linear_index_base = ir_builder_->CreateMul( + linear_index_base, ir_builder_->getInt64(unroll_factor_), + "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); + } + + array_indices.emplace_back(linear_index_base, shape_, ir_builder_); + for (int i = 1; i < unroll_factor_; ++i) { + llvm::Value* linear_index = ir_builder_->CreateAdd( + linear_index_base, ir_builder_->getInt64(i), "linear_index", + /*HasNUW=*/true, /*HasNSW=*/true); + array_indices.emplace_back(linear_index, shape_, ir_builder_); + } + auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( - linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), + linear_index_base, + ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); // Set exit_bb_ to the exit block of the if structure. @@ -116,7 +141,8 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // Set IR builder insertion point to the body of the if structure. llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - return llvm_ir::IrArray::Index(linear_index, shape_, ir_builder_); + + return array_indices; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 8ed63a854a74fc06c3c389f40fe1f5970885deac..25318b3bed8bf4a2dfe3a4a974269d0405c3bfec 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -34,13 +34,13 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // The meanings of other parameters are the same as LoopEmitter. ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); // Constructs a ParallelLoopEmitter from an element generator that generates // each element of the given target array. ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); // Constructs a loop emitter for a loop that generates on element of each of N // arrays on each iteration. @@ -50,18 +50,20 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; - llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + std::vector EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) override; private: // The thread and block dimension to parallelize the loop on. const LaunchDimensions launch_dimensions_; + const int unroll_factor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 6cf280df05496716a0780d61ded92efd9982734c..d3fd0544fb68809125e9b9f7a5e5b7eff8c6ef43 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -29,8 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -44,12 +42,16 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc) { + const Shape& shape, const se::DeviceDescription& device_desc, + int unroll_factor) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } + CHECK_EQ(num_elements % unroll_factor, 0); + num_elements = num_elements / unroll_factor; + // Since we don't do any inter-warp communication, we're free to choose any // block size we want, subject to hardware constraints. We choose the // smallest block size that allows the GPU to reach full occupancy (assuming diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 0bf463a6ef95d5a32784838c08ad239752fd1acf..c125474edb1036090a926020f2b1e7fcf64c751a 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,8 +57,8 @@ std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc); + const Shape& shape, const se::DeviceDescription& device_desc, + int unroll_factor = 1); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index d8a43091d4037a0edd125a4a1b6cb5ad7c7065f0..88cb10883e97ae663dc492ad088e6daf9133d7f5 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -20,25 +20,24 @@ limitations under the License. namespace xla { namespace gpu { -SequentialThunk::SequentialThunk(std::vector>&& thunks, +SequentialThunk::SequentialThunk(std::vector> thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} -tensorflow::Status SequentialThunk::Initialize( - const GpuExecutable& executable) { +Status SequentialThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executable)); + TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status SequentialThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status SequentialThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 32c5b748aba14239d6795d14e442c1c3b43d010e..135f79e413dfaa27f2f2264e0daa3beb3c305e0f 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -31,17 +31,17 @@ namespace gpu { // require multiple kernel launches or library calls. class SequentialThunk : public Thunk { public: - SequentialThunk(std::vector>&& thunks, + SequentialThunk(std::vector> thunks, const HloInstruction* hlo); SequentialThunk(const SequentialThunk&) = delete; SequentialThunk& operator=(const SequentialThunk&) = delete; const std::vector>& thunks() const { return thunks_; } - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8c98956f1a9b2a0bb1d304a27eb8c8cfcf610784..6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -28,6 +28,14 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", config); + } + // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); }; @@ -41,9 +49,9 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -60,9 +68,9 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -91,24 +99,24 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { +namespace gpu { + +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::DataLayoutString; +using stream_executor::dnn::FilterLayout; +using stream_executor::dnn::FilterLayoutString; + +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + DataLayout input, FilterLayout filter, + DataLayout output) { + std::vector input_layout; + switch (input) { + case DataLayout::kBatchDepthYX: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.push_back(dnums.input_feature_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + input_layout.push_back(dnums.input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid input layout: ", + DataLayoutString(input)); + } + + std::vector filter_layout; + switch (filter) { + case FilterLayout::kOutputInputYX: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + break; + case FilterLayout::kOutputYXInput: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid filter layout: ", + FilterLayoutString(filter)); + } + + std::vector output_layout; + switch (output) { + case DataLayout::kBatchDepthYX: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.push_back(dnums.output_feature_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + output_layout.push_back(dnums.output_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid output layout: ", + DataLayoutString(output)); + } + + return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(output_layout)); +} + +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output) { + Layout nchw_input, nchw_filter, nchw_output; + std::tie(nchw_input, nchw_filter, nchw_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX) + .ConsumeValueOrDie(); + + Layout nhwc_input, nhwc_filter, nhwc_output; + std::tie(nhwc_input, nhwc_filter, nhwc_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth) + .ConsumeValueOrDie(); + + DataLayout input_layout; + if (LayoutUtil::Equal(input, nchw_input)) { + input_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(input, nhwc_input)) { + input_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid input layout: ", + input.ShortDebugString()); + } + + FilterLayout filter_layout; + if (LayoutUtil::Equal(filter, nchw_filter)) { + filter_layout = FilterLayout::kOutputInputYX; + } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + filter_layout = FilterLayout::kOutputYXInput; + } else { + return tensorflow::errors::Internal("Invalid filter layout: ", + filter.ShortDebugString()); + } + + DataLayout output_layout; + if (LayoutUtil::Equal(output, nchw_output)) { + output_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(output, nhwc_output)) { + output_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid output layout: ", + output.ShortDebugString()); + } + + return std::make_tuple(input_layout, filter_layout, output_layout); +} +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8218f4fd11d3978d0ecc53fc15e287aea4b69ec3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +// Helper functions for interacting with StreamExecutor. + +namespace xla { +namespace gpu { + +// Returns (input, filter, output) XLA Layout protos given the StreamExecutor +// layouts. +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + stream_executor::dnn::DataLayout input, + stream_executor::dnn::FilterLayout filter, + stream_executor::dnn::DataLayout output); + +// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 2c3032d79be221e8cacb178ffb1817459b603cc0..931c0bffab850362dbd2df975657dd47d9cbd3ae 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -51,6 +51,8 @@ class Thunk { kGemm, kInfeed, kKernel, + kMemset32BitValue, + kMemzero, kSequential, kTuple, kWhile, @@ -68,11 +70,14 @@ class Thunk { Kind kind() const { return kind_; } const HloInstruction* hlo_instruction() const { return hlo_instruction_; } - // Prepares for executing the thunk. This method is called only once over - // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a - // kernel, which is the same in every execution. - virtual tensorflow::Status Initialize(const GpuExecutable& executable) { - return tensorflow::Status::OK(); + // Prepares the thunk for execution on the given StreamExecutor. + // + // This may be called multiple times. Its main purpose is to give us a chance + // to do initialization outside of ExecuteOnStream() so that the + // time spent initializing doesn't count towards our execution profile. + virtual Status Initialize(const GpuExecutable& /*executable*/, + se::StreamExecutor* /*executor*/) { + return Status::OK(); } // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) @@ -83,27 +88,17 @@ class Thunk { // This value is not required to be constant for a given Thunk. For example, // a Thunk that performs autotuning may return true for its first run and // false thereafter. - virtual bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream* /*stream*/) { + virtual bool ShouldHaltAllActivityBeforeRunning(se::Stream* /*stream*/) { return false; } - // Indicates whether thunks scheduled after this one should wait for this one - // to complete before running. For example, a convolution thunk creates a - // scratch allocator, then kicks off a convolution in cudnn via the stream - // executor. When the stream executor call returns, the scratch allocator goes - // out of scope, and the scratch memory is deallocated. In this case, the - // convolution thunk needs to return true so that future thunks wait for the - // convolution thunk to avoid reusing the deallocated memory until the - // convolution thunk is done with it. - virtual bool ShouldBlockFutureThunks() { return false; } - // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. - virtual tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) = 0; + // + // Precondition: Initialize(stream->parent()) has been called. + virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index bd65e72393a59e72671ff0cc32c37eaa48856255..97cb04c38fbf18e516857f5269c984696ca204c3 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -17,13 +17,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { -tensorflow::Status TupleThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -42,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream( tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), sizeof(void*) * tuple_element_buffer_addresses.size()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 3b1a496328540ae69a449e7080903d31284885d1..951f809b51937c97a6e7de0345ec58a8b66a4242 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -45,9 +45,8 @@ class TupleThunk : public Thunk { TupleThunk(const TupleThunk&) = delete; TupleThunk& operator=(const TupleThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index c21559af6d2e5dfb5aaf62afcdcaed514e0914c9..30b9640c4c75dae61e9a90da5fb10e9d4a90cd26 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,15 +34,17 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -Status WhileThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); +Status WhileThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); return Status::OK(); } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - perftools::gputools::DeviceMemoryBase condition_result_data = + se::Stream* stream) { + se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); while (true) { diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 4c9f45de9e42494df58706d0a4a3eb0c4220b8b8..22176685a92df9c95b10f755b209309843c0fa3a 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,9 +45,10 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index e6caec8625f0d622dbb92bcc20802d254fe23f94..7749201cbceece216a2db2569936949eb7de5125 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -144,7 +144,7 @@ class ExprTree { TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), tagged_instructions)); } - return tensorflow::Status::OK(); + return Status::OK(); } private: @@ -169,7 +169,7 @@ class MatcherBase { // Attempts to match each ExprTree in 'expr_trees_'. // Returns OK on the first successful match, error status otherwise. - virtual tensorflow::Status Run() { + virtual Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { status = MatchExprTree(expr_tree); @@ -201,7 +201,7 @@ class MatcherBase { } else if (type == S64) { *const_value = literal.GetFirstElement(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr GetTaggedInstruction( @@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase { gte_fusion_param0->name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; @@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase { GetTaggedInstruction("loop_start", tagged_instructions)); TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - return tensorflow::Status::OK(); + return Status::OK(); } const HloInstruction* while_hlo_; @@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase { return InvalidArgument("Unexpected tuple index instruction : %s", inst->name().c_str()); } else if (tag == "loop_increment") { - // Parse the constant which represents the loop induction variable - // increment value. + // ParseHloString the constant which represents the loop induction + // variable increment value. TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); } else if (tag == "param0" && inst != computation_->parameter_instruction(0)) { @@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase { } } } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 05017008e2ddbe0b9e78d06275fdec5d08d94bfa..acf661148699dab18916e3065ee647d37fda6208 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -82,7 +82,8 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // instructions. Sets the computation as the entry to an HLO module and returns // the module. std::unique_ptr MakeBigGraph() { - auto module = MakeUnique("BigGraph"); + HloModuleConfig config; + auto module = MakeUnique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index a2d13c013c56059148ccd04dba2137a5b2badc42..5dba50a63b1d77bda0835e0333cc7dd9ddbe2dcf 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -27,44 +26,47 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -namespace { - -// Returns the set of buffers that may be sources of all operands of the given -// instruction. The returned buffers are guaranteed to have no duplicates, and -// to be sorted in a deterministic order. -std::vector UniqueOperandSourceBuffers( - const HloInstruction* instruction, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector buffers; - for (const HloInstruction* operand : instruction->operands()) { - points_to_analysis.GetPointsToSet(operand).ForEachElement( - [&](const ShapeIndex& /*index*/, - const PointsToSet::BufferList& points_to) { - buffers.insert(buffers.end(), points_to.begin(), points_to.end()); - }); +StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; } - // Sort and then remove duplicates from buffers. - std::sort(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - buffers.erase(std::unique(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() == b->id(); - }), - buffers.end()); - return buffers; + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; } -} // namespace +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = @@ -79,7 +81,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, @@ -93,6 +95,7 @@ Status HeapSimulator::RunComputation( const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { + VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -101,47 +104,75 @@ Status HeapSimulator::RunComputation( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. - FlatMap> live_buffers; + // 'used_buffers' is the reverse map - it tracks which buffers were used by an + // instruction, so that we can remove the instructions from a buffer's live + // set after they are visited. + FlatMap> live_buffers; + FlatMap> used_buffers; + auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( + const HloInstruction* user, + const BufferValue* buffer) { + if (!IgnoreBuffer(buffer)) { + VLOG(4) << " Adding user " << user->name() << " to buffer " + << buffer->ToString(); + live_buffers[buffer].insert(user); + used_buffers[user].insert(buffer); + } + }; + + // Initialize live_buffers for each buffer that we're going to assign. The + // set of instructions that need to be visited contains all users of all + // aliases, that is, all users of all instructions that have the buffer + // contained in their points-to set. + for (const HloInstruction* instruction : instruction_sequence) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction); + const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); + for (const HloInstruction* user : instruction->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + for (const BufferValue* buffer : buffer_set) { + add_user_to_buffer(user, buffer); + } + } else { + // A GetTupleElement doesn't need to keep all of its operand's buffers + // alive. It only needs the buffers that relate to the element its + // extracting, and the tuple it's extracting from, but not the buffers + // for the other elements. + for (const BufferValue* buffer : points_to.element({})) { + add_user_to_buffer(user, buffer); + } + const PointsToSet& gte_points_to = + points_to_analysis.GetPointsToSet(user); + for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) { + add_user_to_buffer(user, buffer); + } + } + } + } const HloInstruction* root = computation.root_instruction(); - auto output_source_buffers = - points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + BufferValueCompactPointerSet output_source_buffers = + ToBufferValueCompactPointerSet( + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet()); - std::vector dead_buffers_to_free; - std::vector operand_buffers_to_free; + std::vector dead_buffers_to_free; + std::vector operand_buffers_to_free; for (const HloInstruction* instruction : instruction_sequence) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); - // Initialize live_buffers for each buffer that we're going to assign. The - // set of instructions that need to be visited contains all users of all - // aliases. The alias itself is not necessary; if it has users, the users - // are necessarily scheduled after the alias. And if it has no users, it is - // either a dead value or an output, both of which are handled below. - // - // We ignore control dependencies here. The reasoning is that the control - // dependencies have already been accounted for in the ordering of the given - // 'instruction_sequence', and should not otherwise artificially extend the - // lifetime of buffers that aren't already connected by a data dependency. + VLOG(3) << "Instruction: " << instruction->ToString(); + for (const BufferValue* buffer : buffers_defined_by_instruction) { + VLOG(4) << " Defines: " << buffer->ToString() + << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); + } + dead_buffers_to_free.clear(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } - FlatSet* live_set = nullptr; - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - const std::vector& users = - alias.instruction()->users(); - if (!users.empty()) { - if (live_set == nullptr) { - live_set = &live_buffers[buffer]; - } - live_set->insert(users.begin(), users.end()); - } - } - // Add a nullptr sentry to ensure entry parameters and output source // buffers are not freed until the very end. const bool entry_parameter = @@ -165,11 +196,12 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : - UniqueOperandSourceBuffers(instruction, points_to_analysis)) { + for (const BufferValue* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } + VLOG(4) << " Removing user " << instruction->name() << " from buffer " + << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); FlatSet* live_set = &it->second; live_set->erase(instruction); @@ -178,6 +210,11 @@ Status HeapSimulator::RunComputation( operand_buffers_to_free.push_back(operand_buffer); } } + // Sort to get a deterministic iteration order. + std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), + [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -186,7 +223,7 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -197,12 +234,14 @@ Status HeapSimulator::RunComputation( // we must be the last user of the buffer. bool shared = false; if (options_.may_reuse_operand_buffers) { - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + for (const BufferValue* operand_buffer : operand_buffers_to_free) { if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( + points_to_analysis.CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index())) { + VLOG(3) << " Sharing: " << buffer->ToString() << " with " + << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; @@ -211,6 +250,7 @@ Status HeapSimulator::RunComputation( } if (!shared) { + VLOG(3) << " Allocating: " << buffer->ToString(); Alloc(buffer, instruction); } } @@ -243,21 +283,35 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. - for (const LogicalBuffer* buffer : dead_buffers_to_free) { + for (const BufferValue* buffer : dead_buffers_to_free) { + VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } - for (const LogicalBuffer* buffer : operand_buffers_to_free) { + for (const BufferValue* buffer : operand_buffers_to_free) { + VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } } // Any remaining live buffers must be entry parameters or output source - // buffers, which had a nullptr sentry added. Free them now. + // buffers, which had a nullptr sentry added. Free them now, in a + // deterministic order. + std::vector to_free; + to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { - const LogicalBuffer* buffer = buffer_pending.first; + const BufferValue* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; + to_free.push_back(buffer); + } + + std::sort(to_free.begin(), to_free.end(), + [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); + for (const BufferValue* buffer : to_free) { + VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } @@ -266,7 +320,7 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), @@ -278,7 +332,7 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} -bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { +bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { // Buffers for constants are ignored unless the alloc_constants option is // set. Also ignore buffers that we're not meant to assign. // @@ -292,7 +346,7 @@ bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { } // Alloc always calls the underlying heap algorithm. -void HeapSimulator::Alloc(const LogicalBuffer* buffer, +void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { CHECK(allocated_buffers_.count(buffer) == 0) << "Alloc called on allocated buffer: " << *buffer; @@ -312,7 +366,7 @@ void HeapSimulator::Alloc(const LogicalBuffer* buffer, // buffers whose group liveness has expired. Shared group liveness is tracked // by maintaining a refcount; the Free call on the last buffer in the group // causes Free to be called on the underlying algorithm. -void HeapSimulator::Free(const LogicalBuffer* buffer, +void HeapSimulator::Free(const BufferValue* buffer, const HloInstruction* instruction) { auto shared_it = shared_buffers_.find(buffer); if (shared_it != shared_buffers_.end()) { @@ -343,8 +397,8 @@ void HeapSimulator::Free(const LogicalBuffer* buffer, // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to // Alloc. The 'shared' buffer must be a previously allocated or shared buffer. // Both 'buffer' and 'shared' will be associated with the same SharedGroup. -void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, - const LogicalBuffer* shared, +void HeapSimulator::ShareBuffer(const BufferValue* buffer, + const BufferValue* shared, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; @@ -355,7 +409,7 @@ void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, CHECK(freed_buffers_.count(shared) == 0) << "ShareBuffer called on freed shared buffer: " << *shared; - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; auto shared_it = shared_buffers_.find(shared); if (shared_it != shared_buffers_.end()) { // The 'shared' buffer already has a group; it might be the canonical, but @@ -389,7 +443,7 @@ HeapSimulator::Result HeapSimulator::Finish() { // collecting statistics, e.g. NoFragmentationStatsHeap. if (!result.chunk_map.empty()) { for (const auto& share_pair : shared_buffers_) { - const LogicalBuffer* buffer = share_pair.first; + const BufferValue* buffer = share_pair.first; std::shared_ptr group = share_pair.second; if (buffer != group->canonical) { // The canonical must already exist in the chunk_map, since we called @@ -418,9 +472,9 @@ HeapSimulator::Result HeapSimulator::Finish() { } void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* share_with_canonical) { + const BufferValue* share_with_canonical) { HeapSimulatorTrace::Event* event = debug_trace_.add_events(); event->set_kind(kind); event->set_buffer_id(buffer->id()); @@ -434,14 +488,14 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } @@ -453,12 +507,12 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) { SetMode(kAlloc); run_.emplace_back(Op{buffer, size}); } -void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) { CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer; SetMode(kFree); run_.emplace_back(Op{buffer, size}); @@ -499,7 +553,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { run_.clear(); } -void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -567,7 +621,7 @@ void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size}); } -void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) { auto alloc_it = result_.chunk_map.find(buffer); CHECK(alloc_it != result_.chunk_map.end()) << "Free called on non-allocated buffer: " << *buffer; diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 636f19dd39f09721bd82fc4b44785f196f281ad7..3be3bb8e7fae9d22c9be00f3f81d2b93638c22ef 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -33,6 +34,21 @@ limitations under the License. namespace xla { +// Returns the minimum memory required to compute an HLO module where all +// computations have been scheduled (represented by the given module_sequence), +// assuming no fragmentation. +StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + +// Returns the minimum memory required to compute the given computation, +// assuming no fragmentation. +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // Forward declare classes defined below. class HeapAlgorithm; @@ -43,7 +59,7 @@ class HeapAlgorithm; // don't need to return the assignment of buffer offsets until the very end. class HeapSimulator { public: - // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be + // Chunk represents a contiguous piece of memory. Each BufferValue will be // associated with a chunk in the assignment result. struct Chunk { int64 offset; @@ -55,7 +71,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + tensorflow::gtl::FlatMap chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -81,7 +97,7 @@ class HeapSimulator { bool alloc_constants; // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. - const tensorflow::gtl::FlatSet* buffers_to_assign; + const BufferValueFlatSet* buffers_to_assign; }; // Run the heap simulation with the given algorithm, assuming the given @@ -97,7 +113,7 @@ class HeapSimulator { std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' @@ -109,7 +125,7 @@ class HeapSimulator { const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); private: @@ -118,7 +134,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -127,21 +143,21 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis); - bool IgnoreBuffer(const LogicalBuffer* buffer) const; - void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction); - void Free(const LogicalBuffer* buffer, const HloInstruction* instruction); - void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared, + bool IgnoreBuffer(const BufferValue* buffer) const; + void Alloc(const BufferValue* buffer, const HloInstruction* instruction); + void Free(const BufferValue* buffer, const HloInstruction* instruction); + void ShareBuffer(const BufferValue* buffer, const BufferValue* shared, const HloInstruction* instruction); Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* shared_with_canonical); + const BufferValue* shared_with_canonical); const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; - const LogicalBuffer::SizeFunction size_fn_; + const BufferValue::SizeFunction size_fn_; const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; @@ -160,15 +176,15 @@ class HeapSimulator { // The shared_buffers_ map associates each shared buffer (including the // canonical) to its SharedGroup control block. struct SharedGroup { - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + tensorflow::gtl::FlatSet allocated_buffers_; + tensorflow::gtl::FlatSet freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -186,10 +202,10 @@ class HeapAlgorithm { virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Alloc(const BufferValue* buffer, int64 size) = 0; // Free de-allocates a previously allocated buffer. - virtual void Free(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Free(const BufferValue* buffer, int64 size) = 0; // Finish collects the buffer offset assignment results. Free may only be // called once, after the Alloc and Free calls. @@ -205,8 +221,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: @@ -223,14 +239,14 @@ class DecreasingSizeRunsHeap : public HeapAlgorithm { : algorithm_(std::move(algorithm)) {} ~DecreasingSizeRunsHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: // A single Alloc or Free operation that we've buffered in run_. struct Op { - const LogicalBuffer* buffer; + const BufferValue* buffer; int64 size; }; @@ -266,8 +282,8 @@ class LazyBestFitHeap : public HeapAlgorithm { LazyBestFitHeap(int64 alignment) : alignment_(alignment) {} ~LazyBestFitHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 387b649a731ebcbfd8307807469f39f22d192b06..309ab85f784274835904015472f3f0d601885763 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -20,11 +20,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -33,12 +34,70 @@ limitations under the License. namespace xla { namespace { +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; // CallSequence records a sequence of Alloc/Free/Finish calls. -using CallSequence = std::vector>; +using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. class HeapCallRecorder : public HeapAlgorithm { @@ -46,7 +105,7 @@ class HeapCallRecorder : public HeapAlgorithm { explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override { + void Alloc(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kAlloc, buffer); // Instead of assigning a real offset, we set the cardinality of the Alloc // call. This isn't a valid assignment, but allows us to easily test for @@ -54,7 +113,7 @@ class HeapCallRecorder : public HeapAlgorithm { const int64 offset = result_.chunk_map.size(); result_.chunk_map.emplace(buffer, Chunk{offset, size}); } - void Free(const LogicalBuffer* buffer, int64 size) override { + void Free(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kFree, buffer); } Result Finish() override { @@ -76,7 +135,8 @@ class HeapSimulatorTracker { HeapSimulatorTracker( const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -84,7 +144,7 @@ class HeapSimulatorTracker { // size of the buffers doesn't matter, so we always return 0. We rely on // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. - auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; + auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = MakeUnique( MakeUnique(&actual_calls_)); result_ = HeapSimulator::Run( @@ -94,7 +154,8 @@ class HeapSimulatorTracker { } explicit HeapSimulatorTracker(const string& name) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -115,9 +176,9 @@ class HeapSimulatorTracker { // Hack the size_fn so that it returns a decreasing value as we step through // the sequence. This lets us ensure the Alloc calls are in the sequence - // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // order. The Free calls are sorted by BufferValue.id, which is at least // deterministic. - auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; auto algorithm = MakeUnique( @@ -130,8 +191,8 @@ class HeapSimulatorTracker { HloModule* module() { return module_.get(); } // Returns the buffer defined at the given instruction and index. - const LogicalBuffer* BufferAt(const HloInstruction* instruction, - const ShapeIndex& index) const { + const BufferValue* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { return points_to_analysis_->GetBufferDefinedAt(instruction, index) .ConsumeValueOrDie(); } @@ -147,8 +208,8 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const LogicalBuffer* a = BufferAt(instruction_a, index_a); - const LogicalBuffer* b = BufferAt(instruction_b, index_b); + const BufferValue* a = BufferAt(instruction_a, index_a); + const BufferValue* b = BufferAt(instruction_b, index_b); EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) << *a << ", " << *b; } @@ -410,6 +471,56 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, IndependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramB = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32scalar_, "paramB")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kMultiply, paramA, paramB)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kAdd, paramA, paramB)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + auto element0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0)); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec4_, element0, {0})); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kSubtract, paramA, paramB)); + auto element1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1)); + auto output = builder.AddInstruction( + HloInstruction::CreateTuple({broadcast, sub, element1})); + + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramB, mul, add, tuple, element0, + broadcast, sub, element1, output}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramB, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(add, {})}, + {kAlloc, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(broadcast, {})}, + // The mul can be freed right after the broadcast happens, even though + // The other GetTupleElement is still alive. + {kFree, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(sub, {})}, + // The temporary tuple is now dead. + {kFree, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(output, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramB, {})}, + {kFree, tracker.BufferAt(add, {})}, + {kFree, tracker.BufferAt(broadcast, {})}, + {kFree, tracker.BufferAt(sub, {})}, + {kFree, tracker.BufferAt(output, {})}, + {kFinish, nullptr}, + }); +} + TEST_F(HeapSimulatorTest, WholeModule) { HeapSimulatorTracker tracker(TestName()); @@ -472,7 +583,7 @@ TEST_F(HeapSimulatorTest, WholeModule) { // Now the final cond less-than buffer is allocated. {kAlloc, tracker.BufferAt(cond_lt, {})}, - // The order of the remaining Free calls is based on the LogicalBuffer.id, + // The order of the remaining Free calls is based on the BufferValue.id, // which is deterministic, but not obvious. {kFree, tracker.BufferAt(param, {})}, {kFree, tracker.BufferAt(param, {0})}, @@ -494,40 +605,40 @@ TEST_F(HeapSimulatorTest, WholeModule) { class HeapAlgorithmTestBase : public ::testing::Test { protected: HeapAlgorithmTestBase() : builder_("heap_simulator_test") { - buffer_a_ = DummyLogicalBuffer(); - buffer_b_ = DummyLogicalBuffer(); - buffer_c_ = DummyLogicalBuffer(); - buffer_d_ = DummyLogicalBuffer(); - buffer_e_ = DummyLogicalBuffer(); - buffer_f_ = DummyLogicalBuffer(); - buffer_g_ = DummyLogicalBuffer(); - buffer_h_ = DummyLogicalBuffer(); - buffer_i_ = DummyLogicalBuffer(); + buffer_a_ = DummyBufferValue(); + buffer_b_ = DummyBufferValue(); + buffer_c_ = DummyBufferValue(); + buffer_d_ = DummyBufferValue(); + buffer_e_ = DummyBufferValue(); + buffer_f_ = DummyBufferValue(); + buffer_g_ = DummyBufferValue(); + buffer_h_ = DummyBufferValue(); + buffer_i_ = DummyBufferValue(); } ~HeapAlgorithmTestBase() override {} - const LogicalBuffer* buffer_a_; - const LogicalBuffer* buffer_b_; - const LogicalBuffer* buffer_c_; - const LogicalBuffer* buffer_d_; - const LogicalBuffer* buffer_e_; - const LogicalBuffer* buffer_f_; - const LogicalBuffer* buffer_g_; - const LogicalBuffer* buffer_h_; - const LogicalBuffer* buffer_i_; + const BufferValue* buffer_a_; + const BufferValue* buffer_b_; + const BufferValue* buffer_c_; + const BufferValue* buffer_d_; + const BufferValue* buffer_e_; + const BufferValue* buffer_f_; + const BufferValue* buffer_g_; + const BufferValue* buffer_h_; + const BufferValue* buffer_i_; private: - // Create a dummy LogicalBuffer to pass to the heap algorithm. - const LogicalBuffer* DummyLogicalBuffer() { - const LogicalBuffer::Id id = buffers_.size(); + // Create a dummy BufferValue to pass to the heap algorithm. + const BufferValue* DummyBufferValue() { + const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(const0, ShapeIndex{}, id)); + buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } HloComputation::Builder builder_; - std::vector> buffers_; + std::vector> buffers_; }; class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index a43785b4a9701369ae315f67d4d64d03dc6c081d..e201359d3d25b7d2dda852762c6de1fcb75685d7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING. -// -// Don't use these protos in the real compilation or execution codepaths. The -// data format is meant for debugging only, and may change without notice. +// This proto file defines messages which represent the HLO module. This is a +// full fidelity serialization of the c++ HLO constructs. // // Many of the protos below are simple 1-to-1 serializations of the -// corresponding C++ classes. +// corresponding C++ classes, e.g., HloModule, HloComputation, and +// HloInstruction. // // FIELD NAMES ARE IMPORTANT // @@ -38,16 +37,19 @@ option cc_enable_arenas = true; message HloInstructionProto { reserved 10; reserved "parameter_name"; + reserved 12; + reserved "fused_instructions_computation"; + reserved 4; + reserved "operand_names"; + reserved 5; + reserved "control_predecessor_names"; + reserved 6; + reserved "called_computation_names"; string name = 1; string opcode = 2; xla.Shape shape = 3; - // TODO(b/67782397): Replace instruction names with HloInstruction ids. - repeated string operand_names = 4; - repeated string control_predecessor_names = 5; - repeated string called_computation_names = 6; - xla.OpMetadata metadata = 7; // Literal, only present for kConstant. @@ -58,7 +60,6 @@ message HloInstructionProto { // Fusion state, only present for kFusion. string fusion_kind = 11; - HloComputationProto fused_instructions_computation = 12; // Index for kGetTupleElement. int64 tuple_index = 13; @@ -133,28 +134,61 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_window_bounds = 34; + + // Compute Host. + string channel_name = 41; + int64 cost_estimate_ns = 42; + + // The id of this instruction. + int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 control_predecessor_ids = 37; + repeated int64 called_computation_ids = 38; + repeated int64 replica_group_ids = 44; + + xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. message HloComputationProto { + reserved 3; + reserved "root_name"; + string name = 1; // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; - // The name of the root of the computation. - string root_name = 3; + // The program shape (with layout) of this computation. + xla.ProgramShape program_shape = 4; + + // The id of this computation. + int64 id = 5; + + // The id of the root of the computation. + int64 root_id = 6; } // Serialization of HloModule. message HloModuleProto { string name = 1; string entry_computation_name = 2; + int64 entry_computation_id = 6; // The array of computations is always in a valid dependency order, where // callees appear before their callers. repeated HloComputationProto computations = 3; + + // The program shape (with layout) of the entry computation. + xla.ProgramShape program_shape = 4; + + // The id of this module. + int64 id = 5; } // Serialization of HloOrdering. @@ -266,3 +300,20 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } + +// Encapsulates HloProto together with the arguments, result, and +// execution_platform. This message is used for purposes such as +// analysis/replay/file-storage. +message HloSnapshot { + // The hlo graph. + HloProto hlo = 1; + + // The arguments passed to the graph. + repeated LiteralProto arguments = 2; + + // The result of the graph. + LiteralProto result = 3; + + // The name of the platform used to run the graph. + string execution_platform = 4; +} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 30e32a46d7dd0923f738939c33407ac7484b5bbe..a88283ed9a6459b4fa9310e160b59c77d51f1027 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -171,24 +171,21 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } - // Compute and return a vector of buffers that the given value must be - // contained in due to HLO aliasing rules. - std::vector ComputeAliasedBuffers(const HloValue& value) { + void ComputeWhileAliasedBuffers(const HloValue& value, + std::vector* aliased_buffers) { + VLOG(3) << "Compute kWhile aliases"; // Value is init of a while (use is while). - std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = dataflow_.GetUniqueValueAt(use.instruction, use.operand_index); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); VLOG(3) << " value is init value to a while; must share buffer with " "while value " << while_value.ToShortString(); } } - // Value is a parameter of a while body/condition. if (value.defining_instruction()->opcode() == HloOpcode::kParameter) { const HloComputation* computation = @@ -205,11 +202,10 @@ class BufferValueMap { VLOG(3) << " value is parameter value of the body or condition of a " "while; must share buffer with while value " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } - // Value is the root of a while body. for (const HloPosition& position : value.positions()) { const HloComputation* computation = position.instruction->parent(); @@ -224,27 +220,71 @@ class BufferValueMap { const HloValue& while_value = dataflow_.GetUniqueValueAt( callsite.instruction(), position.index); - VLOG(3) << " value is root the body computation of a while; must " - "share buffer with while value " + VLOG(3) << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; body root and while value root must share buffer " + "among them : " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } } - // Value is the output of the while instruction itself. if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { VLOG(3) << " value is output of a while instruction"; - aliased_buffers.push_back(GetBufferForValue(value)); + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + + void ComputeConditionalAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute kConditional aliases"; + // Aliases the buffers of the true/false computations roots, with the one of + // the conditional. + for (const HloPosition& position : value.positions()) { + const HloComputation* computation = position.instruction->parent(); + const CallGraphNode& call_graph_node = + dataflow_.call_graph().GetNode(computation); + if (position.instruction == computation->root_instruction()) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + // Call graph must have been flattened. + CHECK_EQ(call_graph_node.caller_callsites().size(), 1); + + const HloValue& cond_value = dataflow_.GetUniqueValueAt( + callsite.instruction(), position.index); + VLOG(3) + << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; true/false branch roots must share buffer among them : " + << cond_value.ToShortString(); + aliased_buffers->push_back(GetBufferForValue(cond_value)); + } + } + } + } + // Value is the output of the conditional instruction itself. + if (value.defining_instruction()->opcode() == HloOpcode::kConditional) { + VLOG(3) << " value is output of a conditional instruction"; + aliased_buffers->push_back(GetBufferForValue(value)); } + } + // Compute and return a vector of buffers that the given value must be + // contained in due to HLO aliasing rules. + std::vector ComputeAliasedBuffers(const HloValue& value) { + for (const HloUse& use : value.uses()) { + VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; + } + std::vector aliased_buffers; + ComputeWhileAliasedBuffers(value, &aliased_buffers); + ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. std::sort(aliased_buffers.begin(), aliased_buffers.end()); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); - return aliased_buffers; } diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7f73bba036534a62a70a80431236cffa766c9b38 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +class HloInstruction; + +template +using EnableIfDerivedFromHlo = + typename std::enable_if::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template * = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template * = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast( + Cast(const_cast(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template * = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template * = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast( + CastOrNull(const_cast(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template * = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast(instruction); +} + +// Non-const overload of DynCast. +template * = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast( + DynCast(const_cast(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template * = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template * = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast( + DynCastOrNull(const_cast(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3364275409122254bf99b40a7d2fcbb2d7564cc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DummyInstruction : public HloInstruction { + public: + DummyInstruction() + : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {} +}; + +class AnotherDummyInstruction : public HloInstruction { + public: + AnotherDummyInstruction() + : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {} +}; + +TEST(HloCastingUtilsTest, CastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(Cast(null), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = CastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(DynCast(null), ""); +} + +TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = DynCastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h new file mode 100644 index 0000000000000000000000000000000000000000..658643b427a9625fac1166151a89cbd669f817d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +class HloInstruction; +class HloComputation; +class HloModule; + +// Data structure used to track the cloning of HloInstruction and HloComputation +// objects. +class HloCloneContext { + public: + // Creates a new HloCloneContext object to clone HloInstruction and + // HloComputation objects to be added to the module specified as argument. + // The suffix string will be appended to computation names. + explicit HloCloneContext(HloModule* module, const string& suffix = "") + : module_(module), suffix_(suffix) {} + + HloModule* module() const { return module_; } + + const string& suffix() const { return suffix_; } + + void MapInstruction(const HloInstruction* old_instruction, + HloInstruction* new_instruction) { + instructions_[old_instruction] = new_instruction; + } + + void MapComputation(const HloComputation* old_computation, + HloComputation* new_computation) { + computations_[old_computation] = new_computation; + } + + // Finds the new instruction mapped to its old copy, or return nullptr in case + // it is not found. + HloInstruction* FindInstruction(const HloInstruction* old_instruction) const { + return FindOrDefault(instructions_, old_instruction, nullptr); + } + + // Finds the new computation mapped to its old copy, or return nullptr in case + // it is not found. + HloComputation* FindComputation(const HloComputation* old_computation) const { + return FindOrDefault(computations_, old_computation, nullptr); + } + + // Retrieves the new instruction mapped to its old copy, or fail if not found. + HloInstruction* GetInstruction(const HloInstruction* old_instruction) const { + return FindOrDie(instructions_, old_instruction); + } + + // Retrieves the new computation mapped to its old copy, or fail if not found. + HloComputation* GetComputation(const HloComputation* old_computation) const { + return FindOrDie(computations_, old_computation); + } + + const tensorflow::gtl::FlatMap& + cloned_instructions() const { + return instructions_; + } + + const tensorflow::gtl::FlatMap& + cloned_computations() const { + return computations_; + } + + private: + HloModule* module_; + string suffix_; + tensorflow::gtl::FlatMap + instructions_; + tensorflow::gtl::FlatMap + computations_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 21e6b2ca730f6347af902097e6496826b861e8a3..b158f44923982642615543cebbd54f00596e2641 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -64,7 +64,8 @@ HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), + unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); @@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal( instruction->UniquifyName(&parent()->instruction_name_uniquer()); instruction->SetUniqueId(parent()->NewUniqueInstructionId()); } - Reparent(instruction.get()); + instruction->set_parent(this); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); @@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } -void HloComputation::Reparent(HloInstruction* instruction) { - instruction->set_parent(this); -} - bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -307,19 +304,51 @@ void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, std::list* post_order) { - if (visited->count(computation) > 0) { - return; + if (visited->insert(computation).second) { + for (auto* instruction : computation->instructions()) { + for (HloComputation* called_computation : + instruction->called_computations()) { + ComputeComputationPostOrder(called_computation, visited, post_order); + } + } + post_order->push_back(computation); } +} - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - ComputeComputationPostOrder(called_computation, visited, post_order); +std::list ComputeInstructionPostOrder( + HloInstruction* root, tensorflow::gtl::FlatSet* visited) { + std::list post_order; + std::vector> dfs_stack; + dfs_stack.emplace_back(root, false); + while (!dfs_stack.empty()) { + const auto current = dfs_stack.back(); + if (current.second) { + dfs_stack.pop_back(); + if (!visited->insert(current.first).second) { + continue; + } + post_order.push_back(current.first); + } else { + if (visited->count(current.first)) { + dfs_stack.pop_back(); + continue; + } + dfs_stack.back().second = true; + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current.first->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i], false); + } + + for (HloInstruction* op : current.first->control_predecessors()) { + dfs_stack.emplace_back(op, false); + } } } - - visited->insert(computation); - post_order->push_back(computation); + return post_order; } } // namespace @@ -335,9 +364,9 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice(post_order.end(), - InstructionPostOrderer::GetOrder(instruction.get(), - &added_instructions)); + post_order.splice( + post_order.end(), + ComputeInstructionPostOrder(instruction.get(), &added_instructions)); } } post_order.splice(post_order.end(), trace_instructions); @@ -354,6 +383,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -367,25 +401,38 @@ std::list HloComputation::MakeEmbeddedComputationsList() string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } - if (options.print_percent()) { - s << "%"; + + if (!options.is_in_nested_computation()) { + if (options.print_percent()) { + s << "%"; + } + s << name() << " "; } - s << name(); + if (options.print_program_shape()) { - s << " " << ShapeUtil::HumanString(ComputeProgramShape()); - } - s << " {\n"; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; + } + s << "{\n"; + { + // Print the instructions in this computation. + HloPrintOptions new_options = options; + new_options.set_indent_amount(options.indent_amount() + 1) + .set_is_in_nested_computation(true); + CanonicalNameMap name_map; + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < new_options.indent_amount(); i++) { + s << " "; + } + s << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToStringWithCanonicalNameMap(new_options, &name_map) + << "\n"; } - s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString(options) << "\n"; } + for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } s << "}"; return s.str(); @@ -393,43 +440,56 @@ string HloComputation::ToString(const HloPrintOptions& options) const { HloComputationProto HloComputation::ToProto() const { HloComputationProto proto; + CHECK(unique_id_ != -1) + << "This computation does not have a valid id. Please make sure the " + "computation is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); for (const HloInstruction* instruction : MakeInstructionPostOrder()) { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } - proto.set_root_name(root_instruction()->name()); + proto.set_root_id(root_instruction()->unique_id()); + *proto.mutable_program_shape() = ComputeProgramShape(); return proto; } /* static */ StatusOr> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction) { + const HloComputationProto& proto, + const tensorflow::gtl::FlatMap& computation_map) { + tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap to_proto_id; std::vector> instructions; - tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, - HloInstruction::CreateFromProto( - module, instruction_proto, instruction_map, - computation_map, add_fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } - TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); - instruction_map[instruction->name()] = instruction.get(); + TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); + instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } - TF_RET_CHECK(!proto.root_name().empty()); - TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); - HloInstruction* root = instruction_map.at(proto.root_name()); - return WrapUnique(new HloComputation( - proto.name(), parameter_count, &instructions, root, fusion_instruction)); + TF_RET_CHECK(proto.root_id() != -1); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); + HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + return WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( @@ -532,7 +592,6 @@ ProgramShape HloComputation::ComputeProgramShape() const { } *program_shape.mutable_result() = root_instruction_->shape(); - LayoutUtil::ClearLayout(&program_shape); return program_shape; } @@ -728,18 +787,24 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloCloneContext* context) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloCloneContext* context, const string& suffix) { + std::unique_ptr context_ptr; + if (context == nullptr) { + context_ptr = MakeUnique(parent(), suffix); + context = context_ptr.get(); + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -761,24 +826,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; - std::unique_ptr new_instr = nullptr; + std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) << "replacements map tried to eliminate a used instruction " << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + new_operands.push_back(context->GetInstruction(replaced_operand)); } new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + instr->CloneWithNewOperands(instr->shape(), new_operands, context); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -786,27 +846,25 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/context->GetInstruction( + replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = context->GetInstruction(instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been + // successor may not have been remapped, because it might have been // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; + if (replaced_successor != nullptr) { + TF_CHECK_OK(new_instr->AddControlDependencyTo( + context->GetInstruction(replaced_successor))); } - - TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); } } - + context->MapComputation(this, result.get()); // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { @@ -814,7 +872,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( new_instr->DetachFromOperands(); } } - return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 39d864efcb70382b6f8e631d7e6e452ea6410104..0da4a305f3d5d694a1918fed294337100b0a27fd 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -49,9 +50,20 @@ class HloModule; // Describes a computation at the HLO level. // -// An HloComputation contains a directed acyclic graph of HLO instructions. The -// computation has a single root instruction which produces the output of the -// computation. +// You can think of an HloComputation like a function. It has some inputs +// (parameters) and returns exactly one value (the value of its root node). If +// you want to return multiple values, you can return a tuple. +// +// The instructions inside of a computation do not have an explicit total order. +// Instead, they have a partial order determined by their data and control +// dependencies. +// +// An HloModule contains one "entry computation" -- this is like main() in a C +// program. Every other computation inside of a module is attached to one or +// more HloInstructions, as a "nested computation". For example, the kMap +// instruction has a nested computation and "applies" it to every element of its +// input, elementwise. (That is, the input [x, y, z] is transformed to [f(x), +// f(y), f(z)].) class HloComputation { public: // Builder class for HloComputation. @@ -157,23 +169,13 @@ class HloComputation { // Creates a computation from the given proto. Arguments: // - // module: the module which will contain the computation. The newly created - // computation is *not* added to the module, however. // proto: the proto to convert from. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used only when the instruction is a fusion instruction. - // fusion_instruction: if non-null then the newly created computation will - // be constructed as a fused computation with this instruction as its - // fusion parent. static StatusOr> CreateFromProto( - HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction = nullptr); + const HloComputationProto& proto, + const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. // @@ -248,7 +250,7 @@ class HloComputation { ShapeTree* copies_added = nullptr); // Computes and returns the ProgramShape of this computation (shape of - // parameters and result without layout). + // parameters and result with layout). ProgramShape ComputeProgramShape() const; // Return whether `*this` and `other` are functionally equivalent. @@ -299,11 +301,11 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). + // If the clone context is specified, it will be populated with the cloned + // object mappings, and its module() will be used to add new computations + // into. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + HloCloneContext* context = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -313,7 +315,7 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -342,6 +344,15 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // The id of this computation should be unique within the module. + void SetUniqueId(int64 id) { + CHECK_EQ(unique_id_, -1); + CHECK_GE(id, 0); + unique_id_ = id; + } + + int64 unique_id() const { return unique_id_; } + private: explicit HloComputation( const string& name, int parameter_count, @@ -352,10 +363,6 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); - // Helper for setting the parent of instructions that are added to this - // computation. - void Reparent(HloInstruction* instruction); - // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. @@ -373,6 +380,7 @@ class HloComputation { std::vector CollectUnreachableRoots() const; string name_; + int64 unique_id_; HloInstruction* root_instruction_; // If this computation is a fusion computation, this field points to the diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 7b7588f4ba9aa622677db6f9d5022cc8cc029e04..25469a54c48f4f5cab478aba929f1cc18de8b81f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) { EXPECT_FALSE(reachability->IsReachable(constant2, copy)); } +TEST_F(HloComputationTest, Stringification) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloComputationTest, StringificationIndent) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = + HloPrintOptions().set_print_metadata(false).set_indent_amount(2); + EXPECT_EQ(computation->ToString(options), + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"); +} + +TEST_F(HloComputationTest, StringificationCanonical) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); + + options = HloPrintOptions().Canonical(); + EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 53450991b6fad5b9651d9d23b55c908e6b68e5dd..35ecd4428d0dfde2de445ea34472d2c78148c6c9 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,7 +35,10 @@ limitations under the License. namespace xla { StatusOr HloConstantFolding::Run(HloModule* module) { - auto evaluator = MakeUnique(); + // Limit the constant folding to 0 iterations to skip folding loops. This + // retains the behavior from before while loop support in HloEvaluator and may + // be revised. + auto evaluator = MakeUnique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7b552ee5b1798c4c7e24884a392c5982d7fb17ff..5d05ccfc0b223d8749a2577ba1bf96b1ab3e761b 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 4ec2ef27bf59b0c877ec38e55ef5c12debeec227..762e1afc71b108b2e32b5a7f7f1bbeb783fc6fbd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -141,19 +142,25 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { } Status HloCostAnalysis::HandleParameter(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -165,15 +172,22 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicSlice( + const HloInstruction* dynamic_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicUpdateSlice( + const HloInstruction* dynamic_update_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -328,6 +342,7 @@ Status HloCostAnalysis::HandleSelectAndScatter( Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -378,21 +393,106 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { - auto rhs_instruction = convolution->operand(1); + auto lhs = convolution->operand(0); + auto rhs = convolution->operand(1); + Window window = convolution->window(); + const auto& result_shape = convolution->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const auto& dnums = convolution->convolution_dimension_numbers(); - const int64 output_features = - convolution->shape().dimensions(dnums.output_feature_dimension()); - // For each output element, we do one fma per element in the kernel at some - // given output feature index. - const int64 fmas_per_output_element = - output_features > 0 - ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features - : 0; - const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - current_properties_[kFlopsKey] = - output_elements * fmas_per_output_element * kFmaFlops; + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_feature_dim = dnums.input_feature_dimension(); + const int64 output_feature_dim = dnums.output_feature_dimension(); + const int64 input_feature = + ShapeUtil::GetDimension(lhs_shape, input_feature_dim); + const int64 output_feature = + ShapeUtil::GetDimension(result_shape, output_feature_dim); + const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim); + + DimensionVector kernel_limits; + DimensionVector output_limits; + DimensionVector input_limits; + if (window.dimensions().empty()) { + window = window_util::MakeWindow({1}); + kernel_limits.push_back(1); + output_limits.push_back(1); + input_limits.push_back(1); + } else { + for (int64 spatial_dimension = 0; + spatial_dimension < window.dimensions_size(); ++spatial_dimension) { + // Spatial dimension number for kernel (rhs). + const int64 kernel_spatial_dim = + dnums.kernel_spatial_dimensions(spatial_dimension); + const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim); + kernel_limits.push_back(kernel_limit); + + // Spatial dimension number for output. + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(spatial_dimension); + const int64 output_limit = result_shape.dimensions(output_spatial_dim); + output_limits.push_back(output_limit); + + // Spatial dimension number for input (lhs). + const int64 input_spatial_dim = + dnums.input_spatial_dimensions(spatial_dimension); + const int64 input_limit = lhs_shape.dimensions(input_spatial_dim); + input_limits.push_back(input_limit); + } + } + + DimensionVector valid_position_counts; + + // Loop over each spatial dimension. + for (int64 spatial_dimension = 0; + spatial_dimension < window.dimensions_size(); ++spatial_dimension) { + int64 valid_position_count = 0; + // Loop over each point in the kernel. + for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension]; + ++kernel_idx) { + // Loop over each point in the output. + for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension]; + ++output_idx) { + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(spatial_dimension); + const int64 undilated_index = output_idx * window_dim.stride() - + window_dim.padding_low() + + kernel_idx * window_dim.window_dilation(); + + // Calculate the actual lhs (input) index after dilation. Avoid the + // division as an optimization. + const int64 lhs_spatial_index = + window_dim.base_dilation() > 1 + ? undilated_index / window_dim.base_dilation() + : undilated_index; + + // Skip if the lhs (input) index is to be dilated. + if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) { + continue; + } + + // Skip if input index is not in bound. + if (lhs_spatial_index < 0 || + lhs_spatial_index >= input_limits[spatial_dimension]) { + continue; + } + + valid_position_count += 1; + } + } + valid_position_counts.push_back(valid_position_count); + } + + const int64 fma_count = + input_feature * output_feature * batch * Product(valid_position_counts); + current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } @@ -473,11 +573,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - // We can't do anything sane with CustomCalls, since we don't know what they - // do, and returning an error status will stop iteration over this - // computation, which is probably also not what we want. So just punt and - // return OK. This will cause all of the properties to be reported as 0, - // which is fine. + // Mark applicable fields as "unknown", since we don't know what CustomCall + // does. This is better than returning an error, which would stop iteration, + // and therefore would prevent us from getting *any* stats for a computation + // which contains a CustomCall. + current_properties_[kOptimalSecondsKey] = -1; + current_properties_[kBytesAccessedKey] = -1; + current_properties_[kFlopsKey] = -1; current_should_compute_bottleneck_time_ = false; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f2a23fd98d18b77d5fb25853901a789..0d66736fe1d0677d13a63ede7a203d6ac20c76f5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleGenerateToken(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 3b289c240a45e8f3df8156ed89e879da2132d01a..d22bef56730da194816b4ee89dc3196439b350f9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -20,16 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" @@ -58,11 +55,10 @@ class HloCostAnalysisTest : public ::testing::Test { // whitebox accesses to the user computation built from the client, // as shown in the BuildHloGraph functions below. service_(static_cast(ClientLibrary::GetXlaService( - static_cast(client_)->platform()))), - computation_tracker_(service_->computation_tracker()) { + static_cast(client_)->platform()))) { // Create a computation for a unary user function: x => exp(x + 0.5) { - ComputationBuilder builder(client_, "add_and_exp"); + XlaBuilder builder("add_and_exp"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto half = builder.ConstantR0(0.5); builder.Exp(builder.Add(x, half)); @@ -73,7 +69,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -84,7 +80,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { - ComputationBuilder builder(client_, "sigmoid"); + XlaBuilder builder("sigmoid"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = builder.ConstantR0(1.0); builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); @@ -95,7 +91,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { - ComputationBuilder builder(client_, "max"); + XlaBuilder builder("max"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Max(x, y); @@ -106,7 +102,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { - ComputationBuilder builder(client_, "gt"); + XlaBuilder builder("gt"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Gt(x, y); @@ -117,35 +113,30 @@ class HloCostAnalysisTest : public ::testing::Test { } // Build HLO graph from the given builder and return the HLO module. - std::unique_ptr BuildHloGraph(ComputationBuilder* builder) { + std::unique_ptr BuildHloGraph(XlaBuilder* builder) { auto computation_status = builder->Build(); TF_CHECK_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto user_computation_status = - computation_tracker_.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - return std::move( - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) - .ValueOrDie()); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); + return HloModule::CreateFromProto(computation.proto(), config) + .ConsumeValueOrDie(); } Client* client_; Service* service_; - const ComputationTracker& computation_tracker_; // User computations used for higher order operations (e.g., Map, Reduce). - Computation add_; - Computation add_and_exp_; - Computation sigmoid_; - Computation max_; - Computation gt_; + XlaComputation add_; + XlaComputation add_and_exp_; + XlaComputation sigmoid_; + XlaComputation max_; + XlaComputation gt_; }; TEST_F(HloCostAnalysisTest, MatrixMultiply) { - ComputationBuilder builder(client_, "matrix_multiply"); + XlaBuilder builder("matrix_multiply"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); auto result = builder.Dot(lhs, rhs); @@ -167,7 +158,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { } TEST_F(HloCostAnalysisTest, Map) { - ComputationBuilder builder(client_, "map"); + XlaBuilder builder("map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); auto result = builder.Map({input}, add_and_exp_, {0}); @@ -184,14 +175,16 @@ TEST_F(HloCostAnalysisTest, Map) { } TEST_F(HloCostAnalysisTest, Convolution) { - ComputationBuilder builder(client_, "convolution"); + XlaBuilder builder("convolution"); auto input = builder.Parameter( - 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, - /*x_dim=*/20}), + 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), "input"); auto kernel = builder.Parameter( - 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, - /*x_dim=*/3}), + 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), "kernel"); auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); @@ -211,7 +204,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { } TEST_F(HloCostAnalysisTest, Reduce) { - ComputationBuilder builder(client_, "reduce"); + XlaBuilder builder("reduce"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = @@ -229,7 +222,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { } TEST_F(HloCostAnalysisTest, ReduceWindow) { - ComputationBuilder builder(client_, "reduce_window"); + XlaBuilder builder("reduce_window"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, @@ -246,7 +239,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { } TEST_F(HloCostAnalysisTest, SelectAndScatter) { - ComputationBuilder builder(client_, "select_and_scatter"); + XlaBuilder builder("select_and_scatter"); auto operand = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = @@ -267,7 +260,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { } TEST_F(HloCostAnalysisTest, Broadcast) { - ComputationBuilder b(client_, "broadcast"); + XlaBuilder b("broadcast"); b.Broadcast(b.ConstantR0(42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); @@ -278,7 +271,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { // Calculates the computation cost of a graph with more than one HLO node. TEST_F(HloCostAnalysisTest, FullyConnectedForward) { - ComputationBuilder builder(client_, "fully_connected_forward"); + XlaBuilder builder("fully_connected_forward"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = @@ -303,7 +296,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { - ComputationBuilder builder(client_, "conv_looking_matmul"); + XlaBuilder builder("conv_looking_matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), "input"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -316,7 +309,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis matmul_analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = @@ -368,8 +361,8 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -410,8 +403,8 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -425,7 +418,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); auto tuple = builder.Tuple({x, y}); @@ -440,5 +433,78 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { + XlaBuilder builder("BaseDilatedConvolution"); + auto input = builder.Parameter( + 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = builder.Parameter( + 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + + auto result = builder.ConvGeneralDilated( + input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(), 1472); +} + +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicSlice(x, builder.ConstantR1({1}), {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-update-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1({1.0}), + builder.ConstantR1({1})); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fb65c845a6d4407c81171f6c1569fee98b1d16d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -0,0 +1,301 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +using tensorflow::gtl::ArraySlice; +using tensorflow::strings::StrCat; + +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN(Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(opcode, lhs, rhs)); + return computation->AddInstruction( + HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); +} + +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, padding_value->parent()); + TF_ASSIGN_OR_RETURN( + Shape pad_shape, + ShapeInference::InferPadShape(operand->shape(), padding_value->shape(), + padding_config)); + return computation->AddInstruction(HloInstruction::CreatePad( + pad_shape, operand, padding_value, padding_config)); +} + +StatusOr MakeSliceHlo(HloInstruction* operand, + ArraySlice start_indices, + ArraySlice limit_indices, + ArraySlice strides) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( + operand->shape(), start_indices, + limit_indices, strides)); + return computation->AddInstruction(HloInstruction::CreateSlice( + slice_shape, operand, start_indices, limit_indices, strides)); +} + +StatusOr MakeConvolveHlo( + HloInstruction* lhs, HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), + window, dimension_numbers)); + return computation->AddInstruction(HloInstruction::CreateConvolve( + convolve_shape, lhs, rhs, window, dimension_numbers)); +} + +StatusOr MakeTransposeHlo(HloInstruction* operand, + ArraySlice dimensions) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN( + Shape transpose_shape, + ShapeInference::InferTransposeShape(operand->shape(), dimensions)); + return computation->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, operand, dimensions)); +} + +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand) { + HloComputation* computation = operand->parent(); + return computation->AddInstruction( + HloInstruction::CreateReshape(result_shape, operand)); +} + +StatusOr MakeReshapeHlo( + ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + result_shape_dim_bounds); + return MakeReshapeHlo(new_shape, operand); +} + +StatusOr MakeDynamicSliceHlo(HloInstruction* operand, + HloInstruction* start_indices, + ArraySlice slice_sizes) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, start_indices->parent()); + TF_ASSIGN_OR_RETURN( + Shape dynamic_slice_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), slice_sizes)); + return computation->AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, operand, start_indices, slice_sizes)); +} + +StatusOr MakeDynamicUpdateSliceHlo( + HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, update->parent()); + CHECK_EQ(computation, start_indices->parent()); + TF_ASSIGN_OR_RETURN( + Shape dynamic_update_slice_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + dynamic_update_slice_shape, operand, update, start_indices)); +} + +StatusOr MakeBroadcastHlo( + HloInstruction* operand, ArraySlice broadcast_dimensions, + ArraySlice result_shape_bounds) { + HloComputation* computation = operand->parent(); + Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + result_shape_bounds); + + return computation->AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, operand, broadcast_dimensions)); +} + +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index) { + HloComputation* computation = operand->parent(); + + TF_ASSIGN_OR_RETURN( + Shape gte_shape, + ShapeInference::InferGetTupleElementShape(operand->shape(), index)); + return computation->AddInstruction( + HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); +} + +StatusOr MakeConcatHlo(ArraySlice operands, + int64 dimension) { + CHECK_GT(operands.size(), 0); + + HloComputation* computation = operands[0]->parent(); + CHECK(c_all_of(operands, [&](HloInstruction* instr) { + return instr->parent() == computation; + })); + + std::vector operand_shapes; + c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); + + TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( + operand_shapes, dimension)); + return computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); +} + +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + return computation->AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); +} + +StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { + CHECK_GT(n, 0); + + const Shape& operand_shape = operand->shape(); + CHECK_GE(operand_shape.dimensions_size(), n); + int64 new_shape_leading_bound = 1; + for (int64 i = 0; i < n; i++) { + new_shape_leading_bound *= operand_shape.dimensions(i); + } + + std::vector new_shape_dims; + new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1); + new_shape_dims.push_back(new_shape_leading_bound); + + std::copy(operand_shape.dimensions().begin() + n, + operand_shape.dimensions().end(), + std::back_inserter(new_shape_dims)); + + Shape output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims); + + return MakeReshapeHlo(output_shape, operand); +} + +StatusOr PrependDegenerateDims(HloInstruction* operand, + int64 n) { + CHECK_GT(n, 0); + std::vector new_shape_dims; + const Shape& operand_shape = operand->shape(); + new_shape_dims.reserve(n + operand_shape.dimensions_size()); + new_shape_dims.insert(new_shape_dims.begin(), n, 1); + c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + return MakeReshapeHlo(new_shape_dims, operand); +} + +StatusOr ExpandFirstDimIntoNDims( + HloInstruction* operand, ArraySlice expanded_dims) { + CHECK_GT(operand->shape().dimensions_size(), 0); + CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); + + std::vector expanded_shape_dim_bounds; + expanded_shape_dim_bounds.reserve(expanded_dims.size() + + operand->shape().dimensions_size() - 1); + c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + std::copy(operand->shape().dimensions().begin() + 1, + operand->shape().dimensions().end(), + std::back_inserter(expanded_shape_dim_bounds)); + Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + expanded_shape_dim_bounds); + return MakeReshapeHlo(new_shape, operand); +} + +StatusOr ElideDegenerateDims(HloInstruction* operand, + ArraySlice dims_to_elide) { + CHECK(c_is_sorted(dims_to_elide)); + + const Shape& input_shape = operand->shape(); + // First accumulate in reverse + std::vector new_shape_dim_bounds; + new_shape_dim_bounds.reserve(input_shape.dimensions_size() - + dims_to_elide.size()); + int64 dims_to_elide_idx = dims_to_elide.size() - 1; + for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) { + if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) { + CHECK_EQ(input_shape.dimensions(i), 1); + dims_to_elide_idx--; + } else { + new_shape_dim_bounds.push_back(input_shape.dimensions(i)); + } + } + + c_reverse(new_shape_dim_bounds); + Shape output_shape = + ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + +StatusOr PadVectorWithZeros(HloInstruction* operand, + int64 zeros_to_prepend, + int64 zeros_to_append) { + HloComputation* computation = operand->parent(); + CHECK_EQ(operand->shape().dimensions_size(), 1); + PaddingConfig padding_config; + PaddingConfig::PaddingConfigDimension padding_config_dim; + padding_config_dim.set_edge_padding_low(zeros_to_prepend); + padding_config_dim.set_edge_padding_high(zeros_to_append); + *padding_config.add_dimensions() = padding_config_dim; + + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(operand->shape().element_type())))); + return MakePadHlo(operand, zero, padding_config); +} + +StatusOr BroadcastZeros( + HloComputation* computation, PrimitiveType element_type, + ArraySlice broadcast_dimensions) { + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(element_type)))); + return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/broadcast_dimensions); +} + +StatusOr> CreateComputationWithSignature( + ArraySlice domain, const Shape& range, + tensorflow::StringPiece name) { + HloComputation::Builder b{std::string(name)}; + int64 param_idx = 0; + for (const Shape* param_shape : domain) { + b.AddInstruction(HloInstruction::CreateParameter( + param_idx, *param_shape, StrCat("param.", param_idx))); + param_idx++; + } + + // We can't change the root type of a computation once it is created so create + // a dummy root instruction to give the computation the right root shape. In + // the future we may want to use a (recursive) broadcast here to avoid + // creating large constants. + b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(range))); + + return b.Build(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..49b1402d689a74874e34423a1832a0b6aa15f469 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Some lightweight utilities intended to make HLO instruction creation more +// ergonomic. We don't have a complete set of helpers yet -- I expect we'll +// expand this interface as needed on an ad-hoc basis. + +// Creates a binary HLO instruction and adds it to the computation containing +// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs); + +// Creates a pad HLO instruction and adds it to the computation containing +// `operand` and `padding_value` (`operand` and `padding_value` must be in the +// same computation). +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + +// Creates a slice HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeSliceHlo( + HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + +// Creates a convolution HLO instruction and adds it to the computation +// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeConvolveHlo( + HloInstruction* lhs, HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Creates a transpose HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeTransposeHlo( + HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + +// Creates a reshape HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand); + +StatusOr MakeReshapeHlo( + tensorflow::gtl::ArraySlice result_shape_dim_bounds, + HloInstruction* operand); + +// Creates a dynamic-slice HLO instruction and adds it to the computation +// containing `operand` and `start_indices` (`operand` and `start_indices` must +// be in the same computation). +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + +// Creates a dynamic-update-slice HLO instruction and adds it to the computation +// containing `operand`, `update` and `start_indices` (`operand`, `update` and +// `start_indices` must be in the same computation). +StatusOr MakeDynamicUpdateSliceHlo( + HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices); + +// Creates a broadcast HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeBroadcastHlo( + HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimensions, + tensorflow::gtl::ArraySlice result_shape_bounds); + +// Creates a GetTupleElement HLO instruction and adds it to the computation +// containing `operand`. +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index); + +// Creates a Concatenate HLO instruction and adds it to the computation +// containing `operands` (`operands` must be non-empty and every element must be +// contained in the same computation). +StatusOr MakeConcatHlo( + tensorflow::gtl::ArraySlice operands, int64 dimension); + +// Creates a Dot HLO instruction and adds it to the computation containing `lhs` +// and `rhs` (both must be in the same computation). +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers); + +// ----------------------------------------------------------------------------- +// Some other miscellaneous helpers to generate common HLO patterns. All of +// these add all the instructions they generate into the computation containing +// their operand(s). + +// Collapses (via reshape) the first N (logical) dimensions of `operand` into a +// single leading dimension. `operand` must have rank > `n` and `n` must not be +// 0. +// +// For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is +// the `operand` reshaped to [56,9]. +StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n); + +// Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand` +// using a reshape. +// +// For instance if operand has shape f32[3,4,5] then this returns the operand +// reshaped to f32[1,3,4,5]. If the operand is a f32 scalar (i.e. has shape +// f32[]) then this returns the operand reshaped to f32[1]. +StatusOr PrependDegenerateDims(HloInstruction* operand, + int64 n); + +// Expands (via reshape) the first (logical) dimension of `operand` into a +// sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 +// and the number of elements in its first dimension must be equal to the +// product of `expanded_dims`. +// +// For instance if `operand` has shape f32[200,9,7] and expanded_dims is +// {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. +StatusOr ExpandFirstDimIntoNDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + +// Elides (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_elide` from `operand`. Every dimension in +// `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be +// sorted and not contain duplicates. +// +// For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide +// is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. +StatusOr ElideDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + +// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the +// front and `zeros_to_append` zeros in the back. +StatusOr PadVectorWithZeros(HloInstruction* operand, + int64 zeros_to_prepend, + int64 zeros_to_append); + +// Broadcasts a zero value of type `element_type` into a tensor with element +// type `element_type` and dimension bounds `broadcast_dimensions`. The +// broadcast instruction is emitted into `computation`. +StatusOr BroadcastZeros( + HloComputation* computation, PrimitiveType element_type, + tensorflow::gtl::ArraySlice broadcast_dimensions); + +// Creates a HLO computation that takes arguments of type `domain` and produces +// a value of type `range`. +StatusOr> CreateComputationWithSignature( + tensorflow::gtl::ArraySlice domain, const Shape& range, + tensorflow::StringPiece name); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e7c4f95fed737f40064224717f409b934e4ff27 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +using tensorflow::gtl::ArraySlice; + +class HloCreationUtilsTest : public HloTestBase { + protected: + static std::unique_ptr CreateModuleWithProgramShape( + PrimitiveType primitive_type, ArraySlice input_shape_dims, + ArraySlice output_shape_dims, HloInstruction** param, + HloComputation** entry_computation) { + Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); + Shape output_shape = + ShapeUtil::MakeShape(primitive_type, output_shape_dims); + auto module = CreateNewModule("test"); + *entry_computation = module->AddEntryComputation( + CreateComputationWithSignature({&input_shape}, output_shape, "entry") + .ValueOrDie()); + *param = (*entry_computation)->parameter_instruction(0); + return module; + } +}; + +TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, + CollapseFirstNDims(param, 1)); + entry_computation->set_root_instruction(first_1_dims_collapsed); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); +} + +TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed, + CollapseFirstNDims(param, 2)); + entry_computation->set_root_instruction(first_2_dims_collapsed); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, + {Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); + CHECK_EQ(*result_literal, + *Literal::CreateR2( + {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); +} + +TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, + PrependDegenerateDims(param, 1)); + entry_computation->set_root_instruction(with_1_degenerate_dim_prepended); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); +} + +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, + PrependDegenerateDims(param, 2)); + entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); +} + +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, + PrependDegenerateDims(param, 2)); + entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR0(9)})); + CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); +} + +TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded, + ExpandFirstDimIntoNDims(param, {3, 1, 2})); + entry_computation->set_root_instruction(first_dim_expanded); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({1, 2, 3, 4, 5, 6})})); + CHECK_EQ(*result_literal, + *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); +} + +TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * zero_padded_param, + PadVectorWithZeros(param, /*zeros_to_prepend=*/3, /*zeros_to_append=*/1)); + entry_computation->set_root_instruction(zero_padded_param); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); +} + +TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * zeros, + BroadcastZeros(module->entry_computation(), S32, {2, 2})); + entry_computation->set_root_instruction(zeros); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR0(0)})); + CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); +} + +TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + F32, + /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * zeros, + BroadcastZeros(module->entry_computation(), F32, {2, 2})); + entry_computation->set_root_instruction(zeros); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR0(0.0f)})); + CHECK_EQ(*result_literal, + *Literal::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 279edd4ba8772a9c576f76f554de8ec68631b953..a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -26,12 +26,14 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -40,16 +42,16 @@ namespace { // Find and combine identical constants. Constants are identical if they have // the same type and value. -bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { - bool changed = false; - +StatusOr CombineConstants(HloComputation* computation, + bool is_layout_sensitive) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); // Map from ShortDebugString of the layoutless shape of the constant to the // set of constant instructions with that shape. Layoutless shape is used to // bin possible common constants together to reduce number of constant // comparisons. If we end up having too many constant comparisons, a more // precise binning might have to be used. std::multimap constants; - + int64 combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { HloInstruction* instruction = *inst_it; @@ -69,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal() == it->second->literal()) { + if (instruction->literal() == it->second->literal() && + domain_map->InSameDomain(it->second, instruction)) { match = it->second; break; } @@ -80,12 +83,27 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { // Match found, replace this instruction with the one in the multimap. TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); - changed = true; + ++combined; } } } + VLOG(4) << "Combined " << combined << " constants in " << computation->name() + << " computation"; + return combined > 0; +} - return changed; +// An instruction is considered to be equivalent to another only if they +// share the exact same set of operands. +int64 CseHash(const HloInstruction* instruction) { + int64 hash = std::hash()(static_cast(instruction->opcode())); + hash = tensorflow::Hash64Combine( + hash, instruction->opcode() == HloOpcode::kGetTupleElement + ? instruction->tuple_index() + : -1); + for (auto operand : instruction->operands()) { + hash = tensorflow::Hash64Combine(hash, operand->unique_id()); + } + return hash; } } // namespace @@ -95,45 +113,53 @@ StatusOr HloCSE::Run(HloModule* module) { const std::function eq_instructions = std::equal_to(); const std::function - eq_computations = std::equal_to(); + eq_computations = [](const HloComputation* lhs, + const HloComputation* rhs) { return *lhs == *rhs; }; + + auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->Identical(*rhs, eq_instructions, eq_computations, + is_layout_sensitive_); + }; + for (auto* computation : module->computations()) { - changed |= CombineConstants(computation, is_layout_sensitive_); - - std::list post_order = - computation->MakeInstructionPostOrder(); - std::set removed_instructions; - for (auto instruction : post_order) { - // If the instruction has already been removed by CSE skip over it. - if (removed_instructions.count(instruction) > 0 || - instruction->operand_count() == 0) { + if (only_fusion_computations_ && !computation->IsFusionComputation()) { + continue; + } + + TF_ASSIGN_OR_RETURN(bool combined, + CombineConstants(computation, is_layout_sensitive_)); + changed |= combined; + + // HLO instructions are grouped into equivalency classes by using the + // cse_equal predicate defined above. This set holds a representative + // instruction for each class. + tensorflow::gtl::FlatSet + representatives(/*N=*/computation->instruction_count() + 1, &CseHash, + cse_equal); + for (auto instruction : computation->MakeInstructionPostOrder()) { + // If the instruction has zero operands (constants, parameters, etc.) skip + // over it. + if (instruction->operand_count() == 0) { continue; } - - // An instruction is considered to be equivalent to another only if they - // share the exact same set of operands. So to find equivalent - // instructions, we just search among instructions which share operand(0) - // of this instruction. - const HloInstruction* operand = instruction->operand(0); - - tensorflow::gtl::InlinedVector - equivalent_instructions; - for (HloInstruction* user : operand->users()) { - if (user != instruction && - user->Identical(*instruction, eq_instructions, eq_computations, - is_layout_sensitive_)) { - equivalent_instructions.push_back(user); - } + // Skip instructions which have side effects or are a domain (which must + // not be CSE-ed). + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kDomain) { + continue; } - // Replace all equivalent instructions with this instruction. - for (HloInstruction* equivalent_instruction : equivalent_instructions) { + auto it = representatives.find(instruction); + if (it != representatives.end()) { + HloInstruction* equivalent_instruction = *it; TF_RETURN_IF_ERROR( - equivalent_instruction->ReplaceAllUsesWith(instruction)); - TF_RETURN_IF_ERROR( - computation->RemoveInstruction(equivalent_instruction)); - removed_instructions.insert(equivalent_instruction); + instruction->ReplaceAllUsesWith(equivalent_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); changed = true; + continue; } + representatives.insert(instruction); } } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 70096e07a2493763a9d4b0dc8e1c31510718c6c2..5e2b348bdda2b31556fb692e24d2bad2e4173ef5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -29,9 +29,11 @@ class HloCSE : public HloPassInterface { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. - explicit HloCSE(bool is_layout_sensitive) - : is_layout_sensitive_(is_layout_sensitive) {} - ~HloCSE() override {} + explicit HloCSE(bool is_layout_sensitive, + bool only_fusion_computations = false) + : is_layout_sensitive_(is_layout_sensitive), + only_fusion_computations_(only_fusion_computations) {} + ~HloCSE() override = default; tensorflow::StringPiece name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common @@ -39,7 +41,8 @@ class HloCSE : public HloPassInterface { StatusOr Run(HloModule* module) override; private: - bool is_layout_sensitive_; + const bool is_layout_sensitive_; + const bool only_fusion_computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 3601a790c4428ee39c264b217a4b9a991ad8456c..16db374566c727f1f3efe2a6d419f1f3caf0aaf1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" @@ -72,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR0(84.0); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -104,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -134,38 +135,53 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // Test that constants with the same value but different type are *not* // commoned. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + std::vector constants; + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + + const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); + for (int64 i = 0; i < constants.size(); ++i) { + constants[i] = builder.AddInstruction( + HloInstruction::CreateConvert(shape_r0, constants[i])); + } + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, constants[0], constants[1])); + for (int64 i = 2; i < constants.size(); ++i) { + root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, root, constants[i])); + } auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(7, computation->instruction_count()); + EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - EXPECT_EQ(6, computation->instruction_count()); + // CSE will remove both the second float(42.0f) and the corresponding + // convert/cast. + EXPECT_EQ(18, computation->instruction_count()); } TEST_F(HloCseTest, NonscalarConstants) { @@ -414,8 +430,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { EXPECT_THAT(root, op::Add(rng1, rng2)); } -// TODO(b/28245743): Handle impure functions correctly in CSE. -TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { +TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. @@ -458,14 +473,67 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Add(op::Map(), op::Map())); + VLOG(3) << "before: " << module->ToString(); + HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + VLOG(3) << "after: " << module->ToString(); EXPECT_EQ(4, computation->instruction_count()); root = computation->root_instruction(); - auto operand = root->operand(0)->operand(0); - EXPECT_THAT(operand, op::Map()); - EXPECT_THAT(root, op::Add(operand, operand)); + EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); +} + +TEST_F(HloCseTest, CompareComputations) { + auto module = ParseHloString(R"( + HloModule m + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + add_computation2 { + add_lhs2 = f32[] parameter(0) + add_rhs2 = f32[] parameter(1) + ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) + } + + ENTRY entry { + p = f32[10]{0} parameter(0) + c = f32[] constant(0) + r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation + r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 + ROOT f2 = (f32[],f32[]) tuple(r1, r2) + })") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0), root->operand(1)); +} + +TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { + // Test that constants with the same value but in different domains (disjoint + // in this case) are not collapsed. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(2, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 934e43ba4879628362009267c671ec4cb0d79c52..d0200058683b2db8f5f0469d6c643014881f741e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -363,16 +363,16 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateConditionalValueSet( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( conditional->false_computation()->root_instruction())}; - // A phi-node is not defined for a kConditional instruction even though it - // represents a join point. This is because the current approach is to define - // a phi-node only for kWhile to account for the dataflow through back-edges - // and deal with the ambiguity in other cases. - return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + if (ssa_form_) { + return Phi(conditional, inputs); + } else { + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + } } bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { @@ -538,7 +538,7 @@ bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) { bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet(xla_while->while_body()->root_instruction()), &GetInstructionValueSet(xla_while->operand(0))}; if (ssa_form_) { @@ -878,4 +878,135 @@ Status HloDataflowAnalysis::Verify() const { return Status::OK(); } +bool HloDataflowAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // Iterate through all users of all uses of the fusion parameter value. + // Return false if any uses are detected, returns true otherwise. + const HloValue& value = GetValueDefinedAt(fusion_param, index); + return value.uses().empty(); + } else { + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + return false; + } + } + } + } + + return true; +} + +bool HloDataflowAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + + if (user->opcode() == HloOpcode::kFusion) { + // Get the parameter associated with 'operand'; + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + + const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); + if (value.uses().size() != 1) { + return false; + } + const HloUse& use = value.uses()[0]; + + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + } + + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } + + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 7b8a74b096ff48733717e78ada5bb56a28caed72..9868746b6113881949e388cd2a4aa9f610b1fdb7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -118,6 +118,23 @@ class HloDataflowAnalysis { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 7bf3a1a06045c79621d75b653bf42220705a69d4..db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1602,11 +1602,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), ElementsAre(HloUse{conditional, 2, {}})); - EXPECT_EQ(analysis.values().size(), 3); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), - analysis.GetValueDefinedAt(constant2))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 4); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + } } TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { @@ -1713,11 +1719,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); - EXPECT_EQ(analysis.values().size(), 6); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(add), - analysis.GetValueDefinedAt(sub))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 7); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); + } } TEST_P(HloDataflowAnalysisTest, NestedConditionals) { @@ -1834,25 +1846,496 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), analysis.GetValueDefinedAt(constant2)); - EXPECT_EQ(analysis.values().size(), 9); - EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT( - HloValuesAt(inner_conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()))); - EXPECT_THAT( - HloValuesAt(conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()), - analysis.GetValueDefinedAt(computation3->root_instruction()))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 11); + EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); + } } INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); +class HloDataflowAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr dataflow_analysis_; +}; + +class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + NonElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0)); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, neg, {0, 1})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {reverse, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + MultiOutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0)); + auto copy1 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + ElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {exp, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + FusedDynamicUpdateSliceWithConvertCantShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + auto convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape_bf16, gte1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape_bf16, convert1, update, starts)); + + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {convert2, dynamic_update_slice, starts, update, convert1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can't share with tuple element 1. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 1e5f0f797a13fd7e7ce1cc934387a274a74153bc..fcd723af146e2227b8661b1a4993f1338f7de389 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -40,7 +40,7 @@ StatusOr HloDCE::Run(HloModule* module) { VLOG(2) << "Before dce:"; XLA_VLOG_LINES(2, module->ToString()); - for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* computation : module->MakeComputationPostOrder()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc new file mode 100644 index 0000000000000000000000000000000000000000..78955db0da02f16eb93689db947dc1190ab7049a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainIsolator::RunContext { + public: + RunContext(HloModule* module, HloDomainIsolator* isolator) + : module_(module), isolator_(isolator) {} + + StatusOr Run(); + + private: + // Inserts a kDomain instruction between parent and operand, in case + // the attribute (ie, sharding) values change between instruction and operand. + // Returns the newly inserted kDomain instruction, or nullptr if no kDomain + // instruction was necessary. + StatusOr CreateDomain(HloInstruction* instruction, + HloInstruction* parent, + HloInstruction* operand); + + HloModule* module_; + HloDomainIsolator* isolator_; +}; + +StatusOr HloDomainIsolator::RunContext::CreateDomain( + HloInstruction* instruction, HloInstruction* parent, + HloInstruction* operand) { + HloInstruction* domain = nullptr; + std::unique_ptr domain_instruction = + isolator_->creator_(instruction, operand); + if (domain_instruction != nullptr) { + domain = operand->parent()->AddInstruction(std::move(domain_instruction)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + } + return domain; +} + +StatusOr HloDomainIsolator::RunContext::Run() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); + + int64 added_domains = 0; + for (HloComputation* computation : module_->computations()) { + // Walk in post order and place all the required kDomain instructions. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + continue; + } + for (HloInstruction* operand : instruction->unique_operands()) { + // When applying multiple domains, we could end up stacking more than + // one in one edge, so here we want to build the effective + // (kDomain-less) instruction->operand edge. + HloInstruction* parent = instruction; + while (operand->opcode() == HloOpcode::kDomain) { + parent = operand; + operand = operand->mutable_operand(0); + } + // Check whether a kDomain is necessary between instruction and operand. + TF_ASSIGN_OR_RETURN(HloInstruction * domain, + CreateDomain(instruction, parent, operand)); + if (domain != nullptr) { + VLOG(4) << "New domain: " << domain->ToString(); + ++added_domains; + } + } + } + } + VLOG(3) << "Added " << added_domains << " kDomain instructions"; + if (added_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + } + return added_domains > 0; +} + +HloDomainIsolator::HloDomainIsolator(DomainCreator creator) + : creator_(std::move(creator)) {} + +StatusOr HloDomainIsolator::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h new file mode 100644 index 0000000000000000000000000000000000000000..e0c5718509dabebb7b9307bf764b0ea1ce7369a0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Domain isolation is the task of placing kDomain instructions between HLO +// instructions having different shrading. A kDomain instruction is essentially +// used to break an HLO graph edge connecting two instructions with different +// sharding. If a set of connected instructions have all the same sharding, no +// kDomain instruciton will be placed. +class HloDomainIsolator : public HloPassInterface { + public: + // Creates a new kDomain instruction for the edge between the use instruction + // (the first HloInstruction argument), and the operand instruction (the + // second HloInstruction argument). + // Returns nullptr in case no domain separation is necessary. + using DomainCreator = std::function( + HloInstruction*, HloInstruction*)>; + + explicit HloDomainIsolator(DomainCreator creator); + + tensorflow::StringPiece name() const override { return "domain_isolator"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + DomainCreator creator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebd5adb5d573ce4b556046f85eb26a6ad59efcb9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +/* static */ StatusOr> HloDomainMap::Create( + HloComputation* computation, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + return std::move(domain_map); +} + +/* static */ StatusOr> HloDomainMap::Create( + HloModule* module, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + } + return std::move(domain_map); +} + +bool HloDomainMap::InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const { + int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); + int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + return domain_id1 >= 0 && domain_id1 == domain_id2; +} + +Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); + // We only check operands, so we are sure to not process the empty domain from + // both sides. + for (HloInstruction* operand : instruction->unique_operands()) { + if (IsDomainInstruction(operand)) { + auto domain = MakeUnique(); + domain->enter_domains.insert(operand); + domain->exit_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + } + return Status::OK(); +} + +Status HloDomainMap::Populate(HloComputation* computation) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsDomainInstruction(instruction)) { + // If this is a kDomain of the kind we are currently processing, check + // whether this is an "empty domain". + TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + continue; + } + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + if (domain_id >= 0) { + // We have already processed this instruction. + continue; + } + TF_ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction)); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + return Status::OK(); +} + +Status HloDomainMap::InsertDomain( + std::unique_ptr domain) { + int64 domain_id = instruction_domains_.size(); + instruction_domains_.push_back(std::move(domain)); + for (HloInstruction* instruction : instruction_domains_.back()->reach_set) { + instruction_to_domain_[instruction] = domain_id; + } + return Status::OK(); +} + +Status HloDomainMap::ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const { + std::vector in_queue; + in_queue.push_back(instruction); + while (!in_queue.empty()) { + HloInstruction* current_instruction = in_queue.back(); + in_queue.pop_back(); + if (domain->reach_set.insert(current_instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = + FindOrDefault(instruction_to_domain_, current_instruction, -1); + TF_RET_CHECK(domain_id < 0) + << "Instruction " << current_instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : current_instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + in_queue.push_back(operand); + } + } + for (HloInstruction* user : current_instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + in_queue.push_back(user); + } + } + } + } + return Status::OK(); +} + +StatusOr> HloDomainMap::CreateDomain( + HloInstruction* instruction) const { + auto domain = MakeUnique(); + TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + domain->instructions = MakeNonDomainInstructions(domain->reach_set); + return std::move(domain); +} + +bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { + if (instruction->opcode() != HloOpcode::kDomain) { + return false; + } + if (!domain_kind_.empty()) { + if (instruction->user_side_metadata().Kind() != domain_kind_) { + return false; + } + // Both user and operand side of the metadata must be of the same kind. + CHECK(instruction->operand_side_metadata().Kind() == domain_kind_) + << "Instruction " << instruction->ToString() + << " has mismatching metadata kinds"; + } + return true; +} + +/* static */ std::vector +HloDomainMap::MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set) { + std::vector instructions; + instructions.reserve(instruction_set.size()); + for (HloInstruction* instruction : instruction_set) { + if (instruction->opcode() != HloOpcode::kDomain) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [](HloInstruction* a, HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + return instructions; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h new file mode 100644 index 0000000000000000000000000000000000000000..e62ef763fb3881ab6030b1f6a66266ac80a3d84d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The HloDomainMap splits a set of instructions within a module or computation, +// into different domains, separated by kDomain instructions. +// A domain is composed by a set of instructions which can reach each other via +// operand/user edges, without crossing a kDomain insutrction of a given kind. +// A domain never crosses computation boundaries. +class HloDomainMap { + public: + // Creates a new HloDomainMap, creating all the domains within the input + // computation, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create( + HloComputation* computation, string domain_kind); + + // Creates a new HloDomainMap, creating all the domains within the input + // module, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create(HloModule* module, + string domain_kind); + + // Retrieves all the domains the input module or computation are composed by. + const std::vector>& GetDomains() + const { + return instruction_domains_; + } + + // Checks whether two instructions are within the same domain. + bool InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const; + + // Checks whether instruction is a kDomain instruction of the kind we are + // currently processing. + bool IsDomainInstruction(HloInstruction* instruction) const; + + private: + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} + + // Check if the kDomain instruction is facing (via its operand link) another + // kDomain instruction of the same kind, hence defining an empty domain. + // If that is the case, create the empty domain and call the proper + // normalizer. + Status TryProcessEmptyDomain(HloInstruction* instruction); + + Status Populate(HloComputation* computation); + + // Inserts the provided domain into the ones tracked by this object, + // creating a new domain ID. + Status InsertDomain(std::unique_ptr domain); + + // From the given instruction, epxands operand and user wise, the set of + // instructions which can be reached without crossing a kDomain instruction + // of the kind specified by domain_kind_. + // The domain data structure will be populated with all the reached + // instructions, and the boundaries of the domain, with the kDomain + // instructions encountered while expanding the reach. + Status ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const; + + // Creates a domain data structure using the ExpandDomain() API. + StatusOr> CreateDomain( + HloInstruction* instruction) const; + + // Out of an instruction set, returns a vector of all the ones which are not + // a kDomain kind. + static std::vector MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set); + + string domain_kind_; + std::vector> instruction_domains_; + tensorflow::gtl::FlatMap instruction_to_domain_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0308100a21f109579de75788fce7d242d6a6b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Cannot include hlo_instruction.h as this file is included from there. +class HloInstruction; + +// The DomainMetadata represents the base class for metadata which can be +// attached to kDomain HLO instructions. +class DomainMetadata { + public: + // A Domain data structure captures all the information about a kDomain + // bounded instruction set. + struct Domain { + // The set of instructions which are reachable from each other via + // operand/user pathways, without crossing a kDomain instruction of a given + // kind. The reach_set can contain kDomain instructions of other kinds, if + // two domains of different kind intersect each other. + tensorflow::gtl::FlatSet reach_set; + + // The same instructions in reach_set, but purged from kDomain instructions. + std::vector instructions; + + // If we consider a graph edge as an arrow oriented from the operand to the + // user, the enter_domains will contain the set of kDomain instructions + // whose dataflow enters the reach set (domain), while the exit_domains + // contains the set of kDomain instructions whose dataflow exit the reach + // set. + tensorflow::gtl::FlatSet enter_domains; + tensorflow::gtl::FlatSet exit_domains; + }; + + virtual ~DomainMetadata() = default; + + // Clones the metadata object. + virtual std::unique_ptr Clone() const = 0; + + // Returns the metadata type. A unique identifier which describes the real + // metadata type. + virtual tensorflow::StringPiece Kind() const = 0; + + // Compares the metadata object with another one and returns true if the + // two matches. + virtual bool Matches(const DomainMetadata& other) const = 0; + + // Returns a string representation of the metadata. + virtual string ToString() const = 0; + + // Given a reachable set (the set of instructions which are reachable from + // each other via user/operand pathways, without crossing a kDomain + // instruciton), makes sure that all of them have metadata attributes which + // are coherent with this metadata object. + virtual Status NormalizeInstructions(const Domain& domain) const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d06040b0e7c92b03f4cb5481bdee73a0f74f939 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainRemover::RunContext { + public: + RunContext(HloModule* module, HloDomainRemover* remover) + : module_(module), remover_(remover) {} + + StatusOr Run(); + + private: + // Verifies the consistency of the domain, and normalizes the instructions + // within it. + Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain); + + HloModule* module_; + HloDomainRemover* remover_; +}; + +Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( + const DomainMetadata::Domain& domain) { + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + if (ref_metadata != nullptr) { + VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); + TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + } else { + // No kDomain instruction was present within this domain, so call the + // generic normalization functions and have them apply their heuristic. + VLOG(2) << "Applying domain-less normalization"; + TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + } + return Status::OK(); +} + +StatusOr HloDomainRemover::RunContext::Run() { + VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); + + int64 removed_domains = 0; + for (HloComputation* computation : module_->computations()) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); + // Verify and normalize every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + } + + // Now remove all the kDomain instructions of the kind specified by the + // remover, that are within the currently processed computation from the + // graph. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + for (HloInstruction* operand : instruction->unique_operands()) { + if (domain_map->IsDomainInstruction(operand)) { + VLOG(5) << "Removing " << operand->name(); + TF_RETURN_IF_ERROR( + operand->ReplaceAllUsesWith(operand->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + ++removed_domains; + } + } + } + HloInstruction* root = computation->root_instruction(); + if (root != nullptr && domain_map->IsDomainInstruction(root)) { + VLOG(5) << "Removing " << root->name(); + computation->set_root_instruction(root->mutable_operand(0)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + ++removed_domains; + } + } + VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" + << remover_->kind_ << "' kind"; + if (removed_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); + } + return removed_domains > 0; +} + +StatusOr HloDomainRemover::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h new file mode 100644 index 0000000000000000000000000000000000000000..0c71dd34fd4d2944037dc965a2c9ad2c592d6e3e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Removes all the kDomain instructions of a given kind from the input module, +// and calls the normalizer to propagate the properties on the possibly new born +// instructions. +class HloDomainRemover : public HloPassInterface { + public: + // Creates a new HloDomainRemover object tasked at removing all the kDomain + // instructions of a given kind. + // In case a reachable set (the set of instructions within a computation, + // which are mutually reachable via operand/user pathways) has all the + // instructions in it with the same attributes (ie, sharding), a normalizer + // function is tasked at applying attribute normalization on the instructions + // within such domain. + HloDomainRemover( + tensorflow::StringPiece kind, + std::function normalizer) + : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + + tensorflow::StringPiece name() const override { return "domain_remover"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + string kind_; + std::function normalizer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5553ddb153f7f1f2e6a790890c11f35e192488c4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloDomainTest : public HloTestBase { + protected: + bool FindUserViaDomainPath(HloInstruction* instruction, + HloInstruction* operand) const { + for (HloInstruction* user : operand->users()) { + if (user == instruction) { + return true; + } + if (user->opcode() == HloOpcode::kDomain && + FindUserViaDomainPath(instruction, user)) { + return true; + } + } + return false; + } + + // Checks whether there is a kDomain instruction in the edge between the + // instruction and the operand. + bool HasDomainEdge(HloModule* module, + tensorflow::StringPiece instruction_name, + tensorflow::StringPiece operand_name) { + HloInstruction* instruction = FindInstruction(module, instruction_name); + HloInstruction* operand = FindInstruction(module, operand_name); + CHECK_NE(instruction, nullptr); + CHECK_NE(operand, nullptr); + if (!instruction->IsUserOf(operand)) { + // If instruction is not an immediate user, we must find a path from + // operand to instruction anyway, otherwise there is a corruption. + if (FindUserViaDomainPath(instruction, operand)) { + return true; + } + LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name + << "' and '" << operand_name << "' instructions:\n" + << module->ToString(); + } + return false; + } + + StatusOr> ParseModule( + tensorflow::StringPiece hlo_string) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return ParseHloString(hlo_string, config); + } +}; + +// Dummy DomainMetadata implementation which create kDomain boundaries around +// HLO instructions with the same metadata().op_name() values. +class OpNameMetadata : public DomainMetadata { + public: + explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} + + std::unique_ptr Clone() const override { + return MakeUnique(opname_); + } + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override { + const OpNameMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a OpNameMetadata, then it is clearly a no match. + return false; + } + return opname_ == other_ptr->opname_; + } + + string ToString() const override { return opname_; } + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override { + // For the purposes of this test, nothing to do. + return Status::OK(); + } + + static tensorflow::StringPiece KindName() { return "opname"; } + + private: + string opname_; +}; + +// Creator function for OpNameMetadata domains. +std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* operand) { + if (instruction->metadata().op_name() == operand->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + MakeUnique(operand->metadata().op_name()); + std::unique_ptr user_side_metadata = + MakeUnique(instruction->metadata().op_name()); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { + // Nothing to do for the particular use this test make of the OpName domains. + return Status::OK(); +} + +TEST_F(HloDomainTest, CheckDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1} + d = f32[4] subtract(a, b), sharding={maximal device=1} + e = f32[4] multiply(c, d), sharding={maximal device=1} + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b) + d = f32[4] subtract(a, b) + e = f32[4] multiply(c, d) + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(!isolator_changed); +} + +TEST_F(HloDomainTest, CheckDomainAroundIO) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} + c = () send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} + e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} + f = f32[4] add(a, e) + g = f32[4] subtract(a, e) + ROOT h = (f32[4], f32[4]) tuple(f, g) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} + c = f32[4] add(b, b), sharding={maximal device=-1} + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_FALSE(isolator_changed); +} + +TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} + c = f32[4] add(b, b) + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_FALSE(remover_changed); + + HloInstruction* add = FindInstruction(module.get(), "c"); + ASSERT_NE(add, nullptr); + auto device = add->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, 0); +} + +TEST_F(HloDomainTest, CheckMultiDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(a, b), sharding={maximal device=1} + d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"} + e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"} + f = f32[4] add(e, c), sharding={maximal device=1} + ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator sharding_isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, + sharding_isolator.Run(module.get())); + EXPECT_TRUE(sharding_isolator_changed); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module.get())); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module.get())); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module.get())); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); +} + +TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + infeed = (f32[4], f32[4]) infeed(), + sharding={{maximal device=1}, {maximal device=0}} + gte0 = f32[4] get-tuple-element(infeed), index=0 + gte1 = f32[4] get-tuple-element(infeed), index=1 + copy0 = f32[4] copy(gte0) + copy1 = f32[4] copy(gte1) + ROOT add = f32[4] add(copy0, copy1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); + EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + + // Inject unassigned tuple/gte within the infeed domain, to simulate the + // HLO passes adding unexpected instructions. + // + // infeed + // / \ + // GTE0 GTE1 + // / \ + // COPY0 COPY1 + // \ / + // \ / + // TUPLE + // | + // DOMAIN + HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + ASSERT_NE(infeed, nullptr); + auto infeed_users = infeed->users(); + HloInstruction* new_gte0 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* new_copy0 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte0->shape(), HloOpcode::kCopy, new_gte0)); + HloInstruction* new_gte1 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_copy1 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte1->shape(), HloOpcode::kCopy, new_gte1)); + HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction::CreateTuple({new_copy0, new_copy1})); + for (HloInstruction* user : infeed_users) { + TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + struct Assignment { + HloInstruction* instruction; + int64 device; + } assignments[] = { + {new_gte0, 1}, + {new_copy0, 1}, + {new_gte1, 0}, + {new_copy1, 0}, + }; + for (auto& assignment : assignments) { + auto device = assignment.instruction->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, assignment.device); + } + EXPECT_TRUE(new_tuple->has_sharding()); + EXPECT_EQ( + new_tuple->sharding(), + HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index c782d1b0add17c70e0f54826917df251d5a613e2..4ed1508d7067684a15d0fb7d86e69b055bc1333b 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -119,6 +119,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { return false; } + HloCloneContext context(module); bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { @@ -140,6 +141,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || @@ -178,24 +180,37 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { if (hlo->shape().element_type() == eliminate_type_) { Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( - hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + hlo->CloneWithNewOperands(shape, new_operands, &context)); + TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + new_hlo = ToElementType(new_hlo, eliminate_type_); } else if (ShapeUtil::IsTuple(hlo->shape())) { Shape old_shape = hlo->shape(); Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - new_shape, new_operands, hlo->GetModule())); + + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(new_shape, new_operands, &context)); + TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - hlo->shape(), new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); + TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); + TF_RETURN_IF_ERROR(hlo->DropAllControlDeps()); + + // NB! We want to replace and remove side effecting instructions like Rng + // as well so we can't rely HloComputation::ReplaceInstruction to reliably + // remove the replaced instruction. + TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo)); changed = true; } } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index cb94d9f19b825d1321263a4737b66a6bf198a772..5c5a059e0fd895f03bc26a975609b57333237faf 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -22,6 +22,12 @@ namespace { namespace op = xla::testing::opcode_matchers; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Not; +using ::testing::ResultOf; + class HloElementTypeConverterTest : public HloTestBase { public: std::unique_ptr CreateModuleFromHloString( @@ -117,5 +123,65 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { op::Convert(op::GetTupleElement(batch_norm, 2)))); } +TEST_F(HloElementTypeConverterTest, RngIsRemoved) { + const string& hlo_string = R"( +HloModule RngIsRemoved + +ENTRY main { + constant.3 = bf16[] constant(0) + constant.4 = bf16[] constant(1) + ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform +} + )"; + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + + std::function is_bf16_rng = + [](const HloInstruction* inst) { + return inst->shape().element_type() == BF16 && + inst->opcode() == HloOpcode::kRng; + }; + + EXPECT_THAT(module->entry_computation()->instructions(), + Not(Contains(ResultOf(is_bf16_rng, Eq(true))))); +} + +TEST_F(HloElementTypeConverterTest, RngCtrlDep) { + const string& hlo_string = R"( +HloModule RngIsRemoved + +ENTRY main { + constant.3 = bf16[] constant(0) + constant.4 = bf16[] constant(1) + rng0 = bf16[1,2000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform + ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform +} + )"; + auto module = CreateModuleFromHloString(hlo_string); + + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + + HloInstruction *rng0, *rng1; + for (auto* inst : module->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kRng) { + const Shape& shape = inst->shape(); + ASSERT_EQ(shape.dimensions_size(), 3); + ASSERT_TRUE(shape.dimensions(1) == 2000 || shape.dimensions(1) == 1000); + if (shape.dimensions(1) == 2000) { + rng0 = inst; + } else { + rng1 = inst; + } + } + } + + EXPECT_THAT(rng0->control_successors(), ElementsAre(rng1)); + EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 296f010a920a801ef0a4dc5e40bf0dbc07898196..e0648e14672c45e9a691fd6a674c9a2cd7605a12 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -51,16 +51,12 @@ namespace xla { namespace { -template -struct is_complex_t : public std::false_type {}; - -template <> -struct is_complex_t : public std::true_type {}; +using tensorflow::gtl::ArraySlice; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, - const Literal& lhs_literal, - const Literal& rhs_literal) { + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -98,20 +94,19 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + auto result = MakeUnique(shape); + TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } template <> StatusOr> Compare( - const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, - const Literal& rhs_literal) { + const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -129,1947 +124,63 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); - - return std::move(result); -} - -template -StatusOr> ElementWiseUnaryOpImpl( - HloInstruction* instruction, - const std::function& unary_op, - const Literal& operand_literal) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); + TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return unary_op(operand_literal.Get(multi_index)); - })); return std::move(result); } -// For one particular placement of a window in a base shape (the placement is -// represented as `window_count_index`), iterates inside the window. Translates -// the window index into base index. If the base index is within bound, call `f` -// with the base index. -void IterateThroughWindow( - const Shape& window_shape, const Window& window, const Shape& base_shape, - const tensorflow::gtl::ArraySlice& window_count_index, - const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); - DimensionVector window_index(rank); - std::fill(window_index.begin(), window_index.end(), 0); - do { - std::vector base_index(rank); - bool out_of_bound = false; - for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - f(base_index); - } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); -} - } // namespace -template -class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { - public: - explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} - - // The following higher-order functions convert a function with ElementwiseT - // to a function with ReturnT. - std::function ConvertUnaryFunction( - const std::function& unary_op) { - return [&unary_op](ReturnT arg) { - return static_cast(unary_op(static_cast(arg))); - }; - } - std::function ConvertBinaryFunction( - const std::function& - binary_op) { - return [&binary_op](ReturnT arg1, ReturnT arg2) { - return static_cast(binary_op(static_cast(arg1), - static_cast(arg2))); - }; - } - std::function ConvertTernaryFunction( - const std::function& ternary_op) { - return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { - return static_cast(ternary_op(static_cast(arg1), - static_cast(arg2), - static_cast(arg3))); - }; - } - - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); - } - - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive types. - - template ::value>::type* = - nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { - return std::abs(elem_operand); - })); - return Status::OK(); - } - - Status HandleAbs(HloInstruction* abs) override { - return HandleAbs(abs); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { - return std::round(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); - } - - Status HandleRound(HloInstruction* round) override { - return HandleRound(round); - } - - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { - return std::ceil(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); - } - - Status HandleCeil(HloInstruction* ceil) override { - return HandleCeil(ceil); - } - - Status HandleConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).Convert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleExp(HloInstruction* exp) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { - return std::exp(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { - return std::floor(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); - } - - Status HandleFloor(HloInstruction* floor) override { - return HandleFloor(floor); - } - - Status HandleLog(HloInstruction* log) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { - return std::log(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return ~elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); - } - - Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { - return NativeT(-type(elem_operand)); - })); - return Status::OK(); - } - - template ::value || - std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp( - negate, [](ElementwiseT elem_operand) { return -elem_operand; })); - return Status::OK(); - } - - Status HandleNegate(HloInstruction* negate) override { - return HandleNegate(negate); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - return (ElementwiseT(0) < elem_operand) - - (elem_operand < ElementwiseT(0)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ElementwiseT(0) - : elem_operand / abs_val; - })); - return Status::OK(); - } - - Status HandleSign(HloInstruction* sign) override { - return HandleSign(sign); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], - ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return std::atan2(lhs_elem, rhs_elem); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); - } - - Status HandleAtan2(HloInstruction* atan2) override { - return HandleAtan2(atan2); - } - - Status HandleTanh(HloInstruction* tanh) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { - return std::tanh(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); - return Status::OK(); - } - - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - - Status HandleSubtract(HloInstruction* subtract) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); - return Status::OK(); - } - - Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; - })); - return Status::OK(); - } - - Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem / rhs_elem; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::fmax(lhs, rhs); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); - } - - Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::fmin(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); - } - - Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); - } - - Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); - } - - Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); - } - - template ::value>::type* = - nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el & rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); - } - - Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); - } - - template ::value>::type* = - nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el | rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); - } - - Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction* shl) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shl], - ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return lhs_elem << rhs_elem; - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); - } - - Status HandleShiftLeft(HloInstruction* shl) override { - return HandleShiftLeft(shl); - } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction* shr) { - typedef typename std::make_signed::type SignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - return static_cast(static_cast(lhs_elem) >> - rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); - } - - Status HandleShiftRightArithmetic(HloInstruction* shra) override { - return HandleShiftRightArithmetic(shra); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction* shr) { - typedef typename std::make_unsigned::type UnsignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - // If shift amount is greater than the number of bits, then return 0. - if (rhs_elem >= sizeof(UnsignedT) * CHAR_BIT) { - return static_cast(0); - } - return static_cast(static_cast(lhs_elem) >> - rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); - } - - Status HandleShiftRightLogical(HloInstruction* shrl) override { - return HandleShiftRightLogical(shrl); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction* clamp) { - std::function - clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmax(low, std::fmin(value, high)); - }; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clamp], - ElementwiseTernaryOp(clamp, - std::move(ConvertTernaryFunction(clamp_op)))); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); - } - - Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); - } - - Status HandleSelect(HloInstruction* select) override { - CHECK(!ShapeUtil::IsTuple(select->shape())); - std::function select_op = - [](bool pred, ReturnT on_true, ReturnT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementwiseTernaryOp(select, std::move(select_op))); - return Status::OK(); - } - - Status HandleReverse(HloInstruction* reverse) override { - const auto result_shape = reverse->shape(); - const auto reverse_dimensions = reverse->dimensions(); - - auto operand = reverse->operand(0); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReverseShape(operand->shape(), - reverse_dimensions)); - - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); - - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { - std::vector from_index(out_index.begin(), out_index.end()); - for (const int64 dim : reverse_dimensions) { - from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; - } - return operand_literal.Get(from_index); - })); - - parent_->evaluated_[reverse] = std::move(result); - return Status::OK(); - } - - Status HandleConvolution(HloInstruction* conv) override { - auto lhs = conv->operand(0); - auto rhs = conv->operand(1); - const auto& window = conv->window(); - const Shape& result_shape = conv->shape(); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); - - TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); - TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); - - const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); - CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 0); - CHECK_EQ(window.dimensions_size(), num_spatial_dims); - - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); - - CHECK_EQ(num_spatial_dims + 2, lhs_rank); - CHECK_EQ(num_spatial_dims + 2, rhs_rank); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); - CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - // Dimension number applicable for input (lhs). - const int64 input_batch_dim = dnums.input_batch_dimension(); - const int64 input_z_dim = dnums.input_feature_dimension(); - // Dimension number applicable for kernel (rhs). - const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); - const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - // Dimension number applicable for output. - const int64 output_batch_dim = dnums.output_batch_dimension(); - const int64 output_z_dim = dnums.output_feature_dimension(); - - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - - std::vector window_dimension_sizes; - for (auto i : dnums.kernel_spatial_dimensions()) { - window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); - } - - const Shape& window_shape = - ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); - - auto func = [&](tensorflow::gtl::ArraySlice out_index) { - ElementwiseT result_val = static_cast(0); - - std::fill(lhs_index.begin(), lhs_index.end(), 0); - std::fill(rhs_index.begin(), rhs_index.end(), 0); - std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); - - lhs_index[input_batch_dim] = out_index[output_batch_dim]; - rhs_index[kernel_output_z_dim] = out_index[output_z_dim]; - - // Convolve input feature with kernel. - do { - for (int64 iz = 0; iz < z_size; ++iz) { - lhs_index[input_z_dim] = iz; - rhs_index[kernel_input_z_dim] = iz; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - if (window_dim.base_dilation() > 1) { - lhs_index[input_spatial_dim] = - undilated_index / window_dim.base_dilation(); - } else { - lhs_index[input_spatial_dim] = undilated_index; - } - - // Skip if input index is not in bound. - if (!(lhs_index[input_spatial_dim] >= 0 && - lhs_index[input_spatial_dim] < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_index[dnums.kernel_spatial_dimensions(ki)] = - window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]; - } - - result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); - } - cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - - return static_cast(result_val); - }; - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate(func)); - - parent_->evaluated_[conv] = std::move(result); - return Status::OK(); - } - - Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); - - const auto& dnums = dot->dot_dimension_numbers(); - - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - - // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); - const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); - // Contracted dimension sizes must be the same. - CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), - rhs->shape().dimensions(rhs_contracting_dimension)) - << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracting_dimension) - << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(dot->shape()); - - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - - std::vector lhs_non_contracting_dims; - for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); - } - } - - std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); - } - } - - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { - ElementwiseT result_val = static_cast(0); - - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; - } - - // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; - - result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); - } - - return static_cast(result_val); - })); - - parent_->evaluated_[dot] = std::move(result); - return Status::OK(); - } - - Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); - // Padding value must be scalar. - CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), - pad->padding_config().dimensions_size()); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferPadShape( - /*operand_shape=*/pad->operand(0)->shape(), - /*padding_value_shape=*/pad->operand(1)->shape(), - /*padding_config=*/pad->padding_config())); - CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - // Create new HLO of padded shape with padding value. - ReturnT scalar = - parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = Literal::CreateFromShape(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); - - const Literal& evaluated_operand = - parent_->GetEvaluatedLiteralFor(pad->operand(0)); - - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); - - // Loop through each element of the operand, assign them to the - // corresponding index of the resulting padded literal. - const PaddingConfig& pad_config = pad->padding_config(); - - auto func = [&](const std::vector& input_index) { - for (auto i = 0; i < input_index.size(); ++i) { - // Interior padding occurs logically before edge padding, so in the case - // of negative edge padding elements are removed from the - // interior-padded operand. - target_index[i] = - pad_config.dimensions(i).edge_padding_low() + - input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); - - // Account for negative low and high padding: skip assignment if the - // any target index is out of range. - if (!(target_index[i] >= 0 && - target_index[i] < pad->shape().dimensions(i))) { - return true; - } - } - result->Set(target_index, - evaluated_operand.Get(input_index)); - return true; - }; - - std::vector zero_base(evaluated_operand.shape().dimensions_size(), - 0); - std::vector step(evaluated_operand.shape().dimensions_size(), 1); - - ShapeUtil::ForEachIndex( - evaluated_operand.shape(), zero_base, - AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); - - parent_->evaluated_[pad] = std::move(result); - return Status::OK(); - } - - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - auto operand = dynamic_slice->operand(0); - auto start_indices = dynamic_slice->operand(1); - auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - default: - LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - auto operand = dynamic_update_slice->operand(0); - auto update = dynamic_update_slice->operand(1); - auto start_indices = dynamic_update_slice->operand(2); - auto result_shape = dynamic_update_slice->shape(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - default: - LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - template - StatusOr> MapImpl(HloInstruction* map) { - auto operands = map->operands(); - HloComputation* computation = map->to_apply(); - - auto result = Literal::CreateFromShape(map->shape()); - - HloEvaluator embedded_evaluator; - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - std::vector> arg_literals; - arg_literals.reserve(operands.size()); - - // Construct scalar literal parameters to be passed to the map - // computation. - for (auto operand : operands) { - const Literal& arg_literal = - parent_->GetEvaluatedLiteralFor(operand); - - auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); - - arg_literals.push_back(std::move(curr_val_literal)); - } - - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) - .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - return computed_result->Get({}); - })); - return std::move(result); - } - - Status HandleMap(HloInstruction* map) override { - switch (map->operand(0)->shape().element_type()) { - case PRED: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F16: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], - MapImpl(map)); - break; - } - case F32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case C64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - default: - LOG(FATAL) << "HandleMap: unhandled primitive type for " - "input operand: " - << PrimitiveType_Name( - map->operand(0)->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleReduce(HloInstruction* reduce) override { - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), - /*dimensions_to_reduce=*/dimensions, - /*to_apply=*/function->ComputeProgramShape())); - TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce->shape()); - - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); - std::vector arg_dim_steps(arg_dimensions.size()); - std::vector arg_dim_counts(arg_dimensions.size()); - for (const int64 dim : dimensions) { - arg_dim_steps[dim] = 1; - arg_dim_counts[dim] = arg_dimensions[dim]; - } - - // Create mapping from result index to arg index. - const int64 result_rank = ShapeUtil::Rank(result->shape()); - int64 result_dim = 0; - std::vector result_to_arg_index(result_rank); - for (int64 i = 0; i < arg_dimensions.size(); ++i) { - if (arg_dim_steps[i] == 0) { - result_to_arg_index[result_dim] = i; - ++result_dim; - } - } - - HloEvaluator embedded_evaluator; - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = init_scalar; - - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } - - auto func = [&](const std::vector& input_index) { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {curr_val_literal.get(), - result_val_literal.get()}; - - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - // Assign computed result to result_val. - result_val = computed_result->Get({}); - - return true; - }; - - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); - } - - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { - auto operand = select_and_scatter->operand(0); - auto source = select_and_scatter->operand(1); - const Window& window = select_and_scatter->window(); - - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(select_and_scatter->shape()); - - // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); - - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - HloComputation* select = select_and_scatter->select(); - HloComputation* scatter = select_and_scatter->scatter(); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - - int64 rank = ShapeUtil::Rank(operand_literal.shape()); - - HloEvaluator embedded_evaluator; - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); - do { - // For each element in `source`, we place a window in `operand`. For each - // window placement, we iterate inside the window twice: - // - // 1. Find the selected index by applying `select` function to all - // elements. E.g., If the `select` function is GreaterEqual, the first - // iteration through the window finds the biggest value and returns its - // index. - // - // 2. Using the selected index, scatter value from `source` to result. We - // do this by iterating through the window, and compare each index with - // the selected index. - tensorflow::gtl::optional selected_val; - tensorflow::gtl::optional> selected_index; - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - if (!selected_val) { - selected_val = curr_val; - selected_index = operand_index; - } - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto selected_val_literal = - Literal::CreateR0(*selected_val); - - const std::vector args = { - curr_val_literal.get(), selected_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*select, args) - .ConsumeValueOrDie(); - bool selected = computed_result->Get({}); - if (selected) { - selected_val = curr_val; - selected_index = operand_index; - } - embedded_evaluator.ResetVisitStates(); - }); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - if (std::equal(operand_index.begin(), operand_index.end(), - selected_index->begin())) { - auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - const auto source_literal = Literal::CreateR0(source); - const auto scattered_literal = - Literal::CreateR0(scattered); - - const std::vector args = { - source_literal.get(), scattered_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*scatter, args) - .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); - // Clear visit states so that the we can use the evaluator again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - } - }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); - - parent_->evaluated_[select_and_scatter] = std::move(result); - return Status::OK(); - } - - Status HandleReduceWindow(HloInstruction* reduce_window) override { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferReduceWindowShape( - /*operand_shape=*/reduce_window->operand(0)->shape(), - /*init_value=*/reduce_window->operand(1)->shape(), window, - /*to_apply_shape=*/function->ComputeProgramShape())); - TF_RET_CHECK( - ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) - << "return shape is set to: " - << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanStringWithLayout(inferred_return_shape); - - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); - VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); - VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce_window->shape()); - - // Creates a Shape object from window, for iteration below. - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - - HloEvaluator embedded_evaluator; - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - ReturnT result_val = init_scalar; - - std::fill(window_index.begin(), window_index.end(), 0); - std::fill(operand_index.begin(), operand_index.end(), 0); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), output_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - - // Evaluate computation with specified literal operands. - const auto curr_val_literal = - Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = { - curr_val_literal.get(), result_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - // Clear visit states so that the we can use the evaluate again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - - result_val = computed_result->Get({}); - }); - - return result_val; - })); - - parent_->evaluated_[reduce_window] = std::move(result); - return Status::OK(); - } - - Status HandleSlice(HloInstruction* slice) override { - auto operand = slice->operand(0); - const Shape& shape = slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferSliceShape( - operand->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const int64 rank = ShapeUtil::Rank(operand->shape()); - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { - DimensionVector operand_index(rank); - for (int64 i = 0; i < rank; ++i) { - operand_index[i] = - slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); - } - return operand_literal.Get(operand_index); - }; - - auto result = Literal::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); - parent_->evaluated_[slice] = std::move(result); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { - return std::sin(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); - } - - Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); - } - - template ::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { - return std::cos(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); - } - - Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[reduce_precision], - ElementWiseUnaryOp(reduce_precision, [reduce_precision]( - ElementwiseT elem) { - uint32_t value_as_int = tensorflow::bit_cast(elem); - const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); - const uint32_t exponent_bits = reduce_precision->exponent_bits(); - - // Code is based on the CPU/GPU implementation in LLVM-emitting code. - // - // Bits in float type: - // mantissa : bits [0:22] - // exponent : bits [23:30] - // sign : bits [31] - if (mantissa_bits < 23) { - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); - - // Compute rounding bias for round-to-nearest with ties to even. - // This is equal to a base value of 0111... plus one bit if the last - // remaining mantissa bit is 1. - const uint32_t base_rounding_bias = - (last_mantissa_bit_mask >> 1) - 1; - const uint32_t x_last_mantissa_bit = - (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); - const uint32_t x_rounding_bias = - x_last_mantissa_bit + base_rounding_bias; - - // Add rounding bias, and mask out truncated bits. Note that the - // case where adding the rounding bias overflows into the exponent - // bits is correct; the non-masked mantissa bits will all be zero, - // and the exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - value_as_int = value_as_int + x_rounding_bias; - value_as_int = value_as_int & truncation_mask; - } - if (exponent_bits < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the - // most- significant bit -- is equal to 1.0f for all exponent sizes. - // Adding 2^(n-1)-1 to this gives us the highest non-infinite - // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from - // this gives us the lowest' exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n - // is (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; - const bool x_overflows = x_exponent > (reduced_max_exponent << 23); - const bool x_underflows = - x_exponent <= (reduced_min_exponent << 23); - - // Compute appropriately-signed values of zero and infinity. - const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; - const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; - - // Force to zero or infinity if overflow or underflow. (Note that - // this truncates all denormal values to zero, rather than rounding - // them.) - value_as_int = x_overflows ? x_signed_inf : value_as_int; - value_as_int = x_underflows ? x_signed_zero : value_as_int; - } - - float reduced_result = tensorflow::bit_cast(value_as_int); - if (std::isnan(elem)) { - reduced_result = mantissa_bits > 0 - ? elem - : std::numeric_limits::infinity(); - } - return reduced_result; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); - } - - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return HandleReducePrecision(reduce_precision); - } - - private: - template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - - std::vector operand_indices(start.size()); - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < operand_indices.size(); ++i) { - CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); - } - - auto result = operand_literal.Get(operand_indices); - return result; - })); - - return std::move(result); - } - - template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto start_indices_typed = start_indices_literal.data(); - const std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - - auto result = operand_literal.CloneToUnique(); - std::vector result_index(ShapeUtil::Rank(result->shape()), 0); - - auto func = [&](const std::vector& update_index) { - std::transform(update_index.begin(), update_index.end(), start.begin(), - result_index.begin(), std::plus()); - - result->Set(result_index, - update_literal.Get(update_index)); - return true; - }; - - std::vector base(update_literal.shape().dimensions_size(), 0); - std::vector step(update_literal.shape().dimensions_size(), 1); - ShapeUtil::ForEachIndex(update_literal.shape(), base, - AsInt64Slice(update_literal.shape().dimensions()), - step, func); - - return std::move(result); - } - - StatusOr> ElementWiseUnaryOp( - HloInstruction* instruction, - const std::function& unary_op) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - TF_ASSIGN_OR_RETURN( - auto result_literal, - (ElementWiseUnaryOpImpl( - instruction, ConvertUnaryFunction(unary_op), operand_literal))); - - return std::move(result_literal); - } - - StatusOr> ElementWiseBinaryOp( - HloInstruction* instruction, - const std::function& - binary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return ConvertBinaryFunction(binary_op)( - lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); - return std::move(result); - } - - template - StatusOr> ElementwiseTernaryOp( - HloInstruction* instruction, - const std::function& ternary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return ternary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index), - ehs_literal.Get(multi_index)); - })); - - return std::move(result); - } - - HloEvaluator* parent_; -}; // class HloEvaluator::TypedVisitor - -HloEvaluator::HloEvaluator() { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); +HloEvaluator::HloEvaluator(int64 max_loop_iterations) + : max_loop_iterations_(max_loop_iterations) { + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: U16."); + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: S16."); + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); - typed_visitors_[F16] = MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = + MakeUnique>(this); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. - typed_visitors_[BF16] = MakeUnique>(this); + typed_visitors_[BF16] = + MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE."); + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); } template StatusOr> HloEvaluator::Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals) { + const HloModule& module, ArraySlice arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -2086,8 +197,8 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals) { + const HloComputation& computation, ArraySlice arg_literals) { + CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -2103,8 +214,7 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals) { + HloInstruction* instruction, ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); @@ -2199,6 +309,35 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } +StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), + rhs_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + +StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand) { + std::unique_ptr operand_instr = + HloInstruction::CreateConstant(operand.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -2229,8 +368,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + ArraySlice operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -2369,6 +507,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { } break; case F16: return Unimplemented("unhandled primitive type: F16."); + case BF16: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; case F32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -2402,6 +545,379 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output +// gather dimensions while keeping the rest of the output dimensions clamped to +// 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( + const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { + int64 output_rank = output_shape.dimensions_size(); + std::vector index_base(output_rank, 0); + std::vector index_count; + index_count.reserve(output_rank); + for (int64 i = 0; i < output_rank; i++) { + bool is_output_gather_dim = + !c_binary_search(dim_numbers.output_window_dims(), i); + index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) + : 1); + } + + return {std::move(index_base), std::move(index_count), + std::vector(output_rank, 1)}; +} + +// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( + int64 output_rank, ArraySlice window_bounds, + const GatherDimensionNumbers& dim_numbers) { + std::vector index_base(output_rank, 0); + std::vector index_count(output_rank, 1); + int64 window_bounds_idx = 0; + for (int64 i = 0; i < output_rank; i++) { + bool is_output_window_dim = + c_binary_search(dim_numbers.output_window_dims(), i); + if (is_output_window_dim) { + while (c_binary_search(dim_numbers.elided_window_dims(), + window_bounds_idx)) { + window_bounds_idx++; + } + index_count[i] = window_bounds[window_bounds_idx++]; + } + } + + return {std::move(index_base), std::move(index_count), + std::vector(output_rank, 1)}; +} + +// This functor computes the contribution of gather_indices to an input index +// corresponding to an output index. That is, given an output index I, it picks +// out the gather output indices in I and uses them to look up a gather index, +// G, from the gather indices tensor, and expands G into the input space +// according to gather_dims_to_operand_dims. +class OutputGatherIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit OutputGatherIndexToInputIndex( + const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& output_shape, const Literal* gather_indices) + : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + for (int64 i = 0; i < output_shape.dimensions_size(); i++) { + output_dim_is_gather_dims_.push_back( + !c_binary_search(dim_numbers_.output_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), + c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + if (index_of_input_dim_in_index_vector == + dim_numbers_.gather_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of gather_indices to the input index corresponding + // to output_index. See gather_inner_loop_body. + // + // This is conceptually a stateless transformation from output_index to the + // gather input index, but: + // + // - Instead of allocating memory to represent the gather input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()(ArraySlice output_index) { + PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return ArraySlice(input_index_); + } + + private: + // Propagates the gather index dimensions from the output index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the dimension + // we iterate over in FetchIndexVector. + void PropagateOutputIndexGatherDimsToIndexVectorIndex( + ArraySlice output_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = output_index.size(); i < e; i++) { + if (!output_dim_is_gather_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = output_index[i]; + } + } + + // Populates index_vector_ by iterating over gather_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of + // the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // dimension. + std::vector output_dim_is_gather_dims_; + + // The buffer into which we construct an index into gather_indices_ to fetch + // the index vector. + std::vector index_vector_index_; + + // The index vector fetched from gather_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice into + // this vector. + std::vector input_index_; + + const GatherDimensionNumbers& dim_numbers_; + const Literal& gather_indices_; +}; + +// This functor computes the contribution of the window indices in an output +// index to an input index. That is, given an output index I it picks out the +// output window indices in I and expands it into a window index into the input +// shape. +class OutputWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit OutputWindowIndexToInputIndex( + const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& output_shape) { + std::vector window_index_to_output_index; + int64 output_index_count = 0; + for (int64 i = 0; i < output_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.output_window_dims(), i)) { + window_index_to_output_index.push_back(output_index_count++); + } else { + output_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + input_dim_value_to_output_index_.push_back(-1); + } else { + input_dim_value_to_output_index_.push_back( + window_index_to_output_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to output_index. See gather_inner_loop_body. + // + // This is conceptually a stateless transformation from output_index to the + // window input index, but instead of allocating memory to represent the + // gather input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()(ArraySlice output_index) { + PropagateOutputIndexWindowDimsToInputIndex(output_index); + return ArraySlice(input_index_); + } + + private: + // Propagates window dimensions from the output index to input_index_ by + // mutating input_index_ in place. + void PropagateOutputIndexWindowDimsToInputIndex( + ArraySlice output_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_output_index_[i] != -1) { + input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of + // the input index from the output index. See + // PropagateOutputIndexToInputIndex. + std::vector input_dim_value_to_output_index_; + + // The result computed by this functor. operator() returns an ArraySlice into + // this vector. + std::vector input_index_; +}; + +// Rehapes the gather indices input to have a trailing degenerate `1` dimension +// if necessary. Hands over the ownership of the newly created literal (if +// there is one) to `reshaped_gather_indices`. +static StatusOr> ReshapedGatherIndices( + int64 index_vector_dim, const Literal& gather_indices, + std::unique_ptr* reshaped_gather_indices) { + if (gather_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(gather_indices); + } + + std::vector new_shape(gather_indices.shape().dimensions().begin(), + gather_indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, + gather_indices.Reshape(new_shape)); + return std::cref(**reshaped_gather_indices); +} + +Status HloEvaluator::HandleGather(HloInstruction* gather) { + std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + const Shape& shape = gather->shape(); + const GatherDimensionNumbers& dim_numbers = + gather->gather_dimension_numbers(); + const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); + std::unique_ptr reshaped_gather_indices; + TF_ASSIGN_OR_RETURN( + const Literal& gather_indices, + ReshapedGatherIndices(dim_numbers.index_vector_dim(), + GetEvaluatedLiteralFor(gather->operand(1)), + &reshaped_gather_indices)); + + // We iterate over the gather dimensions in the output shape in an outer loop + // nest, and iterate over the window dimensions in the output shape in an + // inner loop nest. + + ShapeUtil::IndexIterationSpace gather_indices_iteration_space = + IterationSpaceForOutputGatherIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForOutputWindowIndices( + shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + + // Scratch buffers that hold an index in the output shape and the + // corresponding index in the input shape. + std::vector input_index(operand.shape().dimensions_size()); + std::vector output_index(gather->shape().dimensions_size()); + + OutputGatherIndexToInputIndex output_gather_index_to_input_index( + &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), + /*output_shape=*/shape, &gather_indices); + OutputWindowIndexToInputIndex output_window_index_to_input_index( + gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), + /*output_shape=*/shape); + + const Shape& operand_shape = operand.shape(); + + auto gather_inner_loop_body = + [&](ArraySlice output_window_index, + ArraySlice input_gather_index, + ArraySlice output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + ArraySlice input_window_index, + output_window_index_to_input_index(output_window_index)); + for (int i = 0, e = output_index.size(); i < e; i++) { + output_index[i] = output_gather_index[i] + output_window_index[i]; + DCHECK_LT(output_index[i], shape.dimensions(i)); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + // TODO(b/74360564): We should implement whatever out of bounds behavior + // we decide for dynamic-slice here as well. + input_index[i] = (input_gather_index[i] + input_window_index[i]) % + operand_shape.dimensions(i); + if (input_index[i] < 0) { + input_index[i] += operand_shape.dimensions(i); + } + } + TF_RETURN_IF_ERROR( + result->CopyElementFrom(operand, input_index, output_index)); + return true; + }; + + auto gather_outer_loop_body = + [&](ArraySlice output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + ArraySlice input_gather_index, + output_gather_index_to_input_index(output_gather_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + shape, window_indices_iteration_space, + std::bind(gather_inner_loop_body, std::placeholders::_1, + input_gather_index, output_gather_index))); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + shape, gather_indices_iteration_space, gather_outer_loop_body)); + evaluated_[gather] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { + const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand.shape().dimensions(i)); + } + + TF_ASSIGN_OR_RETURN( + evaluated_[broadcast], + operand.Broadcast(broadcast->shape(), broadcast->dimensions())); + + return Status::OK(); +} + +Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { + // Literals cannot represent a TOKEN shape so just create an empty tuple as + // the "result" of the kGenerateToken operation. + // TODO(b/109929053): Add support for TOKENs in Literals. + evaluated_[token] = Literal::MakeTuple({}); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -2432,6 +948,137 @@ Status HloEvaluator::HandleCopy(HloInstruction* copy) { return Status::OK(); } +Status HloEvaluator::HandleCall(HloInstruction* call) { + auto* computation = call->to_apply(); + auto operands = call->operands(); + + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[call] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleFusion(HloInstruction* fusion) { + HloModuleConfig config; + // Attach cloned computation to an empty HLO module so the existing ones are + // not modified. + HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloCloneContext context(&empty_hlo_module); + auto cloned_fused_computation = + fusion->fused_instructions_computation()->Clone( + /*suffix=*/"clone_with_layout", &context); + for (auto* instruction : cloned_fused_computation->instructions()) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } + auto readded_computation = + empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); + + auto operands = fusion->operands(); + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator + .Evaluate(*readded_computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[fusion] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleConditional(HloInstruction* conditional) { + const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); + const auto& true_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(1)); + const auto& false_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(2)); + + auto* true_computation = conditional->true_computation(); + auto* false_computation = conditional->false_computation(); + + HloEvaluator embedded_evaluator; + std::unique_ptr result; + if (pred.Get({})) { + result = embedded_evaluator + .Evaluate(*true_computation, + {&true_computation_arg}) + .ConsumeValueOrDie(); + } else { + result = embedded_evaluator + .Evaluate(*false_computation, + {&false_computation_arg}) + .ConsumeValueOrDie(); + } + + evaluated_[conditional] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleSelect(HloInstruction* select) { + const auto& pred = GetEvaluatedLiteralFor(select->operand(0)); + const auto& on_true = GetEvaluatedLiteralFor(select->operand(1)); + const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); + + // If predicate is of scalar type, no element-wise selection would be needed. + // This would also handle output array of tuple types as the DefaultAction + // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. + if (ShapeUtil::IsScalar(pred.shape())) { + if (pred.Get({})) { + evaluated_[select] = on_true.CloneToUnique(); + } else { + evaluated_[select] = on_false.CloneToUnique(); + } + return Status::OK(); + } + + return DefaultAction(select); +} + +Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { + HloComputation* cond_comp = while_hlo->while_condition(); + HloComputation* body_comp = while_hlo->while_body(); + // Initialize the loop carried valued with the input to the While instruction. + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + bool keep_going = true; + int64 iteration_count = 0; + HloEvaluator cond_evaluator(max_loop_iterations_); + HloEvaluator loop_body_evaluator(max_loop_iterations_); + while (keep_going) { + if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { + return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", + while_hlo->name().c_str(), max_loop_iterations_); + } + TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( + *cond_comp, {lcv.get()})); + keep_going = cond_val->GetFirstElement(); + if (keep_going) { + TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( + *body_comp, {lcv.get()})); + VLOG(3) << "Loop iteration result: " << body_val->ToString(); + lcv = std::move(body_val); + cond_evaluator.ResetVisitStates(); + loop_body_evaluator.ResetVisitStates(); + } + } + evaluated_[while_hlo] = std::move(lcv); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); @@ -2445,28 +1092,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(const HloModule& module, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( - const HloModule& module, - tensorflow::gtl::ArraySlice> arg_literals); + const HloModule& module, ArraySlice> arg_literals); -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(const HloComputation& computation, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( const HloComputation& computation, - tensorflow::gtl::ArraySlice> arg_literals); + ArraySlice> arg_literals); -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(HloInstruction* instruction, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( HloInstruction* instruction, - tensorflow::gtl::ArraySlice> arg_literals); + ArraySlice> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 3b2b697e492a78a06a4e5ae6bf056ff8676f2ff5..fc2fc9437b238a2e519401b2b121dfbef070e2dc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -36,7 +37,10 @@ namespace xla { // This class is not thread-safe. class HloEvaluator : public DfsHloVisitorWithDefault { public: - HloEvaluator(); + // Only evaluate up to max_loop_iterations per while-loop execution if + // specified. + explicit HloEvaluator(int64 max_loop_iterations = -1); + // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. // Precondition: The indices of arg_literals correspond to the parameter @@ -105,20 +109,23 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const std::unordered_map& substitutions); + StatusOr> EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs); + + StatusOr> EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand); + protected: - // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this rule, notably: - // - HandleCompare and HandleIsFinite: where the resulting literal type is - // always boolean. - // These operations are handled outside of the parent HloEvaluator handlers - // instead of from within TypedVisitor. + // Make HloEvaluatorTypedVisitor a friend because it is logically part of this + // class. // - // Type params: - // - ReturnT: The type of input and output of each operation. - // - ElementwiseT: The type in which internal computation are done. - template - class TypedVisitor; + // A straightforward implementation would be to make it a nested class + // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor + // lives as a separate class with its own header because its template gets + // instantiated many times and we want to use extern templates to shard out + // the compilation of those instantiations across multiple cc files. + template + friend class HloEvaluatorTypedVisitor; // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. @@ -149,11 +156,26 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; + Status HandleGather(HloInstruction* gather) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleCopy(HloInstruction* copy) override; - private: + Status HandleConditional(HloInstruction* conditional) override; + + Status HandleCall(HloInstruction* call) override; + + Status HandleFusion(HloInstruction* fusion) override; + + Status HandleWhile(HloInstruction* while_hlo) override; + + Status HandleSelect(HloInstruction* select) override; + + Status HandleBroadcast(HloInstruction* broadcast) override; + + Status HandleGenerateToken(HloInstruction* token) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -168,14 +190,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return *(it->second); } - // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; - // Tracks the HLO instruction and its evaluated literal result. // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in @@ -184,12 +198,50 @@ class HloEvaluator : public DfsHloVisitorWithDefault { tensorflow::gtl::FlatMap> evaluated_; + private: + template + static StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = MakeUnique(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); + })); + return std::move(result); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of // each invocation to the Evaluate* method. // Must be cleared for each evaluation. std::vector arg_literals_; + // Max loop iterations to execute with no maximum if negative. + int64 max_loop_iterations_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 97765d65909cee192f65069777f8f195081603b2..72eb9930e92c340ab9f42cd563c27507623b2ba7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -81,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - LiteralTestUtil::ExpectNear(*expected, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); } else { - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } } @@ -99,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } bool use_bfloat16_; @@ -128,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -149,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -174,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -261,13 +262,13 @@ TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = @@ -306,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies Reshape operation is correctly evaluated. @@ -314,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -332,7 +333,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); } @@ -350,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -369,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -391,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -412,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -431,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -451,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } PaddingConfig CreatePaddingConfig( @@ -489,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -524,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = Literal::CreateR4FromArray4D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -566,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -605,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -650,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -687,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { auto expected = Literal::CreateR1({22.f, 28.f}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -736,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -784,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -826,7 +827,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -846,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -926,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1003,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1045,7 +1046,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1066,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1108,7 +1109,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1130,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, @@ -1179,7 +1180,7 @@ TEST_P(HloEvaluatorTest, *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1202,9 +1203,83 @@ TEST_P(HloEvaluatorTest, })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; + +// Tests that Reduce doesn't lose precision when adding many numbers (because +// it accumulates its result in a double). +TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + HloComputation::Builder b(TestName()); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction* init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module().AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + std::unique_ptr result = + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, *result); +} + +// Reducing many numbers should be fast because it doesn't create +// intermediate Literals; the microbenchmark should finish in < 1 msec. +void BM_ReducePrecisely(int num_iters) { + tensorflow::testing::StopTiming(); + HloComputation::Builder b("BM_ReducePrecisely"); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_ReducePrecisely", config); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(v))); + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module.AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + tensorflow::testing::StartTiming(); + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_ReducePrecisely); + TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); @@ -1244,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { auto expected = Literal::CreateR1({6, 18}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1295,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1352,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1415,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); - LiteralTestUtil::ExpectEqual(*result_literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1448,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { {19}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1481,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1516,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1552,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { {5, -6, -7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1587,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { {5, 6, 7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1628,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { result_inner_literal.get(), }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1681,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1701,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1725,8 +1800,250 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { auto result = evaluator.EvaluateWithSubstitutions( add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR3( + {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { + const char* hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, + EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { + const char* hlo_text = R"( +HloModule DynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,0] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 0} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { + const string hlo_text = R"( +HloModule GatherXd + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1} +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise comparison with 2 bfloat16 operands. +TEST_P(HloEvaluatorTest, DoesCompareBF16) { + // lhs >= rhs + auto lhs = Literal::CreateR2( + {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, + {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}}); + auto rhs = Literal::CreateR2( + {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)}, + {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); + auto expected = + Literal::CreateR2({{false, true, true}, {false, true, true}}); + TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), + std::move(rhs)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..13f46407e33e36bdbef4c9032630101d6c18268f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -0,0 +1,2142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// TODO(b/79274244): We'd like these type traits to live inside of +// HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that +// crashes clang in the frontend. +// +// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is +// a "private" header that's not exposed outside of hlo_evaluator.cc. +template +using is_complex_t = std::is_same; +template +using is_complex64_t = std::is_same; + +// Templated DfsHloVisitor for use by HloEvaluator. +// +// Typically ReturnT here indicates the resulting literal type of each evaluated +// Handle* method of a TypedVisitor. There are however a few notable exceptions +// to this rule, notably: +// - HandleCompare and HandleIsFinite: where the resulting literal type is +// always boolean. +// These operations are handled outside of the parent HloEvaluator handlers +// instead of from within TypedVisitor. +// +// Type params: +// - ReturnT: The type of input and output of each operation. +// - ElementwiseT: The type in which internal computation are done. +// +// This a logically a private part of HloEvaluator. It lives in this header +// file rather than in hlo_evaluator.cc because we use extern templates and a +// bunch of independent cc files to speed up compiling the many instantiations +// of this class. +template +class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} + + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + } + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive type. + + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (HloEvaluator::ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } + return HandleAbs(abs); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound(round); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil) override { + return HandleCeil(ceil); + } + + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleBitcastConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleExp(HloInstruction* exp) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::expm1(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Expm1"); + } + + Status HandleExpm1(HloInstruction* floor) override { + return HandleExpm1(floor); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor) override { + return HandleFloor(floor); + } + + Status HandleLog(HloInstruction* log) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::log1p(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Log1p"); + } + + Status HandleLog1p(HloInstruction* floor) override { + return HandleLog1p(floor); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } + + Status HandleNot(HloInstruction* not_) override { + return HandleNot(not_); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ElementwiseT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign) override { + return HandleSign(sign); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2(atan2); + } + + Status HandleTanh(HloInstruction* tanh) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + } + + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + + Status HandleSubtract(HloInstruction* subtract) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum(maximum); + } + + template ::value>::type* = + nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum(minimum); + } + + Status HandlePower(HloInstruction* power) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder) override { + return HandleRemainder(remainder); + } + + template ::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_) override { + return HandleAnd(and_); + } + + template ::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_) override { + return HandleOr(or_); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return IsShiftOutOfBounds(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl) override { + return HandleShiftLeft(shl); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + SignedT lhs_signed = static_cast(lhs_elem); + if (IsShiftOutOfBounds(rhs_elem)) { + return lhs_signed < 0 ? static_cast(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra) override { + return HandleShiftRightArithmetic(shra); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (IsShiftOutOfBounds(rhs_elem)) { + return static_cast(0); + } + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl) override { + return HandleShiftRightLogical(shrl); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return std::fmin(high, std::fmax(value, low)); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction*) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp) override { + return HandleClamp(clamp); + } + + Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementwiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + } + + Status HandleReverse(HloInstruction* reverse) override { + const auto result_shape = reverse->shape(); + const auto reverse_dimensions = reverse->dimensions(); + + auto operand = reverse->operand(0); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReverseShape(operand->shape(), + reverse_dimensions)); + + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto result = MakeUnique(result_shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + std::vector from_index(out_index.begin(), out_index.end()); + for (const int64 dim : reverse_dimensions) { + from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; + } + return operand_literal.Get(from_index); + })); + + parent_->evaluated_[reverse] = std::move(result); + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(ShapeUtil::IsArray(lhs_shape)); + CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 0); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + const auto lhs_rank = ShapeUtil::Rank(lhs_shape); + const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, + window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + std::vector window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); + } + + const Shape& window_shape = + ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); + + DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); + DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); + + auto lhs_literal_data = lhs_literal.data(); + auto rhs_literal_data = rhs_literal.data(); + + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, + &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, + rhs_literal_data]( + tensorflow::gtl::ArraySlice out_index) { + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + ElementwiseT result_val = static_cast(0); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), + 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + int64 lhs_linear_index = 0; + lhs_linear_index += out_index[output_batch_dim] * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; + + int64 rhs_linear_index = 0; + rhs_linear_index += out_index[output_z_dim] * + rhs_dim_multipliers[kernel_output_z_dim]; + rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + lhs_linear_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < + lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + rhs_linear_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + + result_val += + static_cast(lhs_literal_data[lhs_linear_index]) * + static_cast(rhs_literal_data[rhs_linear_index]); + } + cnt : {} + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return static_cast(result_val); + }; + + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + } + + Status HandleDot(HloInstruction* dot) override { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + auto result = MakeUnique(dot->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice result_index) { + ElementwiseT result_val = static_cast(0); + + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; + } + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; + + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); + } + + return static_cast(result_val); + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); + auto result = MakeUnique(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&scalar](tensorflow::gtl::ArraySlice multi_index) { + return scalar; + })); + + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set(target_index, + evaluated_operand.Get(input_index)); + return true; + }; + + std::vector zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + } + + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + auto operand = dynamic_slice->operand(0); + auto start_indices = dynamic_slice->operand(1); + auto result_shape = dynamic_slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), + dynamic_slice->dynamic_slice_sizes())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + default: + LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + auto operand = dynamic_update_slice->operand(0); + auto update = dynamic_update_slice->operand(1); + auto start_indices = dynamic_update_slice->operand(2); + auto result_shape = dynamic_update_slice->shape(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + default: + LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = MakeUnique(map->shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + /*arg=*/arg->shape(), + /*init_value=*/init_value->shape(), + /*dimensions_to_reduce=*/dimensions, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + std::vector arg_dim_steps(arg_dimensions.size()); + std::vector arg_dim_counts(arg_dimensions.size()); + for (const int64 dim : dimensions) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index.push_back(i); + } + } + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = init_scalar; + + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + computed_result += arg_literal.Get(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto curr_val = arg_literal.Get(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = Literal::CreateR0(curr_val); + auto result_val_literal = Literal::CreateR0(result_val); + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + result_val = computed_result->Get({}); + return true; + }; + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); + return Status::OK(); + } + + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = MakeUnique(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + return init_scalar; + })); + + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + DimensionVector source_index(rank, 0); + + // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid + // dynamic memory allocations. + auto curr_val_literal = Literal::CreateR0(ReturnT()); + auto selected_val_literal = Literal::CreateR0(ReturnT()); + auto source_literal_scatter = Literal::CreateR0(ReturnT()); + auto scattered_literal = Literal::CreateR0(ReturnT()); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional selected_val; + tensorflow::gtl::optional> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + curr_val_literal->Set({}, curr_val); + selected_val_literal->Set({}, *selected_val); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *select, + {selected_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + bool selected = !computed_result->Get({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get(source_index); + auto scattered = result->Get(operand_index); + source_literal_scatter->Set({}, source); + scattered_literal->Set({}, scattered); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate(*scatter, + {source_literal_scatter.get(), + scattered_literal.get()}) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + + Status HandleReduceWindow(HloInstruction* reduce_window) override { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferReduceWindowShape( + /*operand_shape=*/reduce_window->operand(0)->shape(), + /*init_value=*/reduce_window->operand(1)->shape(), window, + /*to_apply_shape=*/function->ComputeProgramShape())); + TF_RET_CHECK( + ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) + << "return shape is set to: " + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); + + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); + VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); + VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + // Creates a Shape object from window, for iteration below. + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + DimensionVector window_index(window.dimensions_size()); + DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce_window->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + ReturnT result_val = init_scalar; + + std::fill(window_index.begin(), window_index.end(), 0); + std::fill(operand_index.begin(), operand_index.end(), 0); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); + }); + + return result_val; + })); + + parent_->evaluated_[reduce_window] = std::move(result); + return Status::OK(); + } + + Status HandleSlice(HloInstruction* slice) override { + auto operand = slice->operand(0); + const Shape& shape = slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferSliceShape( + operand->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); + TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const int64 rank = ShapeUtil::Rank(operand->shape()); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto func = [&](tensorflow::gtl::ArraySlice out_index) { + DimensionVector operand_index(rank); + for (int64 i = 0; i < rank; ++i) { + operand_index[i] = + slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); + } + return operand_literal.Get(operand_index); + }; + + auto result = Literal::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + TF_RETURN_IF_ERROR(result->Populate(func)); + parent_->evaluated_[slice] = std::move(result); + return Status::OK(); + } + + // Enable CLZ only for int32, uint32, int64 and uint64. + template < + typename NativeT, + typename std::enable_if< + (std::is_floating_point::value || + std::is_integral::value || is_complex_t::value) && + !(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + return InvalidArgument("Unsupported type for Clz"); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 31 - tensorflow::Log2Floor(elem_operand); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 63 - tensorflow::Log2Floor64(elem_operand); + })); + return Status::OK(); + } + + Status HandleClz(HloInstruction* clz) override { + return HandleClz(clz); + } + + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision(reduce_precision); + } + + private: + // Creates a vector of multipliers which can be used to create a linear index + // into shape. + // + // Given the multidimensional index {i1, ..., iN} and + // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply + // + // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. + // + // This lets you calculate LI given the multidimensional indices in any order. + static DimensionVector MakeDimMultipliers(const Shape& shape) { + DimensionVector v(ShapeUtil::Rank(shape)); + int64 scale = 1; + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + v[dim] = scale; + scale *= shape.dimensions(dim); + } + return v; + } + + // For one particular placement of a window in a base shape (the placement is + // represented as `window_count_index`), iterates inside the window. + // Translates the window index into base index. If the base index is within + // bound, call `f` with the base index. + static void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice& window_count_index, + const std::function&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } + + template + StatusOr> DynamicSlice( + const Literal& operand_literal, const Literal& start_indices_literal, + const Shape& result_shape) { + auto start_indices_typed = start_indices_literal.data(); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + + // Clamp the start indices so the slice is in-bounds w.r.t the operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to officially document different behavior. + for (int64 i = 0; i < start.size(); ++i) { + start[i] = std::min( + std::max(int64{0}, start[i]), + operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); + } + + std::vector operand_indices(start.size()); + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < operand_indices.size(); ++i) { + CHECK_GE(multi_index[i] + start[i], 0); + operand_indices[i] = multi_index[i] + start[i]; + } + + auto result = operand_literal.Get(operand_indices); + return result; + })); + + return std::move(result); + } + + template + StatusOr> DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.CloneToUnique(); + auto start_indices_typed = start_indices_literal.data(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + // Clamp the update start indices so the slice is in-bounds w.r.t the + // operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + for (int64 i = 0; i < rank; ++i) { + start[i] = std::min( + std::max(0, start[i]), + result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + } + std::vector result_index(rank, 0); + + auto func = [&](tensorflow::gtl::ArraySlice update_index) { + std::transform(update_index.begin(), update_index.end(), start.begin(), + result_index.begin(), std::plus()); + result->Set(result_index, + update_literal.Get(update_index)); + return true; + }; + + std::vector base(update_literal.shape().dimensions_size(), 0); + std::vector step(update_literal.shape().dimensions_size(), 1); + ShapeUtil::ForEachIndex(update_literal.shape(), base, + AsInt64Slice(update_literal.shape().dimensions()), + step, func); + + return std::move(result); + } + + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (HloEvaluator::ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); + } + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& + binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + return std::move(result); + } + + template + StatusOr> ElementwiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); + })); + + return std::move(result); + } + + template + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + + HloEvaluator* parent_; +}; + +// These extern templates prevent users of this class from implicitly +// instantiating it. We explicitly instantiate this class in the various +// hlo_evaluator_typed_visitor*.cc files. +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc new file mode 100644 index 0000000000000000000000000000000000000000..39c352dfb966af4ad9f1874d078b92dd2a321783 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc new file mode 100644 index 0000000000000000000000000000000000000000..289b40fa06d37b8f5b2705e7de2f479c4a30e89d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cb4eb921fd3af566de5998a097423c90f0cb860 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e6252fbf8c24a7b79c7e656040a6be7be8d777f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee793ae77b1b432daece31697ad436de1683bc08 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc new file mode 100644 index 0000000000000000000000000000000000000000..038d9d39e4a5881b9f0fb1d98732132aab3aaa2c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1952ca6193958eec49fd15297f73a6c6ac22b83 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cbaffb40b7128fb6e99308fbc2b48e63a3d6fac --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f4bf2a392b51abc4d37db4beab6d1ea2b0c4e3a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc new file mode 100644 index 0000000000000000000000000000000000000000..10235447e0d266a6071097e38913c3856939509b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc new file mode 100644 index 0000000000000000000000000000000000000000..8abeaa6ffca4409d2664de6f55850622e95bbc9d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc new file mode 100644 index 0000000000000000000000000000000000000000..6dabd1c176eabcf6656d6de9683bbf0131456d96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index f0df93b61d29c1535d8a89fbd65e669de5b43729..c3ccbf0f0c75b569b49652807dea52faebdccc31 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -111,8 +111,8 @@ HloExecutionProfile::HloExecutionProfile( : hlo_profile_printer_data_(*hlo_profile_printer_data), hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( - /*count*/ hlo_profile_index_map_.total_count(), - /*value*/ 0) {} + /*count=*/hlo_profile_index_map_.total_count(), + /*value=*/0) {} void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 6fb91b9bef9d1df82b8806ce79cc147823edeb3d..be989846ef5cd2645da88ac9bbfea9534dd47821 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -88,7 +88,7 @@ std::unique_ptr CreateHloProfilePrinterData( // down how much time each HLO took. class HloExecutionProfile { public: - using DeviceDescription = perftools::gputools::DeviceDescription; + using DeviceDescription = se::DeviceDescription; HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index a0cb28246d3be541e798e85552436f64a3521f22..eba80c0f199f6224f4b46ac19af482c713585154 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -15,53 +15,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -class HloExecutionProfileTest : public HloTestBase { - protected: - static constexpr int64 kInstructionCyclesIndex = 0; - static constexpr int64 kInstructionNameIndex = 19; -}; +using tensorflow::strings::StrCat; +using ::testing::AllOf; +using ::testing::ContainsRegex; -// Splits `lines` into a sequence of lines delimited by newlines and then split -// each of those lines into a sequence of words delimited by spaces. Filter out -// empty words. -std::vector> SplitIntoLinesAndWords( - tensorflow::StringPiece lines) { - std::vector> result; - for (const string& line : tensorflow::str_util::Split(lines, '\n')) { - std::vector words; - for (const string& word : tensorflow::str_util::Split(line, ' ')) { - if (!word.empty()) { - words.push_back(word); - } - } - result.push_back(std::move(words)); - } - - return result; -} +class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - std::unique_ptr hlo_module = CreateNewModule(); - - HloComputation::Builder builder(TestName()); + auto hlo_module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + lhs = f32[30,30]{1,0} parameter(0) + rhs = f32[30,30]{1,0} parameter(1) + add = f32[30,30]{1,0} add(lhs, rhs) + ROOT dot = f32[30,30]{1,0} dot(lhs, add), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + const HloInstruction* dot_instruction = + hlo_module->entry_computation()->root_instruction(); + const HloInstruction* add_instruction = dot_instruction->operand(1); Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); - HloInstruction* param_lhs = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - HloInstruction* param_rhs = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - HloInstruction* add_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloInstruction* dot_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, param_lhs, add_instruction)); - - hlo_module->AddEntryComputation(builder.Build()); auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; @@ -84,20 +64,12 @@ TEST_F(HloExecutionProfileTest, Basic) { execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); - string rendered_profile = execution_profile.ToString( - backend().default_stream_executor()->GetDeviceDescription()); - std::vector> lines_and_words = - SplitIntoLinesAndWords(rendered_profile); - ASSERT_EQ(lines_and_words.size(), 8); - - const std::vector& line_2 = lines_and_words[2]; - const std::vector& line_3 = lines_and_words[3]; - - EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); - EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); - - EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); - EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); + EXPECT_THAT(execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()), + AllOf(ContainsRegex(StrCat(dot_cycles, R"(\b.*%)", + dot_instruction->name())), + ContainsRegex(StrCat(add_cycles, R"(\b.*%)", + add_instruction->name())))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 2861fec39ef0c92fdfbcee04584f9bd36d3cb4d8..28fc6c4209bcc14d890f28d6a9935c55b76586c0 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -157,52 +159,60 @@ enum ColorScheme { kDashedBorder, }; +// Graphviz attributes/colors that make up a color scheme. +struct NodeColors { + const char* style; + const char* fill_color; + const char* stroke_color; + const char* font_color; +}; + +NodeColors NodeColorsForScheme(ColorScheme color) { + switch (color) { + case kBlue: + return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"}; + case kBrown: + return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"}; + case kDarkBlue: + return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; + case kDarkGreen: + return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkRed: + return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; + case kGray: + return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"}; + case kGreen: + return NodeColors{"filled", "#c8e6c9", "#97b498", "black"}; + case kOrange: + return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"}; + case kPurple: + return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"}; + case kRed: + return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"}; + case kWhite: + return NodeColors{"filled", "white", "black", "black"}; + case kYellow: + return NodeColors{"filled", "#fff9c4", "#cbc693", "black"}; + case kDashedBorder: + // "filled,dashed" looks the same as "dashed", since we have a white + // background. But we use "filled,dashed" so that when you hover over + // any part of the node (not just the text inside the node), our css + // :hover rule is triggered. + return NodeColors{"filled,dashed", "white", "#757575", "#757575"}; + } +} + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's style and fill/stroke/text colors. // // Colors are from https://material.io/color. string NodeColorAttributes(ColorScheme color) { - using std::make_tuple; - - const char *style, *fill_color, *stroke_color, *font_color; - std::tie(style, fill_color, stroke_color, font_color) = [color] { - switch (color) { - case kBlue: - return make_tuple("filled", "#bbdefb", "#8aacc8", "black"); - case kBrown: - return make_tuple("filled", "#bcaaa4", "#8c7b75", "black"); - case kDarkBlue: - return make_tuple("filled", "#1565c0", "#003c8f", "white"); - case kDarkGreen: - return make_tuple("filled", "#2e7d32", "#005005", "white"); - case kDarkRed: - return make_tuple("filled", "#b71c1c", "#7f0000", "white"); - case kGray: - return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black"); - case kGreen: - return make_tuple("filled", "#c8e6c9", "#97b498", "black"); - case kOrange: - return make_tuple("filled", "#ffe0b2", "#cbae82", "black"); - case kPurple: - return make_tuple("filled", "#e1bee7", "#af8eb5", "black"); - case kRed: - return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black"); - case kWhite: - return make_tuple("filled", "white", "black", "black"); - case kYellow: - return make_tuple("filled", "#fff9c4", "#cbc693", "black"); - case kDashedBorder: - // "filled,dashed" looks the same as "dashed", since we have a white - // background. But we use "filled,dashed" so that when you hover over - // any part of the node (not just the text inside the node), our css - // :hover rule is triggered. - return make_tuple("filled,dashed", "white", "#757575", "#757575"); - } - }(); + NodeColors node_colors = NodeColorsForScheme(color); return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style, - font_color, stroke_color, fill_color); + R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, node_colors.stroke_color, + node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a @@ -313,12 +323,12 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - const DebugOptions& debug_options, bool show_metadata, + const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(label.ToString()), + label_(std::string(label)), debug_options_(debug_options), - show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -357,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -384,7 +395,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -418,7 +429,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map sharding_colors_; + std::unordered_map + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -580,15 +592,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { VLOG(2) << "Dumping subcomputation " << subcomp->name(); - const char* computation_fmt = R"(subgraph %s { -%s -label = <%s>; -labelloc = t; -tooltip = " "; -%s -} // %s + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); + } -)"; + // Have we already dumped this subcomputation? If so, generating the edge + // linking it and parent_instr is all we want to do in this function. + if (cluster_ids_.find(subcomp) != cluster_ids_.end()) { + return ""; + } cluster_ids_[subcomp] = next_cluster_id_++; @@ -603,12 +626,26 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } - // Subcomputation's fill/stroke color is light/dark red/gray, depending on - // whether or not the subcomputation's fusion node is highlighted. bool highlight = filter_.Highlight(parent_instr); - const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; - const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + const char* fillcolor; + const char* strokecolor; + if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { + // Use the sharding color, if the node isn't highlighted. + NodeColors node_colors = + NodeColorsForScheme(GetInstructionColor(parent_instr)); + fillcolor = node_colors.fill_color; + strokecolor = node_colors.stroke_color; + } else { + // Subcomputation's fill/stroke color is light/dark red/gray, depending on + // whether or not the subcomputation's fusion node is highlighted. + fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; + strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + } style = Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", fillcolor, strokecolor); @@ -621,25 +658,16 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { - const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); - VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() - << " as " << next_edge_id_; - edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = - R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( - edge_fmt, InstructionId(from), InstructionId(parent_instr), - SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); - } - - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + const char* computation_fmt = R"(subgraph %s { +%s +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s - return computation; +)"; + return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -697,11 +725,25 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +static const HloConstantInstruction* TryGetFusionParameterConstant( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { + return nullptr; + } + const HloInstruction* fusion = instr->parent()->FusionInstruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + return DynCast(operand); +} + bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // If a node: // - // - is a tuple-shaped parameter, - // - is not a parameter to a fusion node, + // - is a parameter of a fusion node which is bound to a constant, + // + // or + // + // - is a tuple-shaped parameter, and + // - is not a parameter to a fusion node, and // - has at least kMinUsersToOmit users shown, and // - all of the shown users are get-tuple-elements, // @@ -709,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // // This helps us handle the common case where a while loop body has one big // tuple-shaped parameter. + if (TryGetFusionParameterConstant(instr) != nullptr) { + return true; + } const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && @@ -747,6 +792,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -764,24 +810,32 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_backend_config, + extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", - InstructionId(instr), node_body, node_shape, + InstructionId(instr), node_body, node_shape, node_metadata, NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); + // If the shape has a dimension of size zero, print it as e.g. + // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), + // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which + // is just noise. + if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { + return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + } + // Print the literal value of constants with <= K elements. optional elem_count; if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { @@ -797,7 +851,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::StringPiece(constant->name()).starts_with("constant")) { + if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); @@ -806,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( ShapeUtil::HumanString(constant->shape())); }; - // Special case: If instr is a parameter to a fusion node, check whether the - // corresponding operand to the fusion node is a constant. - if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->parent()->FusionInstruction(); - const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() != HloOpcode::kConstant) { - return ""; - } - return StrCat("constant ", stringify_constant(operand)); - } - std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast(operand); optional operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { - // Special case: If the operand is a parameter, use its parameter number - // rather than its name, because that's generally how people think of the - // node. + // Special case: If the operand is a parameter to a fusion node and it + // always has a constant value, display it like a regular constant. + // + // For other parameters, use the parameter number rather than the proper + // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + if (const HloConstantInstruction* constant = + TryGetFusionParameterConstant(operand)) { + operand_str = stringify_constant(constant); + } else { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } } else { operand_str = operand->name(); } @@ -850,24 +901,26 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it - // the same color as a parameter. + // the same color as a parameter. Unless the merged-in parameter is a + // parameter to a fusion node that is bound to a constant -- these aren't + // "real" parameters from the user's perspective. if (std::any_of(instr->operands().begin(), instr->operands().end(), [&](const HloInstruction* operand) { return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand); + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; })) { return kParameterColor; } @@ -883,12 +936,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kClz: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -896,6 +951,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -927,6 +983,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -938,7 +995,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kGreen; case HloOpcode::kConcatenate: - case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kPad: @@ -960,6 +1016,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kCopy: + // Emphasize copy nodes, which are either physical transposes (and thus + // significant), or copies of read-only buffers (and thus dead weight). + return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: @@ -975,6 +1035,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: return kPurple; + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: return kGray; @@ -1015,8 +1076,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::StringPiece(instr->name()) - .starts_with(HloOpcodeString(instr->opcode()))) { + if (tensorflow::str_util::StartsWith(instr->name(), + HloOpcodeString(instr->opcode()))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = @@ -1030,10 +1091,6 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { - if (!show_metadata_) { - return ""; - } - std::vector lines; if (!instr->metadata().op_name().empty()) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); @@ -1051,13 +1108,23 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; // Get the instruction's extra attributes excluding the names of its // subcomputations, since those are drawn explicitly in the graph. for (const auto& line : instr->ExtraAttributesToString( - HloPrintOptions().set_print_subcomputation_references(false))) { + HloPrintOptions().set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff))) { lines.push_back(HtmlLikeStringSanitize(line)); } @@ -1106,6 +1173,20 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return Join(lines, "
"); } +// Gets the total number of array elements in the given shape. For tuples, this +// is the sum of all the sizes of all of the array elements recursively in the +// tuple. +static int64 TotalElementsInShape(const Shape& shape) { + int64 elems = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + elems += ShapeUtil::ElementsIn(subshape); + } + }); + return elems; +} + void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1125,9 +1206,16 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } - const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; + + // We print "small" arrays using a hollow arrowhead and "large" arrays using + // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" + // means. + bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; + + const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - from->name(), to->name(), edge_label)); + (is_big_array ? "normal" : "empty"), from->name(), + to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1377,7 +1465,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1387,9 +1475,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1400,13 +1489,13 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, + HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3ba7fb131bfcb1287d807d785fd9774..0b11f34abb7f0d937a24d11f4dc5d2d6a0aae6e7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,7 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1f00aa41dc783f9e5657f5fa654884a31fae0fe7..68f41a1cbb4db228f5dcf8b4a6130f05e81262a8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -64,8 +64,8 @@ TEST(HloGraphDumperTest, NestedFusion) { sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, sums[i], params[i + 2]))); } - - HloModule m(TestName()); + HloModuleConfig config; + HloModule m(TestName(), config); m.AddEntryComputation(b.Build()); HloComputation* root_computation = m.entry_computation(); @@ -121,8 +121,9 @@ TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); - instruction->set_name("i_am_a_constant_root_instruction"); - HloModule m(TestName()); + instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); + HloModuleConfig config; + HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); string graph = hlo_graph_dumper::DumpGraph( *root_computation, /*label=*/"an_empty_graph", DebugOptions()); @@ -130,5 +131,23 @@ TEST(HloGraphDumperTest, Constant) { EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } +TEST(HloGraphDumperTest, TupleConstant) { + Shape tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); + HloComputation::Builder b("b"); + auto constant = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape))); + auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {3, 2}), constant, 0)); + + HloModuleConfig config; + HloModule m(TestName(), config); + HloComputation* root_computation = m.AddEntryComputation(b.Build(gte)); + string graph = hlo_graph_dumper::DumpGraph( + *root_computation, /*label=*/"tuple_constant", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("tuple_constant")); + EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b7dd055d7cd78eb759a2b24bcbbbc948159f9425..39662d1735e4b411ef36cbc8421eabe52be6165a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include -#include #include #include #include @@ -27,7 +26,9 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,8 +38,11 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -50,66 +54,225 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( - HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation) { + const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); - auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const string& operand_name : proto.operand_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) - << "No instruction named " << operand_name; - instruction->AppendOperand(instruction_map.at(operand_name)); - } - for (const string& predecessor_name : proto.control_predecessor_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) - << "No instruction named " << predecessor_name; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) - ->AddControlDependencyTo(instruction.get())); - } - - // In the proto, fused computations are held exclusively within the - // HloInstructionProto and do not appear as an HloComputationProto within the - // HloModuleProto. - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(proto.has_fused_instructions_computation()); - TF_RET_CHECK(!proto.fusion_kind().empty()); - TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, - StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); - } else { - for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_name)) - << "No computation named " << computation_name; - instruction->called_computations_.push_back( - computation_map.at(computation_name)); + std::unique_ptr instruction; + const auto operands = [&instruction_map, &proto](int index) { + return instruction_map.at(proto.operand_ids(index)); + }; + const auto computations = [&computation_map, &proto](int index) { + return computation_map.at(proto.called_computation_ids(index)); + }; + switch (opcode) { + // Ops migrated to subclasses. + case HloOpcode::kBatchNormTraining: + CHECK_EQ(proto.operand_ids_size(), 3); + instruction = CreateBatchNormTraining( + proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), + proto.feature_index()); + break; + case HloOpcode::kBatchNormInference: + CHECK_EQ(proto.operand_ids_size(), 5); + instruction = CreateBatchNormInference( + proto.shape(), operands(0), operands(1), operands(2), operands(3), + operands(4), proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kBatchNormGrad: + CHECK_EQ(proto.operand_ids_size(), 5); + instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + operands(2), operands(3), operands(4), + proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kFft: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector fft_length(proto.fft_length().begin(), + proto.fft_length().end()); + instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + tensorflow::gtl::ArraySlice(fft_length)); + break; + } + case HloOpcode::kSend: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSend(operands(0), proto.channel_id()); + break; + case HloOpcode::kSendDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSendDone(operands(0)); + break; + case HloOpcode::kRecv: + CHECK_EQ(proto.operand_ids_size(), 0); + instruction = + CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + break; + case HloOpcode::kRecvDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateRecvDone(operands(0)); + break; + case HloOpcode::kReverse: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateReverse(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kConcatenate: { + CHECK_EQ(proto.dimensions_size(), 1); + std::vector concat_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + concat_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateConcatenate(proto.shape(), concat_operands, + proto.dimensions(0)); + break; + } + case HloOpcode::kReduce: + CHECK_EQ(proto.operand_ids_size(), 2); + CHECK_EQ(proto.called_computation_ids_size(), 1); + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + break; + case HloOpcode::kTranspose: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateTranspose(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kBroadcast: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateBroadcast(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kMap: { + CHECK_EQ(proto.called_computation_ids_size(), 1); + std::vector map_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + map_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateMap(proto.shape(), map_operands, computations(0)); + break; + } + case HloOpcode::kSlice: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector slice_starts, slice_limits, slice_strides; + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + slice_starts.push_back(slice_dimensions.start()); + slice_limits.push_back(slice_dimensions.limit()); + slice_strides.push_back(slice_dimensions.stride()); + } + instruction = CreateSlice(proto.shape(), operands(0), slice_starts, + slice_limits, slice_strides); + break; + } + case HloOpcode::kConstant: { + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + break; + } + case HloOpcode::kTrace: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Trace instruction should have 1 operand but sees " + << proto.operand_ids_size(); + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + break; + } + case HloOpcode::kFusion: { + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within + // the HloModuleProto. + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(FusionKind fusion_kind, + StringToFusionKind(proto.fusion_kind())); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + std::vector fusion_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + fusion_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateFusion(proto.shape(), fusion_kind, fusion_operands, + fused_computation); + break; + } + case HloOpcode::kRng: { + std::vector rng_parms(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + rng_parms.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateRng(proto.shape(), proto.distribution(), rng_parms); + break; + } + case HloOpcode::kParameter: + instruction = CreateParameter(proto.parameter_number(), proto.shape(), + proto.name()); + break; + case HloOpcode::kGetTupleElement: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateGetTupleElement(proto.shape(), operands(0), + proto.tuple_index()); + break; + case HloOpcode::kReducePrecision: + instruction = + CreateReducePrecision(proto.shape(), operands(0), + proto.exponent_bits(), proto.mantissa_bits()); + break; + default: { + instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + if (instruction->opcode() != HloOpcode::kFusion) { + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; + instruction->called_computations_.push_back( + computation_map.at(computation_id)); + } + } + break; } } TF_RET_CHECK(!proto.name().empty()); - instruction->name_ = proto.name(); - + instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); - if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(instruction->literal_, - Literal::CreateFromProto(proto.literal())); - } - instruction->parameter_number_ = proto.parameter_number(); + instruction->backend_config_ = proto.backend_config(); - instruction->tuple_index_ = proto.tuple_index(); - for (int64 dimension : proto.dimensions()) { - instruction->dimensions_.push_back(dimension); - } if (proto.has_window()) { instruction->window_ = MakeUnique(proto.window()); } @@ -122,14 +285,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (const HloInstructionProto::SliceDimensions& slice_dimensions : - proto.slice_dimensions()) { - instruction->slice_starts_.push_back(slice_dimensions.start()); - instruction->slice_limits_.push_back(slice_dimensions.limit()); - instruction->slice_strides_.push_back(slice_dimensions.stride()); - } - instruction->exponent_bits_ = proto.exponent_bits(); - instruction->mantissa_bits_ = proto.mantissa_bits(); + for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); } @@ -138,16 +294,29 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique(proto.padding_config()); } instruction->outfeed_config_ = proto.outfeed_config(); - instruction->distribution_ = proto.distribution(); - instruction->epsilon_ = proto.epsilon(); - instruction->feature_index_ = proto.feature_index(); - instruction->channel_id_ = proto.channel_id(); instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); instruction->outfeed_shape_ = proto.outfeed_shape(); - instruction->fft_type_ = proto.fft_type(); - for (int64 fft_len : proto.fft_length()) { - instruction->fft_length_.push_back(fft_len); + + if (proto.has_sharding()) { + TF_ASSIGN_OR_RETURN(const auto& sharding, + HloSharding::FromProto(proto.sharding())); + instruction->set_sharding(sharding); + } + + if (proto.has_gather_dimension_numbers()) { + instruction->gather_dimension_numbers_ = + MakeUnique(proto.gather_dimension_numbers()); + } + for (int64 bound : proto.gather_window_bounds()) { + instruction->gather_window_bounds_.push_back(bound); + } + + instruction->channel_name_ = proto.channel_name(); + instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); + + for (int64 replica_group_id : proto.replica_group_ids()) { + instruction->replica_group_ids_.push_back(replica_group_id); } return std::move(instruction); @@ -155,50 +324,29 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); - instruction->parameter_number_ = parameter_number; - instruction->name_ = name; - return instruction; + return MakeUnique(parameter_number, shape, name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); - instruction->operands_.push_back(operand); - instruction->literal_ = Literal::CreateR1U8(tag); - return instruction; + return MakeUnique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); - instruction->literal_ = std::move(literal); - return instruction; + return MakeUnique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); - instruction->tuple_index_ = index; - instruction->AppendOperand(operand); - return instruction; + return MakeUnique(shape, operand, index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape)); - instruction->distribution_ = distribution; - instruction->shape_ = shape; - for (HloInstruction* param : parameters) { - instruction->AppendOperand(param); - } - return instruction; + return MakeUnique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -226,11 +374,15 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCeil: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kClz: + case HloOpcode::kDomain: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -309,13 +461,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation, tensorflow::gtl::ArraySlice static_operands) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(map_computation); - return instruction; + return MakeUnique(shape, operands, map_computation, + static_operands); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( @@ -341,11 +488,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); - instruction->AppendOperand(operand); - instruction->fft_type_ = fft_type; - instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); - return instruction; + return MakeUnique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( @@ -378,18 +521,29 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); - instruction->AppendOperand(operand); - instruction->exponent_bits_ = exponent_bits; - instruction->mantissa_bits_ = mantissa_bits; - return instruction; + return MakeUnique( + shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& channel_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(!channel_id.has_value()); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->called_computations_.push_back(reduce_computation); + instruction->replica_group_ids_.assign(replica_group_ids.begin(), + replica_group_ids.end()); + instruction->cross_replica_sum_barrier_ = std::string(barrier); + return instruction; } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -408,63 +562,51 @@ HloInstruction::CreateCrossReplicaSum( << "Outfeed shape " << shape << " must be compatible with operand shape " << operand->shape(); instruction->AppendOperand(operand); - instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_config_ = std::string(outfeed_config); instruction->outfeed_shape_ = shape; return instruction; } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(operand, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) + auto send_operand = DynCast(operand); + CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(send_operand); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(shape, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) + auto recv_operand = DynCast(operand); + CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(recv_operand); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr +HloInstruction::CreateGenerateToken( + tensorflow::gtl::ArraySlice operands) { + auto instruction = WrapUnique(new HloInstruction( + HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } return instruction; } @@ -501,18 +643,8 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); - instruction->AppendOperand(operand); - instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); - instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); - instruction->slice_strides_.assign(strides.begin(), strides.end()); - // For backward compatibility with old serialized computations: if there are - // no strides, assume all strides are 1. - // TODO(b/63317920): remove this code. - if (instruction->slice_strides_.empty()) { - instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); - } - return instruction; + return MakeUnique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( @@ -543,13 +675,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->dimensions_.push_back(dimension); - return instruction; + return MakeUnique(shape, operands, dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( @@ -572,13 +698,8 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); - instruction->AppendOperand(arg); - instruction->AppendOperand(init_value); - instruction->dimensions_.assign(dimensions_to_reduce.begin(), - dimensions_to_reduce.end()); - instruction->called_computations_.push_back(reduce_computation); - return instruction; + return MakeUnique( + shape, arg, init_value, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( @@ -599,14 +720,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, epsilon, feature_index); } /* static */ std::unique_ptr @@ -614,16 +729,8 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, mean, variance, epsilon, feature_index); } /* static */ std::unique_ptr @@ -632,16 +739,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->AppendOperand(grad_output); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique(shape, operand, scale, mean, + variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -664,12 +764,8 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(broadcast_dimensions.begin(), - broadcast_dimensions.end()); - return instruction; + return MakeUnique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -748,343 +844,42 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); - return instruction; -} - -// We put the fusion kind into the instruction's name for transpose-dot fusions, -// since those fusions are really just describing a type of dot rather than -// generating a novel computation. -static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { - switch (fusion_kind) { - case HloInstruction::FusionKind::kTransposeDot: - return "dot_fusion"; - default: - return "fusion"; - } + return MakeUnique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); - instruction->set_parent(fused_root->parent()); - instruction->set_metadata(fused_root->metadata()); - instruction->CloneAndFuseInternal(fused_root); - return instruction; + return MakeUnique(shape, fusion_kind, fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); - instruction->called_computations_.push_back(fusion_computation); - fusion_computation->SetFusionInstruction(instruction.get()); - return instruction; -} - -HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { - CHECK_EQ(opcode(), HloOpcode::kFusion); - CHECK_EQ(operand_count(), - fused_instructions_computation()->parameter_instructions().size()); - const int64 param_no = operand_count(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); - HloInstruction* fused_parameter = - fused_instructions_computation()->AddParameter( - HloInstruction::CreateParameter(param_no, new_operand->shape(), - param_name)); - AppendOperand(new_operand); - return fused_parameter; -} - -void HloInstruction::MergeFusionInstruction( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); - // Clone the instruction from which to merge fused instructions. - std::unique_ptr clone = instruction_to_merge->Clone(); - // Replace uses of fused parameters with the corresponding operand of the - // fusion. Add all non-parameter fused instructions to 'unfused_instructions' - // to be merged into 'this'. This is done in reverse post order. - std::vector unfused_instructions; - auto fused_instructions = - clone->fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_it = fused_instructions.rbegin(); - fused_it != fused_instructions.rend(); ++fused_it) { - auto fused_instruction = *fused_it; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith( - clone->mutable_operand(fused_instruction->parameter_number()))); - } else { - unfused_instructions.push_back(fused_instruction); - } - } - CHECK(unfused_instructions.front() == clone->fused_expression_root()); - // Replace instruction_to_merge use of 'this' with unfused_root. - TF_CHECK_OK( - instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); - // Fuse 'unfused_instructions' into 'this'. - for (auto& instruction : unfused_instructions) { - FuseInstruction(instruction); - instruction->DetachFromOperands(); - } - CHECK_EQ(0, clone->user_count()); - clone->DetachFromOperands(); - TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( - clone->fused_instructions_computation())); -} - -void HloInstruction::MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - // Add all non-parameter fused instructions to 'unfused_instructions' to be - // merged into 'this'. `old_to_new' maps the instructions in the fused node - // to the disaseembled fusion instructions. - // Note that we add the unfused instructions to this->parent_ computation. - // This is necessary because the unique_id needs for an instruction and - // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; - std::vector unfused_instructions; - auto computation_to_merge = - instruction_to_merge->fused_instructions_computation(); - auto post_order = computation_to_merge->MakeInstructionPostOrder(); - for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { - auto fused_instruction = *rit; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&old_to_new, fused_instruction, - instruction_to_merge->mutable_operand( - fused_instruction->parameter_number())); - continue; - } - - // Here we clone the insertion and call FuseInstructionIntoMultiOutput() - // which clones again. This can be improved. - auto cloned_instruction = - parent_->AddInstruction(fused_instruction->Clone()); - unfused_instructions.push_back(cloned_instruction); - InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); - } - for (auto unfused_instruction : unfused_instructions) { - for (int64 index = 0; index < unfused_instruction->operand_count(); - index++) { - auto new_operand = - FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); - TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); - } - } - - HloInstruction* unfused_root = unfused_instructions.front(); - TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); - - TF_CHECK_OK( - instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); - if (GetModule()) { - TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); - } - - // Fuse the root instruction and generate multiple outputs. - FuseInstructionIntoMultiOutput(unfused_root); - TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); - // The rest instructions are of normal fusing. - for (int64 i = 1; i < unfused_instructions.size(); i++) { - auto instruction = unfused_instructions[i]; - FuseInstruction(instruction); - TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); - } + return MakeUnique(shape, fusion_kind, operands, + fusion_computation); } -HloInstruction* HloInstruction::FuseInstructionInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - // When add_output is false, this fusion instruction must be a user of - // instruction_to_fuse. - if (!add_output) { - CHECK(IsUserOf(instruction_to_fuse)); +void HloInstruction::set_single_sharding(const HloSharding& sharding) { + CHECK(!sharding.IsTuple()) << sharding; + if (ShapeUtil::IsTuple(shape())) { + set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); + } else { + set_sharding(sharding); } - HloInstruction* fused_instruction = - CloneAndFuseInternal(instruction_to_fuse, add_output); - return fused_instruction; } -HloInstruction* HloInstruction::CloneAndFuseInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); - VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); - HloInstruction* clone = nullptr; - if (called_computations_.empty()) { - // New fusion instruction. It should not be a multioutput instruction. - CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", this); - builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - called_computations_.push_back( - CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); - clone = fused_expression_root(); +void HloInstruction::SetupDerivedInstruction( + HloInstruction* derived_instruction) const { + if (sharding_ != nullptr) { + derived_instruction->set_sharding(*sharding_); } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); - // When add_output is false, instruction_to_fuse is necessarily an operand - // of the fusion instruction. After fusion this will no longer be the case. - // Remove the operand from the operand list and remove its corresponding - // fused parameter instruction. Renumber parameters as necessary to make - // parameter numbers consistent with their index in the - // fused_parameter_ vector. - bool in_operand_list = std::find(operands_.begin(), operands_.end(), - instruction_to_fuse) != operands_.end(); - CHECK(add_output || in_operand_list); - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - if (instruction_to_fuse == operands_[operand_num]) { - // replace the fused parameter instruction's uses with the clone. - HloInstruction* fused_parameter = fused_parameters[operand_num]; - TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); - - // Remove the corresponding fused parameter and operand from their - // respective vectors. - TF_CHECK_OK( - fused_instructions_computation()->RemoveParameter(operand_num)); - operands_.erase(operands_.begin() + operand_num); - break; - } - } - // We've cloned instruction_to_fuse into this fusion instruction, so this - // fusion instruction is no longer a use of instruction_to_fuse. - if (in_operand_list) { - instruction_to_fuse->RemoveUser(this); - // When the instruction_to_fuse does not have other users, we don't need - // to generate a multioutput fusion instruction. - if (instruction_to_fuse->user_count() == 0) { - add_output = false; - } - } - } - - // Reread the parameters in the computation. - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - - // Add each operand of the clone as an operand of the fusion instruction. A - // complication is that some clone operands may already be operands of the - // fusion instruction. - for (int64 operand_num = 0; operand_num < clone->operand_count(); - ++operand_num) { - HloInstruction* operand = clone->mutable_operand(operand_num); - - // See if this operand is already an operand of the fusion node. - CHECK_EQ(operands_.size(), fused_parameters.size()); - HloInstruction* fused_param = nullptr; - for (int64 i = 0; i < operands_.size(); ++i) { - if (operands_[i] == operand) { - fused_param = fused_parameters[i]; - break; - } - } - - if (fused_param == nullptr) { - // Clone's operand was not already an operand of the fusion - // instruction. Add it as an operand and add a corresponding fused - // parameter instruction. - fused_param = AddFusionOperand(operand); - } - TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); - } - - if (add_output) { - CHECK_GT(instruction_to_fuse->user_count(), 0); - // If this is already a multioutput fusion instruction, expand the root - // tuple by 1. - HloInstruction* fused_root = fused_expression_root(); - HloInstruction::InstructionVector tuple_elements; - bool newly_created_tuple_instr = false; - if (fused_root->opcode() == HloOpcode::kTuple) { - tuple_elements = fused_root->operands(); - } else { - tuple_elements.push_back(fused_root); - newly_created_tuple_instr = true; - } - if (clone->opcode() == HloOpcode::kTuple) { - for (auto inst : clone->operands()) { - tuple_elements.push_back(inst); - } - } else { - tuple_elements.push_back(clone); - } - HloInstruction* new_root = fused_instructions_computation()->AddInstruction( - HloInstruction::CreateTuple(tuple_elements)); - fused_instructions_computation()->set_root_instruction(new_root); - shape_ = new_root->shape(); - if (fused_root->opcode() == HloOpcode::kTuple) { - TF_CHECK_OK( - fused_instructions_computation()->RemoveInstruction(fused_root)); - } - - // If this is a newly created multioutput instruction, we need to update - // the use of the original fusion instruction. - if (newly_created_tuple_instr) { - HloInstruction* new_instr = parent_->AddInstruction( - HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); - } - int64 index = tuple_elements.size(); - if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); - std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { - CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); - int64 old_tuple_index = old_gte->tuple_index(); - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - old_gte->shape(), this, index + old_tuple_index)); - TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); - to_be_removed.push_back(old_gte); - } - for (auto old_gte : to_be_removed) { - TF_CHECK_OK(parent_->RemoveInstruction(old_gte)); - } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); - } else { - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - clone->shape(), this, index - 1)); - TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); - } + derived_instruction->clear_sharding(); } - - VLOG(2) << "New clone:\n" << clone->ToString(); - return clone; -} - -RandomDistribution HloInstruction::random_distribution() const { - CHECK_EQ(opcode_, HloOpcode::kRng); - return distribution_; + derived_instruction->set_metadata(metadata_); } -bool HloInstruction::HasSideEffect() const { +bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -1096,16 +891,22 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; - default: { - // Check if any of the called computations has a side effect. - for (const auto& computation : called_computations()) { - if (computation->HasSideEffect()) { - return true; - } - } + default: return false; + } +} + +bool HloInstruction::HasSideEffect() const { + if (HasSideEffectNoRecurse()) { + return true; + } + // Check if any of the called computations has a side effect. + for (const auto& computation : called_computations()) { + if (computation->HasSideEffect()) { + return true; } } + return false; } /* static */ std::unique_ptr HloInstruction::CreateCall( @@ -1128,7 +929,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->custom_call_target_ = custom_call_target.ToString(); + instruction->custom_call_target_ = std::string(custom_call_target); return instruction; } @@ -1140,7 +941,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->channel_name_ = channel_name.ToString(); + instruction->channel_name_ = std::string(channel_name); instruction->cost_estimate_ns_ = cost_estimate_ns; return instruction; } @@ -1172,7 +973,8 @@ bool HloInstruction::HasSideEffect() const { /* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims) { + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; for (int64 output_window_dim : output_window_dims) { gather_dim_numbers.add_output_window_dims(output_window_dim); @@ -1184,13 +986,25 @@ bool HloInstruction::HasSideEffect() const { gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); } + gather_dim_numbers.set_index_vector_dim(index_vector_dim); return gather_dim_numbers; } +/* static */ std::unique_ptr HloInstruction::CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + instruction->operand_side_metadata_ = std::move(operand_side_metadata); + instruction->user_side_metadata_ = std::move(user_side_metadata); + instruction->AppendOperand(operand); + return instruction; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { @@ -1198,23 +1012,51 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } std::unique_ptr clone; - // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. switch (opcode_) { + // Ops migrated to subclasses. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + clone = CloneWithNewOperandsImpl(shape, new_operands, context); + break; // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -1258,23 +1100,24 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2]); break; // Other supported ops. - case HloOpcode::kBroadcast: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcast(shape, new_operands[0], dimensions_); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); + if (window_ != nullptr) { + clone->window_ = MakeUnique(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + clone->convolution_dimension_numbers_ = + MakeUnique( + *convolution_dimension_numbers_); + } break; case HloOpcode::kHostCompute: clone = CreateHostCompute(shape, new_operands, channel_name_, cost_estimate_ns_); break; - case HloOpcode::kConcatenate: - clone = CreateConcatenate(shape, new_operands, dimensions(0)); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1283,11 +1126,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kReducePrecision: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); - break; case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -1298,29 +1136,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kFft: - CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); - break; - case HloOpcode::kGetTupleElement: - CHECK_EQ(new_operands.size(), 1); - clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); - break; - case HloOpcode::kMap: - clone = CreateMap(shape, new_operands, to_apply()); + clone = + CreateCrossReplicaSum(shape, new_operands, to_apply(), + replica_group_ids_, cross_replica_sum_barrier_); break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); clone = CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); break; - case HloOpcode::kReduce: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); - break; case HloOpcode::kReduceWindow: CHECK_EQ(new_operands.size(), 2); clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], @@ -1332,22 +1157,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateSelectAndScatter(shape, new_operands[0], select(), *window_, new_operands[1], new_operands[2], scatter()); break; - case HloOpcode::kReverse: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReverse(shape, new_operands[0], dimensions_); - break; - case HloOpcode::kRng: - clone = CreateRng(shape, distribution_, new_operands); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kSlice: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); - break; case HloOpcode::kDynamicSlice: clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); @@ -1357,10 +1170,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands[2]); break; - case HloOpcode::kTranspose: - CHECK_EQ(new_operands.size(), 1); - clone = CreateTranspose(shape, new_operands[0], dimensions_); - break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); *clone->mutable_shape() = shape; @@ -1370,27 +1179,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kConstant: - clone = CreateConstant(literal_->CloneToUnique()); - break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); - break; - case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, name_); - break; - case HloOpcode::kBatchNormTraining: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), feature_index()); - break; - case HloOpcode::kBatchNormInference: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormInference( - shape, new_operands[0], new_operands[1], new_operands[2], - new_operands[3], new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); clone = CreateInfeed(shape, infeed_config()); @@ -1399,58 +1187,47 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); break; - case HloOpcode::kBatchNormGrad: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kSend: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSend(new_operands[0], channel_id()); - break; - case HloOpcode::kSendDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSendDone(new_operands[0]); - break; - case HloOpcode::kRecv: - CHECK_EQ(new_operands.size(), 0); - // The shape is a tuple, but CreateRecv() wants the raw data shape. - clone = - CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); - break; - case HloOpcode::kRecvDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateRecvDone(new_operands[0]); - break; case HloOpcode::kGather: CHECK_EQ(new_operands.size(), 2); clone = CreateGather(shape, new_operands[0], new_operands[1], *gather_dimension_numbers_, gather_window_bounds_); break; - case HloOpcode::kTrace: - LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); - } - clone->set_metadata(metadata_); - if (has_sharding()) { - clone->set_sharding(sharding()); + case HloOpcode::kDomain: + CHECK_EQ(new_operands.size(), 1); + clone = + CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); + break; + case HloOpcode::kGenerateToken: + clone = CreateGenerateToken(new_operands); + break; } + SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_raw_backend_config_string(backend_config_); + if (context != nullptr) { + context->MapInstruction(this, clone.get()); + clone->ReplaceCalledComputations([&](HloComputation* callee) { + return callee->parent() != context->module() + ? context->module()->DeepCloneComputation(callee, context) + : callee; + }); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1487,71 +1264,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -1575,38 +1287,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -const Literal& HloInstruction::literal() const { - CHECK_EQ(HloOpcode::kConstant, opcode_); - return *literal_; -} - -bool HloInstruction::CanHaveDimensionsField() const { - return (opcode() == HloOpcode::kReverse || - opcode() == HloOpcode::kConcatenate || - opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || - opcode() == HloOpcode::kTranspose); -} - -const std::vector& HloInstruction::dimensions() const { - CHECK(CanHaveDimensionsField()); - return dimensions_; -} - -int64 HloInstruction::dimensions(int64 index) const { - return dimensions()[index]; -} - -int64 HloInstruction::concatenate_dimension() const { - CHECK(opcode() == HloOpcode::kConcatenate); - CHECK_EQ(1, dimensions_.size()); - return dimensions(0); -} - -int64 HloInstruction::tuple_index() const { - CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); - return tuple_index_; -} - const HloInstruction* HloInstruction::operand(int64 i) const { return operands_[i]; } @@ -1625,6 +1305,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +HloInstruction::InstructionVector HloInstruction::unique_operands() const { + InstructionVector unique; + tensorflow::gtl::FlatSet seen; + for (HloInstruction* operand : operands()) { + if (seen.insert(operand).second) { + unique.push_back(operand); + } + } + return unique; +} + Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); if (std::find(control_successors_.begin(), control_successors_.end(), @@ -1639,14 +1330,35 @@ Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { } Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { - auto succ_it = std::find(control_successors_.begin(), - control_successors_.end(), instruction); - TF_RET_CHECK(succ_it != control_successors_.end()); - control_successors_.erase(succ_it); - auto pred_it = std::find(instruction->control_predecessors_.begin(), - instruction->control_predecessors_.end(), this); - TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); - instruction->control_predecessors_.erase(pred_it); + TF_RET_CHECK(instruction->parent() == parent()); + TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction)); + TF_RETURN_IF_ERROR( + EraseElementFromVector(&instruction->control_predecessors_, this)); + return Status::OK(); +} + +Status HloInstruction::DropAllControlDeps() { + for (auto* ctrl_succ : control_successors_) { + TF_RETURN_IF_ERROR( + EraseElementFromVector(&ctrl_succ->control_predecessors_, this)); + } + for (auto* ctrl_pred : control_predecessors_) { + TF_RETURN_IF_ERROR( + EraseElementFromVector(&ctrl_pred->control_successors_, this)); + } + control_successors_.clear(); + control_predecessors_.clear(); + return Status::OK(); +} + +Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) { + for (auto* ctrl_pred : inst->control_predecessors()) { + TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this)); + } + + for (auto* ctrl_succ : inst->control_successors()) { + TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ)); + } return Status::OK(); } @@ -1663,10 +1375,6 @@ void HloInstruction::AddUser(HloInstruction* user) { } } -bool HloInstruction::IsConstant() const { - return opcode_ == HloOpcode::kConstant; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -1679,25 +1387,29 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -1705,6 +1417,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: @@ -1717,6 +1430,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -1728,44 +1443,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; - case HloOpcode::kFusion: - return fusion_kind() == other.fusion_kind() && - eq_computations(fused_instructions_computation(), - other.fused_instructions_computation()); - // These opcodes have complex or special behavior so just return false. - case HloOpcode::kRng: - case HloOpcode::kTrace: + case HloOpcode::kDomain: case HloOpcode::kWhile: + case HloOpcode::kGenerateToken: return false; - case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); - - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kBatchNormGrad: - return feature_index() == other.feature_index() && - epsilon() == other.epsilon(); - - // A constant is defined by the value in the literal. - case HloOpcode::kConstant: - return literal() == other.literal(); - - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); - - // A reduce-precision operation is determined by the bit sizes. - case HloOpcode::kReducePrecision: - return exponent_bits() == other.exponent_bits() && - mantissa_bits() == other.mantissa_bits(); - // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1782,16 +1465,6 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - // FFT has various types & lengths. - case HloOpcode::kFft: - return fft_type() == other.fft_type() && - fft_length() == other.fft_length(); - - // Reduction results are determined by the reduction dimension and the - // reduction computation. - case HloOpcode::kReduce: - return dimensions() == other.dimensions() && - eq_computations(to_apply(), other.to_apply()); case HloOpcode::kReduceWindow: return eq_computations(to_apply(), other.to_apply()) && protobuf_util::ProtobufEquals(window(), other.window()); @@ -1803,43 +1476,30 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - // Remaining instructions with special values. - case HloOpcode::kBitcast: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); - case HloOpcode::kGetTupleElement: - return tuple_index() == other.tuple_index(); case HloOpcode::kPad: return protobuf_util::ProtobufEquals(padding_config(), other.padding_config()); - case HloOpcode::kSlice: - return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_ && - slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: - case HloOpcode::kMap: - return eq_computations(to_apply(), other.to_apply()); + case HloOpcode::kCrossReplicaSum: + return replica_group_ids() == other.replica_group_ids() && + cross_replica_sum_barrier() == other.cross_replica_sum_barrier() && + eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: + if ((window_ == nullptr) != (other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(window(), other.window()))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()))) { + return false; + } return custom_call_target_ == other.custom_call_target_; - case HloOpcode::kReverse: - return dimensions() == other.dimensions(); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); @@ -1848,21 +1508,36 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kHostCompute: return false; - } -} -bool HloInstruction::IsRank2Transpose() const { - return (opcode_ == HloOpcode::kTranspose) && - dimensions_ == std::vector({1, 0}) && - shape_.dimensions_size() == 2 && - std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), - operands_[0]->shape_.dimensions().rbegin()); + // Ops migrated to subclasses should never come to this line. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + LOG(FATAL) << "Base class impl called for opcode with subclass: " + << opcode(); + } } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1968,6 +1643,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2099,83 +1775,158 @@ string PrintName(const string& name, const HloPrintOptions& options) { } // namespace string HloInstruction::ToString(const HloPrintOptions& options) const { - string result = - StrCat(PrintName(name(), options), " = ", - ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + CanonicalNameMap new_map; + return ToStringWithCanonicalNameMap(options, &new_map); +} + +bool HloInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + switch (opcode_) { + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); + return true; + + // Binary elementwise operations, the same as in IsElementwiseBinary(). + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); + return true; + + // Ternary elementwise operations. + case HloOpcode::kSelect: + return !ShapeUtil::IsTuple(shape_); + case HloOpcode::kClamp: + return true; + + default: + return false; + } +} + +string HloInstruction::ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string result = ""; + + // Logic to print the instruction name (e.g. "%foo = "). + if (options.canonicalize_instruction_names()) { + if (options.is_in_nested_computation()) { + // If we are canonicalizing instruction names and this is a top-level + // HloInstruction::ToString() call, don't print an instruction name. + StrAppend(&result, + PrintName(canonical_name_map->LookupOrInsert(name()), options), + " = "); + } + } else { + StrAppend(&result, PrintName(name(), options), " = "); + } + + // Print opcode, operand(s) and shape. + StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", + OperandsToStringWithCanonicalNameMap(options, canonical_name_map), + ")"); + + // Print additional attributes. If an instruction contains a subcomputation, + // the subcomputation is also printed here. for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); + } return result; } string HloInstruction::OperandsToString(const HloPrintOptions& options) const { + CanonicalNameMap new_map; + return OperandsToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { string operands; - if (opcode() == HloOpcode::kConstant) { - // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { - // Literal::ToString emits multidimensional arrays over multiple - // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } - } else { - // Do not show large constants or tuples. - operands = "{...}"; - } - } else if (opcode() == HloOpcode::kParameter) { - StrAppend(&operands, parameter_number_); - } else { - tensorflow::gtl::ArraySlice slice(operands_); - const int64 kMaxOperandsToShowIfCompact = 4; - if (options.compact_operands() && - slice.size() > kMaxOperandsToShowIfCompact) { - slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + tensorflow::gtl::ArraySlice slice(operands_); + const int64 kMaxOperandsToShowIfCompact = 4; + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { + slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + } + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - std::vector str; - if (options.print_operand_shape()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); - } - if (!options.compact_operands()) { - str.push_back(PrintName(operand->name(), options)); - } - StrAppend(out, Join(str, " ")); - }); - const int64 remaining = operands_.size() - slice.size(); - if (slice.size() != operands_.size()) { - StrAppend(&operands, ", ...(+", remaining, ")"); + + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); + }); + const int64 remaining = operands_.size() - slice.size(); + if (slice.size() != operands_.size()) { + StrAppend(&operands, ", ...(+", remaining, ")"); } return operands; } std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { - std::vector extra; - if (opcode() == HloOpcode::kFusion) { - extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); - } - if (CanHaveDimensionsField()) { - extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); - } + std::vector extra = ExtraAttributesToStringImpl(options); if (window_ != nullptr && window_->dimensions_size() != 0) { extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } @@ -2183,32 +1934,16 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (opcode() == HloOpcode::kSlice) { - std::vector bounds; - bounds.reserve(slice_starts_.size()); - const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); - for (int i = 0; i < slice_starts_.size(); ++i) { - string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); - bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i], - stride_str, "]")); - } - extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); - } + if (opcode() == HloOpcode::kDynamicSlice) { extra.push_back( StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); } - if (opcode() == HloOpcode::kBatchNormTraining || - opcode() == HloOpcode::kBatchNormInference || - opcode() == HloOpcode::kBatchNormGrad) { - extra.push_back(StrCat("epsilon=", epsilon())); - extra.push_back(StrCat("feature_index=", feature_index())); - } if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); @@ -2218,12 +1953,9 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); } - if (opcode() == HloOpcode::kFft) { - extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); - extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); - } - if (options.print_subcomputation_references()) { + if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { extra.push_back( StrCat("condition=", PrintName(while_condition()->name(), options))); @@ -2240,7 +1972,8 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2251,16 +1984,46 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(computation->name(), options)); }))); } + } else if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kFullBodies) { + HloPrintOptions new_options = options; + new_options.set_is_in_nested_computation(true); + switch (opcode()) { + case HloOpcode::kWhile: + extra.push_back( + StrCat("condition=\n", while_condition()->ToString(new_options))); + extra.push_back(StrCat("body=\n", while_body()->ToString(new_options))); + break; + case HloOpcode::kSelectAndScatter: + extra.push_back(StrCat("select=\n", select()->ToString(new_options))); + extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); + break; + case HloOpcode::kConditional: + extra.push_back(StrCat("true_computation=\n", + true_computation()->ToString(new_options))); + extra.push_back(StrCat("false_computation=\n", + false_computation()->ToString(new_options))); + break; + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + extra.push_back( + StrCat("to_apply=\n", to_apply()->ToString(new_options))); + break; + default: + if (!called_computations().empty()) { + extra.push_back( + StrCat("calls=\n", + Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); + } + break; + } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { - extra.push_back(StrCat("channel_id=", channel_id_)); - } - - if (opcode() == HloOpcode::kGetTupleElement) { - extra.push_back(StrCat("index=", tuple_index())); - } if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } @@ -2280,22 +2043,27 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } - if (opcode() == HloOpcode::kRng) { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", operand_side_metadata_->ToString(), + ", exit=", user_side_metadata_->ToString(), "}")); + } + if (!replica_group_ids().empty()) { extra.push_back( - StrCat("distribution=", RandomDistributionToString(distribution_))); + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")); } - if (opcode() == HloOpcode::kReducePrecision) { - extra.push_back(StrCat("exponent_bits=", exponent_bits_)); - extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + if (!cross_replica_sum_barrier().empty()) { + extra.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } // By contract, we print the custom call target even if - // !options.print_subcomputation_references(), because the call target is not + // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. if (opcode() == HloOpcode::kCustomCall) { extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2310,35 +2078,28 @@ string HloInstruction::ToShortString() const { HloInstructionProto HloInstruction::ToProto() const { HloInstructionProto proto; + CHECK(unique_id_ != -1) + << "This instruction does not have a valid id. Please make sure the " + "instruction is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); *proto.mutable_shape() = shape_; for (const HloInstruction* operand : operands_) { - *proto.add_operand_names() = operand->name(); + proto.add_operand_ids(operand->unique_id()); } for (const HloInstruction* control : control_predecessors_) { - *proto.add_control_predecessor_names() = control->name(); + proto.add_control_predecessor_ids(control->unique_id()); } *proto.mutable_metadata() = metadata_; - if (literal_ != nullptr) { - *proto.mutable_literal() = literal_->ToProto(); - } - proto.set_parameter_number(parameter_number_); - if (opcode() == HloOpcode::kFusion) { - proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.mutable_fused_instructions_computation() = - fused_instructions_computation()->ToProto(); - } else { + proto.set_backend_config(backend_config_); + if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); + proto.add_called_computation_ids(computation->unique_id()); } } - proto.set_tuple_index(tuple_index_); - for (int64 dimension : dimensions_) { - proto.add_dimensions(dimension); - } if (window_ != nullptr) { *proto.mutable_window() = *window_; } @@ -2357,14 +2118,7 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_gather_window_bounds(bound); } } - for (int i = 0; i < slice_starts_.size(); ++i) { - auto* slice_dimension = proto.add_slice_dimensions(); - slice_dimension->set_start(slice_starts_[i]); - slice_dimension->set_limit(slice_limits_[i]); - slice_dimension->set_stride(slice_strides_[i]); - } - proto.set_exponent_bits(exponent_bits_); - proto.set_mantissa_bits(mantissa_bits_); + for (int64 slice_size : dynamic_slice_sizes_) { proto.add_dynamic_slice_sizes(slice_size); } @@ -2372,18 +2126,18 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_padding_config() = *padding_config_; } proto.set_outfeed_config(outfeed_config_); - if (opcode() == HloOpcode::kRng) { - proto.set_distribution(distribution_); - } - proto.set_epsilon(epsilon_); - proto.set_feature_index(feature_index_); - proto.set_channel_id(channel_id_); proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); *proto.mutable_outfeed_shape() = outfeed_shape_; - proto.set_fft_type(fft_type_); - for (int64 fft_len : fft_length_) { - proto.add_fft_length(fft_len); + + if (has_sharding()) { + *proto.mutable_sharding() = sharding().ToProto(); + } + + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); + for (int64 replica_group_id : replica_group_ids_) { + proto.add_replica_group_ids(replica_group_id); } return proto; @@ -2419,8 +2173,6 @@ string HloInstruction::ToCategory() const { return "input fusion"; case FusionKind::kOutput: return "output fusion"; - case FusionKind::kTransposeDot: - return "dot"; case FusionKind::kCustom: return "custom fusion"; } @@ -2439,70 +2191,22 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -string HloInstruction::TracingTag() const { - CHECK_EQ(HloOpcode::kTrace, opcode()); - CHECK(literal_ != nullptr); - return literal_->GetR1U8AsString(); -} - bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { // Instructions which are traced should not be fused. if (tracing()) { return false; - } - // Some kinds of instructions don't make sense to fuse. - switch (opcode_) { - case HloOpcode::kParameter: - return false; - // Side effecting instrutions cannot be fused. - default: - return !HasSideEffect(); - } -} - -HloComputation* HloInstruction::fused_instructions_computation() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(!called_computations_.empty()); - auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()); - return fused_instructions_computation; -} - -HloInstruction* HloInstruction::fused_expression_root() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->root_instruction(); -} - -HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instruction( - parameter_number); -} - -const std::vector& HloInstruction::fused_parameters() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instructions(); -} - -const tensorflow::gtl::iterator_range>::const_iterator>> -HloInstruction::fused_instructions() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - const HloComputation* subcomp = fused_instructions_computation(); - return subcomp->instructions(); -} - -const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> -HloInstruction::fused_instructions() { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->instructions(); -} - -int64 HloInstruction::fused_instruction_count() const { - return fused_instructions_computation()->instruction_count(); + } + // Some kinds of instructions don't make sense to fuse. + switch (opcode_) { + case HloOpcode::kDomain: + case HloOpcode::kParameter: + return false; + // Side effecting instrutions cannot be fused. + default: + return !HasSideEffect(); + } } HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) @@ -2605,12 +2309,18 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleNegate(this); case HloOpcode::kExp: return visitor->HandleExp(this); + case HloOpcode::kExpm1: + return visitor->HandleExpm1(this); case HloOpcode::kFloor: return visitor->HandleFloor(this); case HloOpcode::kCeil: return visitor->HandleCeil(this); + case HloOpcode::kClz: + return visitor->HandleClz(this); case HloOpcode::kLog: return visitor->HandleLog(this); + case HloOpcode::kLog1p: + return visitor->HandleLog1p(this); case HloOpcode::kTanh: return visitor->HandleTanh(this); case HloOpcode::kCos: @@ -2675,13 +2385,19 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kDomain: + return visitor->HandleDomain(this); + case HloOpcode::kGenerateToken: + return visitor->HandleGenerateToken(this); // These opcodes are not handled here. case HloOpcode::kTrace: break; } - return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s", - HloOpcodeString(opcode_).c_str()); + return InternalError( + "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " + "please file a bug for XLA.", + HloOpcodeString(opcode_).c_str()); } // Explicit instantiations. @@ -2899,6 +2615,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); @@ -2939,139 +2656,16 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - switch (opcode_) { - // Nullary elementwise operations. - case HloOpcode::kConstant: - return true; - - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kCeil: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kTanh: - CHECK_EQ(1, operand_count()); - return true; - - // Binary elementwise operations, the same as in IsElementwiseBinary(). - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - CHECK_EQ(2, operand_count()); - return true; - - // Ternary elementwise operations. - case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); - case HloOpcode::kClamp: - return true; - - // Other operations. - case HloOpcode::kRng: - case HloOpcode::kMap: - return true; - case HloOpcode::kFusion: - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - for (auto* fused : fused_instructions()) { - if (fused->opcode() != HloOpcode::kParameter && - !fused->IsElementwise()) { - return false; - } - } - return true; - - default: - return false; - } + return IsElementwiseImpl(tensorflow::gtl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { CHECK(IsElementwise()); - return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape()); + return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } -namespace { -bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, - const HloInstruction* operand) { - std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); -} -} // namespace - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { - // For all instructions other than kFusion, being elementwise on one of the - // operands is equivalent to being elementwise on all the operands. - if (opcode() != HloOpcode::kFusion) { - return IsElementwise(); - } - - CHECK_EQ(HloOpcode::kFusion, opcode()); - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - - // A loop-fusion is elementwise on an operand if all operations (computed - // using BFS) between the operand and the fused root are elementwise. - std::deque worklist; - std::unordered_set visited; - worklist.push_back(fused_parameter(operand_idx)); - visited.insert(fused_parameter(operand_idx)); - while (!worklist.empty()) { - HloInstruction* operand = worklist.front(); - worklist.pop_front(); - for (HloInstruction* user : operand->users()) { - CHECK_GE(user->unique_id(), 0); - if (ContainsKey(visited, user)) { - continue; - } - if (user->IsElementwise() || - IsInstructionElementwiseOnOperand(user, operand)) { - worklist.push_back(user); - visited.insert(user); - } else { - return false; - } - } - } - return true; + return IsElementwiseImpl(operand_idx); } // A helper class for memoized, recursive computation of HloOpcode::kFusion @@ -3093,8 +2687,10 @@ class HloInstruction::FusionReusesParamElements { static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, tensorflow::gtl::FlatMap* cache) { - if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { - return UseKind::kUse; + if (auto hlo_param = DynCast(&hlo)) { + if (hlo_param->parameter_number() == i) { + return UseKind::kUse; + } } auto p = cache->emplace(&hlo, UseKind{}); @@ -3197,8 +2793,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kInput"; case HloInstruction::FusionKind::kOutput: return "kOutput"; - case HloInstruction::FusionKind::kTransposeDot: - return "kTransposeDot"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3215,9 +2809,6 @@ StatusOr StringToFusionKind( if (kind_name == "kOutput") { return HloInstruction::FusionKind::kOutput; } - if (kind_name == "kTransposeDot") { - return HloInstruction::FusionKind::kTransposeDot; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } @@ -3261,42 +2852,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3320,19 +2877,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3358,6 +2904,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = @@ -3369,9 +2937,12 @@ string HloInstruction::GatherDimensionNumbersToString() const { string gather_dims_to_operand_dims = StrCat( "gather_dims_to_operand_dims={", Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims}, + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, ", "); } @@ -3386,6 +2957,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3403,21 +2999,173 @@ void HloInstruction::set_outer_dimension_partitions( outer_dimension_partitions_ = outer_dimension_partitions; } +// TODO(b/80131774): Remove these temporary methods after transition. +int64 HloInstruction::feature_index() const { + return Cast(this)->feature_index(); +} + +float HloInstruction::epsilon() const { + return Cast(this)->epsilon(); +} + +FftType HloInstruction::fft_type() const { + return Cast(this)->fft_type(); +} + +const std::vector& HloInstruction::fft_length() const { + return Cast(this)->fft_length(); +} + +int64 HloInstruction::channel_id() const { + return Cast(this)->channel_id(); +} + +int64 HloInstruction::concatenate_dimension() const { + return Cast(this)->concatenate_dimension(); +} + +bool HloInstruction::IsRank2Transpose() const { + auto transpose = DynCast(this); + return transpose != nullptr && transpose->IsRank2Transpose(); +} + +int64 HloInstruction::slice_starts(int64 dimension) const { + return Cast(this)->slice_starts(dimension); +} + +const std::vector& HloInstruction::slice_starts() const { + return Cast(this)->slice_starts(); +} + +int64 HloInstruction::slice_limits(int64 dimension) const { + return Cast(this)->slice_limits(dimension); +} + +const std::vector& HloInstruction::slice_limits() const { + return Cast(this)->slice_limits(); +} + +int64 HloInstruction::slice_strides(int64 dimension) const { + return Cast(this)->slice_strides(dimension); +} + +const std::vector& HloInstruction::slice_strides() const { + return Cast(this)->slice_strides(); +} + +bool HloInstruction::IsInPlaceSlice() const { + return Cast(this)->IsInPlaceSlice(); +} + +const Literal& HloInstruction::literal() const { + return Cast(this)->literal(); +} + +bool HloInstruction::IsConstant() const { + return DynCast(this) != nullptr; +} + void HloInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { - CHECK_EQ(opcode(), HloOpcode::kConstant); - Shape* mutable_array_subshape = - ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + Cast(this)->RelayoutConstant(new_layout, shape_index); +} + +string HloInstruction::TracingTag() const { + return Cast(this)->TracingTag(); +} - // Normally array_subshape will always have a layout, but this invariant is - // temporarily broken in LayoutAssignment::AssignLayouts. +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + return Cast(this)->AddFusionOperand(new_operand); +} - if (!mutable_array_subshape->has_layout() || - !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); - *mutable_array_subshape->mutable_layout() = new_layout; - } +// Delegates to HloFusionInstruction::MergeFusionInstruction. +void HloInstruction::MergeFusionInstruction( + HloInstruction* instruction_to_merge) { + return Cast(this)->MergeFusionInstruction( + Cast(instruction_to_merge)); +} + +// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. +void HloInstruction::MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge) { + return Cast(this) + ->MergeFusionInstructionIntoMultiOutput( + Cast(instruction_to_merge)); +} + +HloInstruction* HloInstruction::FuseInstruction( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstruction(instruction_to_fuse); +} + +HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstructionIntoMultiOutput( + instruction_to_fuse); +} + +HloComputation* HloInstruction::fused_instructions_computation() const { + return Cast(this)->fused_instructions_computation(); +} + +HloInstruction* HloInstruction::fused_expression_root() const { + return Cast(this)->fused_expression_root(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloInstruction::fused_instructions() const { + return Cast(this)->fused_instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloInstruction::fused_instructions() { + return Cast(this)->fused_instructions(); +} + +int64 HloInstruction::fused_instruction_count() const { + return Cast(this)->fused_instruction_count(); +} + +HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { + return Cast(this)->fused_parameter(parameter_number); +} + +const std::vector& HloInstruction::fused_parameters() const { + return Cast(this)->fused_parameters(); +} + +const bool HloInstruction::IsMultiOutputFusion() const { + const HloFusionInstruction* fusion = DynCast(this); + return fusion != nullptr && fusion->IsMultiOutputFusion(); +} + +HloInstruction::FusionKind HloInstruction::fusion_kind() const { + return Cast(this)->fusion_kind(); +} + +void HloInstruction::set_fusion_kind(FusionKind kind) { + return Cast(this)->set_fusion_kind(kind); +} + +RandomDistribution HloInstruction::random_distribution() const { + return Cast(this)->random_distribution(); } +int64 HloInstruction::parameter_number() const { + return Cast(this)->parameter_number(); +} + +int64 HloInstruction::tuple_index() const { + return Cast(this)->tuple_index(); +} + +int32 HloInstruction::exponent_bits() const { + return Cast(this)->exponent_bits(); +} + +int32 HloInstruction::mantissa_bits() const { + return Cast(this)->mantissa_bits(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index c4fe132d1d52d6071914869cd50a035ace3389b2..a206cdab2739cd5046295179ead5d8bf19d521c4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -60,51 +63,75 @@ class HloModule; // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: + enum class PrintSubcomputationMode { + kOff, // Do not print anything about subcomputations. + kNameOnly, // Only print the name of subcomputations. + kFullBodies, // Print the full bodies of subcomputations. + }; + // Constructs the default print options: don't print large constants, don't // compact operands, no indentation. HloPrintOptions() : print_large_constants_(false), - print_subcomputation_references_(true), + print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), print_percent_(true), - indent_amount_(0) {} + canonicalize_instruction_names_(false), + indent_amount_(0), + is_in_nested_computation_(false) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() .set_print_large_constants(true) - .set_print_subcomputation_references(true) + .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); } + // Options to produce the canonical string representing an isomorphic + // computation graph. + static HloPrintOptions Canonical() { + return HloPrintOptions() + .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) + .set_print_metadata(false) + .set_compact_operands(true) + .set_print_operand_shape(true) + .set_print_program_shape(false) + .set_print_percent(false) + .set_canonicalize_instruction_names(true); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; return *this; } - // If true, the names of subcomputations (e.g. a fusion node's fused - // computation) won't be printed. This makes the resulting text not parsable. - // - // A CustomCall's call target is printed even if - // print_subcomputation_references is false, because the call target isn't an - // HloComputation. - HloPrintOptions& set_print_subcomputation_references(bool value) { - print_subcomputation_references_ = value; + HloPrintOptions& set_print_subcomputation_mode( + PrintSubcomputationMode value) { + print_subcomputation_mode_ = value; return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -130,69 +157,185 @@ class HloPrintOptions { return *this; } + // If true, canonicalizes instructions' name. Instead of using "%foo.1" as + // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. + HloPrintOptions& set_canonicalize_instruction_names(bool value) { + canonicalize_instruction_names_ = value; + return *this; + } + // The indent of the hlo text block. HloPrintOptions& set_indent_amount(int value) { indent_amount_ = value; return *this; } + // If true, indicates the instruction being printed is inside a nested + // computation. + HloPrintOptions& set_is_in_nested_computation(bool value) { + is_in_nested_computation_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } - bool print_subcomputation_references() const { - return print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode() const { + return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool canonicalize_instruction_names() const { + return canonicalize_instruction_names_; + } int indent_amount() const { return indent_amount_; } + int is_in_nested_computation() const { return is_in_nested_computation_; } private: bool print_large_constants_; - bool print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool canonicalize_instruction_names_; int indent_amount_; + bool is_in_nested_computation_; +}; + +// For canonical string output, we need to have a canonical way to rename +// each instruction and its operands. Each operand is renamed as "tmp_", +// where is an index starting from 0. +class CanonicalNameMap { + public: + CanonicalNameMap() : index(0) {} + + string LookupOrInsert(const string& old_name) { + auto iter = canonical_name_map.find(old_name); + if (iter != canonical_name_map.end()) { + return iter->second; + } + + string new_name = tensorflow::strings::StrCat("tmp_", index++); + canonical_name_map[old_name] = new_name; + return new_name; + } + void Clear() { + canonical_name_map.clear(); + index = 0; + } + + private: + int64 index; + tensorflow::gtl::FlatMap canonical_name_map; }; -// HLO instructions are the IR used by the high-level compiler. +// HLO instructions are the atomic unit of the high-level compiler's IR. +// +// HloInstructions live inside of an HloComputation, which is analogous to a +// function in other programming languages. Nodes have no total order within +// their computation. Instead, they have a partial ordering determined by their +// data and control dependencies. +// +// HLO does not have basic blocks or explicit "branch" instructions. Instead, +// certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode +// control flow. For example, the kConditional HLO executes one of two possible +// computations, depending on the runtime value of a predicate. +// +// HLO is pure (mostly). It has no concept of mutable state. Instead, data +// values are produced by one HLO and flow into consumers across dependency +// edges. class HloInstruction { public: + // A fusion node computes the same value a call to its fusion computation + // would compute. However, the choice of fusion kind dictates codegen + // strategy for the backend. + // + // To generate code for a kFusion HloInstruction, most backends do something + // like the following: + // + // 1) Identify the "primary" HloInstruction of the fused computation. + // 2) Emit code that does the work of the primary node, creating its inputs + // and transforming its outputs as specified by the fused computation. + // + // In step (2), the code emitted is usually similar to the code that would be + // emitted for an *unfused* version of the primary node, except that + // + // - when the primary node reads an element of one of its operands, instead + // of loading the value from memory, it *computes* the value based on the + // contents of the fused computation. + // - when the primary node outputs a value, instead of storing it to memory, + // it forwards the value to its users, which then perform additional + // computations before the value is finally stored to memory at the root of + // the fusion node. + // + // An HloInstruction's FusionKind helps us find the kFusion instruction's + // primary node, and can also affect how we generate code in step (2). + // + // - kInput: The primary node is the root of the fused instruction. + // + // - kOutput: The primary node is not the root of the fused instruction. + // This fusion kind requires that one operand buffer of the fusion + // instruction be able to alias the output buffer. This constraint is + // usually enough to let backends find the primary node unambiguously. + // + // - kLoop: The primary node is the root of the fused computation, but, + // unlike in input fusion, we prescribe a specific implementation for + // codegen. Rather than generating code that looks like the code we'd emit + // for an unfused version of the primary/root node, we emit code that + // generates one element of the root at a time. + // + // - kCustom: Custom category for backend-specific fusions that don't fit + // into the above patterns. + // + // Not all backends support all fusion kinds, and given a particular fused + // computation, it's not in general safe to change its fusion kind. Creation + // of fusion nodes is always backend-specific. + // + // For elementwise ops (e.g. kAdd), most backends would emit a + // one-element-at-a-time implementation for the unfused version, so loop + // fusion and input fusion are probably equivalent if the root node is + // elementwise. They're not necessarily equivalent e.g. for kReduce, where an + // implementation might emit something more sophisticated for an unfused or + // input-fusion reduce, but will emit the naive code that reduces one element + // at a time for loop fusion with a reduce as the root. + // + // Another way to think of loop fusion is that it's equivalent to input + // fusion, but where the root node is an implicit identity node, whose + // unfused implementation is "read one element, write one element". + // + // TODO(b/79869434): This categorization scheme is not great. For one thing, + // input and loop fusion are basically the same thing: There is no reason for + // the HLO to encode backend-specific decisions about how e.g. a reduce that's + // the root of a fusion should be lowered. In addition, this scheme as + // written doesn't work for multi-output fusion, where the primary node is + // never actually the root (which is a kTuple instruction that gathers the + // multiple outputs of the fusion). enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, + kInput, + kOutput, + kCustom, }; - ~HloInstruction(); + virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // - // module: the module which will contain the instruction. The newly created - // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. - // instruction_map: a map from instruction name to HloInstruction*. This map + // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. static StatusOr> CreateFromProto( - HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation); + const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -283,10 +426,27 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -457,6 +617,13 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Creates a kDomain instruction which delimits an HLO domain which have + // the provided user and operand side metadata. + static std::unique_ptr CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -498,15 +665,25 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a token instruction used for joining or creating token types which + // thread through side-effecting operations. + static std::unique_ptr CreateGenerateToken( + tensorflow::gtl::ArraySlice operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims); + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + // Returns true if this instruction has a side effect, irrespective of whether + // any called computations may contain an instruction with side effects. + bool HasSideEffectNoRecurse() const; + // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. @@ -531,6 +708,10 @@ class HloInstruction { using InstructionVector = tensorflow::gtl::InlinedVector; const InstructionVector& operands() const { return operands_; } + // Returns the vector of unique operands, in the same order they are found + // within the operand vector. + InstructionVector unique_operands() const; + // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; @@ -561,6 +742,18 @@ class HloInstruction { // 'instruction'. Status RemoveControlDependencyTo(HloInstruction* instruction); + // Drops all control predecessors and successors from this HLO instruction. + Status DropAllControlDeps(); + + // Copies the control predecessors and successors on this HLO instruction to + // `inst`. Does not do a deep copy so this makes sense only if `inst` and + // this HLO are in the same module. + // + // Depending on the use cases we see in practice, in the future we may + // consider folding the logic here into Clone, CloneWithNewOperands and + // ReplaceAllUsesWith by treating control dependencies like data dependencies. + Status CopyAllControlDepsFrom(const HloInstruction* inst); + // Returns the set of control predecessors (successors) of this // instruction. Control predecessors (successors) must execute before (after) // the current instruction. @@ -589,10 +782,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -607,15 +798,16 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + if (backend_config_ != other.backend_config_) { + return false; + } + + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; - // Returns whether this instruction does a rank-2 transposition. - bool IsRank2Transpose() const; - // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. @@ -635,6 +827,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -682,35 +876,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the literal associated with this instruction. - // - // Note: only constant and parameter opcodes have an associated literal. - const Literal& literal() const; - - // Returns the parameter number associated with this instruction. - // - // Note: only parameter opcodes have an associated parameter number. - int64 parameter_number() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_number_; - } - - // Returns the dimension sizes or numbers associated with this instruction. - // - // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, - // and reverse. - const std::vector& dimensions() const; - int64 dimensions(int64 index) const; - - // Accessor for the dimension in which a concatenate HLO should occur. - // Precondition: opcode() == HloOpcode::kConcatenate - int64 concatenate_dimension() const; - - // Returns the tuple index associated with this instruction. - // - // Precondition: opcode() == HloOpcode::kGetTupleElement - int64 tuple_index() const; - // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. @@ -805,7 +970,7 @@ class HloInstruction { string ToShortString() const; // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const; + virtual HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". @@ -817,26 +982,11 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - // - // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv - int64 channel_id() const { return channel_id_; } - - // Returns feature_index field associated with the instruction. The index - // represents the index of the feature dimension. + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - int64 feature_index() const { return feature_index_; } - - // Returns a epsilon value associated with the instruction. The is a small - // number added to the variance to avoid divide-by-zero error. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - float epsilon() const { return epsilon_; } + // Precondition: opcode() == HloOpcode::kHostCompute + string channel_name() const { return channel_name_; } // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) @@ -844,78 +994,14 @@ class HloInstruction { string infeed_config() const { return infeed_config_; } void set_infeed_config(const string& config) { infeed_config_ = config; } - // Returns a tag to be used in tracing. - // - // Precondition: opcode() == HloOpcode::kTrace - string TracingTag() const; - - // Returns whether the instruction is a constant. - bool IsConstant() const; - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the root instruction of the fused expression contained within this - // fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_expression_root() const; - - // Returns the list of fused instructions inside this fusion instruction. The - // returned type is a range of HloInstruction*s. - // - // Precondition: opcode() == HloOpcode::kFusion - const tensorflow::gtl::iterator_range>::const_iterator>> - fused_instructions() const; - - const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> - fused_instructions(); - - // Gets the number of instructions inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - int64 fused_instruction_count() const; - - // Returns the fused parameter instruction in this fusion instruction - // corresponding to the given parameter number. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_parameter(int64 parameter_number) const; - - // Returns the vector of fused parameters inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - const std::vector& fused_parameters() const; - - // Returns true if this instruction is a fusion instruction that generates - // multiple outputs. - const bool IsMultiOutputFusion() const { - return opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple; - } - - FusionKind fusion_kind() const { - CHECK_EQ(HloOpcode::kFusion, opcode_); - return fusion_kind_; - } - - void set_fusion_kind(FusionKind kind) { - CHECK_EQ(HloOpcode::kFusion, opcode_); - fusion_kind_ = kind; - } - // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. const HloSharding& sharding() const { @@ -926,102 +1012,53 @@ class HloInstruction { const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } + // Returns the sharding unique device, if any. + tensorflow::gtl::optional sharding_unique_device() const { + if (sharding_ == nullptr) { + return tensorflow::gtl::optional(); + } + auto device = sharding_->UniqueDevice(); + return device.ok() ? device.ValueOrDie() + : tensorflow::gtl::optional(); + } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + void set_single_sharding(const HloSharding& sharding); + // Sets a sharding that assigns the current instruction to device. + void set_device_sharding(int64 device) { + set_single_sharding(HloSharding::AssignDevice(device)); + } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } - - // Adds a new operand the fusion instruction. - HloInstruction* AddFusionOperand(HloInstruction* new_operand); - - // Merges the fused instructions from 'instruction_to_merge' into the - // fused instruction set of 'this', updating operands as necessary. - // - // Precondition: opcode() == HloOpcode::kFusion - // Predondition: 'instruction_to_merge' must be an operand of 'this'. - void MergeFusionInstruction(HloInstruction* instruction_to_merge); - - // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. - // All the users of instruction_to_merge will be redirected to 'this' - // instruction. instruction_to_merge will be removed from its parent - // computation. - // - // Precondition: opcode() == HloOpcode::kFusion - void MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge); - - // Fuses the given instruction in this fusion instruction. instruction_to_fuse - // is cloned and the clone is placed in the fusion - // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather - // than moved to cleanly handle the case where the instruction has a use - // outside the fusion instruction. Moving such an instruction into a fusion - // instruction would violate the single-result invariant of HLO instructions - // and significantly complicate code generation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse); - } - - // Fuses the given instruction in this fusion instruction and generate - // multioutput fusion instruction. A clone of the instruction_to_fuse will - // be part of the output of fusion instructions. The users of - // instruction_to_fuse will be redirected to this fusion instructions. - // instruction_to_fuse will be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionIntoMultiOutput( - HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); - } - - // Returns the start index in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_starts(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_starts_[dimension]; + // Checks whether the instruction has compatible sharding with the other + // instruction. + bool has_compatible_sharding(const HloInstruction* other) const { + if (!has_sharding()) { + return !other->has_sharding(); + } + return other->has_sharding() ? sharding() == other->sharding() : false; } - const std::vector& slice_starts() const { return slice_starts_; } - // Returns the (exclusive) limit index in the given dimension for a slice - // node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_limits(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_[dimension]; + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; } - const std::vector& slice_limits() const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_; + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; } - // Returns the stride in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_strides(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_strides_[dimension]; - } - const std::vector& slice_strides() const { return slice_strides_; } - - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } + // When creating a new instruction which either replaces, or shifts up (kCopy + // insertion case), another instruction, we need to make sure the certain + // properties of the new instruction are copied into the derived one. As of + // today, the metadata and sharding will be propagated to the derived + // instruction. + void SetupDerivedInstruction(HloInstruction* derived_instruction) const; // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1036,22 +1073,6 @@ class HloInstruction { return dynamic_slice_sizes_; } - // Returns the number of exponent bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 exponent_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return exponent_bits_; - } - - // Returns the number of mantissa bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 mantissa_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return mantissa_bits_; - } - // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -1089,19 +1110,6 @@ class HloInstruction { MakeUnique(dnums); } - FftType fft_type() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_type_; - } - - const std::vector& fft_length() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_length_; - } - - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1124,28 +1132,19 @@ class HloInstruction { // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; - // Returns the random distribution for this rng node. - // - // Precondition: opcode() == HloOpcode::kRng - RandomDistribution random_distribution() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of // the instruction to form the name of the cloned instruction. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + // Ignores the control predecessors and successors of this HLO instruction. + std::unique_ptr Clone( + const string& suffix = "clone", HloCloneContext* context = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). std::unique_ptr CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1215,9 +1214,14 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Gets/sets the string identifier for this instruction. + // Gets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } + + // Sets the string identifier for this instruction. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + void SetAndSanitizeName(const string& name) { + name_ = NameUniquer::GetSanitizedName(name); + } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1234,6 +1238,40 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); + } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); + } + + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1255,6 +1293,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1263,82 +1302,235 @@ class HloInstruction { void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); - // Change the layout for an Constant Hlo instruction to match new_layout. For - // tuple shaped constants shape_index is the path to the internal array - // subshape whose layout needs to be changed. + // Old methods kept for smooth subclassing transition BEGIN. + // TODO(b/80131774): Remove this code. + + // Delegates to HloBatchNormInstruction::feature_index. + int64 feature_index() const; + + // Delegates to HloBatchNormInstruction::epsilon. + float epsilon() const; + + // Delegates to HloFftInstruction::fft_type. + FftType fft_type() const; + + // Delegates to HloFftInstruction::fft_length. + const std::vector& fft_length() const; + + // Delegates to HloSendRecvInstruction::channel_id. + int64 channel_id() const; + + // Returns the dimension sizes or numbers associated with this instruction. + virtual const std::vector& dimensions() const { + LOG(FATAL) << "Unimplemented method."; + } + virtual int64 dimensions(int64 index) const { + LOG(FATAL) << "Unimplemented method."; + } + + // Delegates to HloConcatenateInstruction::concatenate_dimension. + int64 concatenate_dimension() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Delegates to HloSliceInstruction::slice_start. + int64 slice_starts(int64 dimension) const; + const std::vector& slice_starts() const; + + // Delegates to HloSliceInstruction::slice_limits. + int64 slice_limits(int64 dimension) const; + const std::vector& slice_limits() const; + + // Delegates to HloSliceInstruction::slice_strides. + int64 slice_strides(int64 dimension) const; + const std::vector& slice_strides() const; + + // Delegates to HloSliceInstruction::IsInPlaceSlice. + bool IsInPlaceSlice() const; + + // Returns the literal associated with this instruction. + const Literal& literal() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); - private: - enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Delegates to HloTraceInstruction::TracingTag. + string TracingTag() const; + + // Delegates to HloFusionInstruction::AddFusionOperand. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Delegates to HloFusionInstruction::MergeFusionInstruction. + void MergeFusionInstruction(HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. + void MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::FuseInstruction. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::fused_instruction. + HloComputation* fused_instructions_computation() const; + + // Delegates to HloFusionInstruction::fused_expression_root. + HloInstruction* fused_expression_root() const; + + // Delegates to HloFusionInstruction::fused_instructions. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Delegates to HloFusionInstruction::fused_instruction_count. + int64 fused_instruction_count() const; + + // Delegates to HloFusionInstruction::fused_parameter. + HloInstruction* fused_parameter(int64 parameter_number) const; + // Delegates to HloFusionInstruction::fused_parameters. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const; + + // Delegates to HloFusionInstruction::fusion_kind. + FusionKind fusion_kind() const; + + // Delegates to HloFusionInstruction::set_fusion_kind. + void set_fusion_kind(FusionKind kind); + + // Delegates to HloRngInstruction::random_distribution. + RandomDistribution random_distribution() const; + + // Delegates to HloParameterInstruction::parameter_number. + int64 parameter_number() const; + + // Delegates to HloGetTupleElementInstruction::tuple_index. + int64 tuple_index() const; + + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const; + + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const; + // Old methods kept for smooth subclassing transition END. + + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + + protected: + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + + void RemoveOperandAt(int index) { + operands_.erase(operands_.begin() + index); + } + + void AppendComputation(HloComputation* computation) { + called_computations_.push_back(computation); + } + + void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + + private: + // Implementation for non-common logic of CloneWithNewOperands. + virtual std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + // TODO(b/80131774): This should be pure virtual. + LOG(FATAL) << "Unimplemented method."; + } + + // Implementation for non-common logic of ExtraAttributesToString. + virtual std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {}; + } + + // Implementation for IsElementwise if operand_idx is nullopt and for + // IsElementwiseOnOperand if otherwise. + // + // NOTE: For all instructions other than kFusion, being elementwise on one of + // the operands is equivalent to being elementwise on all the operands. + virtual bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const; + // Prints an instruction to a string. + // + // The canonical string representation needs to name operands and instruction + // names in a consistent way. This is implemented through the + // canonical_name_map. + string ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Prints an operand to a string. + virtual string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and + // OperandsToStringWithCanonicalNameMap() functions. + friend class HloComputation; + // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. - bool IdenticalSlowPath( + virtual bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); - // Appends operand to the list of operands and adds this instruction as a user - // of the operand. - void AppendOperand(HloInstruction* operand); - // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. - HloInstruction(HloOpcode opcode, const Shape& shape); - - // Fuses the given instruction into this fusion instruction. When add_output - // is false (which is the default), instruction_to_fuse is cloned and the - // clone is placed in the fusion instruction. instruction_to_fuse is - // unchanged. - // - // When add_output is true, a clone of the instruction_to_fuse will be part - // of the output of fusion instructions. The users of instruction_to_fuse - // will be redirected to this fusion instructions. instruction_to_fuse will - // be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones the given instruction_to_fuse and insert the clone into this fusion - // instruction. If add_output is true, a clone of instruction_to_fuse will - // be in the output of the this fusion instruction (part of the tuple of the - // fusion root). - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones a fusion instruction with a new shape and operands. - std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; - - // Returns true if this instruction can legally have the dimensions field - // set. Used for checking precondition of dimensions field accessors. - bool CanHaveDimensionsField() const; - // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. @@ -1369,16 +1561,6 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Literal, only present for kConstant. - std::unique_ptr literal_; - - // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = -1; - - // Dimensions present for some operations that require reshaping or - // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. - std::vector dimensions_; - // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; @@ -1391,24 +1573,6 @@ class HloInstruction { std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // Describes FFT type for an FFT instruction. - FftType fft_type_ = FftType::FFT; - - // Indicates the FFT length for an FFT instruction. - std::vector fft_length_; - - // Describes the [begin, end) index range for a slice. - std::vector slice_starts_; - std::vector slice_limits_; - std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; - - // The bit sizes for a reduce-precision operation. - int32 exponent_bits_ = 0; - int32 mantissa_bits_ = 0; - // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; @@ -1417,14 +1581,12 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // The type of the fusion. Used by kFusion only. - FusionKind fusion_kind_; - // The sharding, if one exists. std::unique_ptr sharding_; - // For parameter instructions this field holds the parameter number. - int64 parameter_number_ = 0; + // Fields used by the kDomain instruction. + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1433,7 +1595,7 @@ class HloInstruction { string channel_name_; // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_; + int64 cost_estimate_ns_ = 0; // Computations called by this instruction. std::vector called_computations_; @@ -1463,24 +1625,18 @@ class HloInstruction { // an operand. HloInstruction* trace_instruction_ = nullptr; - // The distribution requested for random number generation. - // Only present for kRng. - RandomDistribution distribution_; - - // A small float number added to the variance to avoid divide-by-zero error. - // Only present for kBatchNormTraining. - float epsilon_ = 0.0f; + // The string representation of the infeed configuration. + string infeed_config_; - // An integer value representing the index of the feature dimension. - // Only present for kBatchNormTraining. - int64 feature_index_ = -1; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. - int64 channel_id_ = -1; + // The group id of each replica for CrossReplicaSum. + std::vector replica_group_ids_; - // The string representation of the infeed configuration. - string infeed_config_; + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; // String identifier for instruction. string name_; @@ -1504,6 +1660,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1512,13 +1671,20 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // an HloInstruction* or a const HloInstruction*. // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of -// the hlo. +// the hlo. Exception: null pointer values compare less than non-null. // // Note that this cannot be used for HLO instructions across multiple modules // since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } return lhs->unique_id() < rhs->unique_id(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 32d3ed272bd6b239918076999ecae6c1b3ded2fd..5d6f8b931f0c665fba03e1c845214fa83aabf12e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,11 +24,13 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -149,8 +151,8 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); EXPECT_THAT(foo->users(), UnorderedElementsAre(add)); @@ -186,8 +188,8 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -219,8 +221,8 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -254,8 +256,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(addtotal->Accept(&visitor)); @@ -303,8 +305,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(neg2->Accept(&visitor)); @@ -325,7 +327,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - HloModule module(TestName()); + auto module = CreateNewModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -335,15 +337,15 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(map->Accept(&visitor)); @@ -373,13 +375,13 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - HloModule module(TestName()); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto module = CreateNewModule(); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( @@ -387,7 +389,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(reduce->Accept(&visitor)); @@ -414,8 +416,8 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -449,8 +451,8 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar)); @@ -477,8 +479,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log)); @@ -514,8 +516,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -544,8 +546,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(2, bar->user_count()); @@ -609,8 +611,8 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -627,8 +629,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -645,8 +647,8 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -667,8 +669,8 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -690,8 +692,8 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -746,13 +748,13 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloModule module(TestName()); + auto module = CreateNewModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape, "param")); - return module.AddEmbeddedComputation(builder.Build()); + return module->AddEmbeddedComputation(builder.Build()); }; HloComputation* computation_x = make_map_computation(); @@ -767,7 +769,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {map_3_y}, HloInstruction::FusionKind::kLoop); @@ -814,8 +816,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -940,8 +942,8 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); int visit_num = 0; std::unordered_map visit_order; @@ -969,8 +971,8 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); for (int i = 0; i < add->operand_count(); ++i) { @@ -978,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); @@ -1013,8 +1032,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1056,8 +1075,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1099,10 +1118,10 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); + {dot, reshape}, HloInstruction::FusionKind::kLoop); auto fusion2 = fusion->Clone(); const HloInstruction* root = fusion->fused_expression_root(); @@ -1118,7 +1137,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { } TEST_F(HloInstructionTest, FusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1128,7 +1147,7 @@ TEST_F(HloInstructionTest, FusionEquality) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); auto neg = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); auto* fusion2 = computation->CreateFusionInstruction( @@ -1140,7 +1159,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1166,10 +1185,10 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { data_shape, HloOpcode::kSubtract, dot, add_operand)); builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); auto nested_fusion = computation->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + {dot, b_t}, HloInstruction::FusionKind::kLoop); auto fusion = computation->CreateFusionInstruction( {add, nested_fusion}, HloInstruction::FusionKind::kOutput); @@ -1244,15 +1263,8 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); - HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - - EXPECT_EQ( - fusion->ToString(options), - "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); @@ -1271,7 +1283,7 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } -TEST_F(HloInstructionTest, StringifyGather) { +TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); Shape gather_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); @@ -1291,11 +1303,12 @@ TEST_F(HloInstructionTest, StringifyGather) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " @@ -1303,7 +1316,313 @@ TEST_F(HloInstructionTest, StringifyGather) { "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " "gather_dims_to_operand_dims={0,1,2,3,4}, " - "window_bounds={30,29,28,27,26}"); + "index_vector_dim=4, window_bounds={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, StringifyGather_1) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=2, window_bounds={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().Canonical(); + + EXPECT_EQ(dot->ToString(options), + "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_EQ( + fusion->ToString(options), + R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + HloInstruction* loop = builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ(loop->ToString(options), + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ( + conditional->ToString(options), + R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, false_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CheckDeepClone) { + const char* const hlo_string = R"( +HloModule Module + +addy (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT zadd = s32[] add(lhs, rhs) +} + +calla (x: s32[]) -> s32[] { + x = s32[] parameter(0) + reduce = s32[] reduce-window(x, x), to_apply=addy + ROOT xadd = s32[] add(x, reduce) +} + +body (bparam: s32[]) -> s32[] { + constant = s32[] constant(1) + bparam = s32[] parameter(0) + v = s32[] call(bparam), to_apply=calla + ROOT add = s32[] add(constant, bparam) +} + +condition (cparam: s32[]) -> pred[] { + xconstant = s32[] constant(5) + cparam = s32[] parameter(0) + ROOT greater-than = pred[] greater-than(xconstant, cparam) +} + +ENTRY entry (param: s32[]) -> s32[] { + eparam = s32[] parameter(0) + ROOT while = s32[] while(eparam), condition=condition, body=body + } +)"; + // Check that deep clones really deep clones every instruction and + // computations, without leaving dangling pointers to the old module. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + std::unique_ptr clone = module->Clone(); + for (HloComputation* computation : clone->computations()) { + EXPECT_EQ(computation->parent(), clone.get()); + for (HloInstruction* instruction : computation->instructions()) { + EXPECT_EQ(instruction->parent()->parent(), clone.get()); + } + } +} + +TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { + const Shape shape = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder("test"); + HloInstruction* p = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + + EXPECT_TRUE(add1->Identical(*add2)); + add1->set_raw_backend_config_string("abc"); + EXPECT_FALSE(add1->Identical(*add2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc new file mode 100644 index 0000000000000000000000000000000000000000..d326d5d009fa3a378cfe92e1b081445f09d982e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -0,0 +1,1287 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, + const HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + return std::all_of( + operand_indices.begin(), operand_indices.end(), + [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); +} +} // namespace + +HloBatchNormInstruction::HloBatchNormInstruction( + HloOpcode opcode, const Shape& shape, HloInstruction* operand, + HloInstruction* scale, float epsilon, int64 feature_index) + : HloInstruction(opcode, shape), + epsilon_(epsilon), + feature_index_(feature_index) { + AppendOperand(operand); + AppendOperand(scale); +} + +bool HloBatchNormInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return feature_index() == casted_other.feature_index() && + epsilon() == casted_other.epsilon(); +} + +HloInstructionProto HloBatchNormInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + return proto; +} + +std::vector HloBatchNormInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("epsilon=", epsilon()), + StrCat("feature_index=", feature_index())}; +} + +HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); +} + +std::unique_ptr +HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), + feature_index()); +} + +HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); + AppendOperand(mean); + AppendOperand(variance); +} + +std::unique_ptr +HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloBatchNormGradInstruction::HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale, + epsilon, feature_index) { + AppendOperand(mean); + AppendOperand(variance); + AppendOperand(grad_output); +} + +std::unique_ptr +HloBatchNormGradInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloFftInstruction::HloFftInstruction( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) + : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { + fft_length_.assign(fft_length.begin(), fft_length.end()); + AppendOperand(operand); +} + +HloInstructionProto HloFftInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } + return proto; +} + +std::vector HloFftInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("fft_type=", FftType_Name(fft_type())), + StrCat("fft_length={", Join(fft_length(), ","), "}")}; +} + +bool HloFftInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return fft_type() == casted_other.fft_type() && + fft_length() == casted_other.fft_length(); +} + +std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], fft_type_, + fft_length_); +} + +HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, + const Shape& shape, + int64 channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +HloInstructionProto HloSendRecvInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_id(channel_id_); + return proto; +} + +std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("channel_id=", channel_id_)}; +} + +bool HloSendRecvInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +// Send instruction produces a tuple of {aliased operand, U32 context}. +HloSendInstruction::HloSendInstruction(HloInstruction* operand, + int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kSend, + ShapeUtil::MakeTupleShape( + {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), + channel_id) { + AppendOperand(operand); +} + +std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(new_operands[0], channel_id()); +} + +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloSendDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +// Recv instruction produces a tuple of {receive buffer, U32 context}. +HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kRecv, + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + channel_id) {} + +std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique( + ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); +} + +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::GetTupleElementShape(operand->shape(), 0), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloRecvDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +HloReverseInstruction::HloReverseInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kReverse, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloReverseInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReverseInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReverseInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloConcatenateInstruction::HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) + : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloConcatenateInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloConcatenateInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloConcatenateInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, + dimensions(0)); +} + +HloReduceInstruction::HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduce, shape), + dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { + AppendOperand(arg); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + // Reduction results are determined by the reduction dimension and the + // reduction computation. + return dimensions() == casted_other.dimensions() && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dimensions(), to_apply()); +} + +HloTransposeInstruction::HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kTranspose, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; + AppendOperand(operand); +} + +bool HloTransposeInstruction::IsRank2Transpose() const { + return dimensions() == std::vector({1, 0}) && + shape().dimensions_size() == 2 && + std::equal(shape().dimensions().begin(), shape().dimensions().end(), + operand(0)->shape().dimensions().rbegin()); +} + +HloInstructionProto HloTransposeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloTransposeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloTransposeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloBroadcastInstruction::HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension) + : HloInstruction(HloOpcode::kBroadcast, shape), + dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloBroadcastInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloBroadcastInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloMapInstruction::HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands) + : HloInstruction(HloOpcode::kMap, shape) { + CHECK(static_operands.empty()) << "static_operands not yet supported"; + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(map_computation); + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + dimensions_.resize(ShapeUtil::Rank(shape)); + std::iota(dimensions_.begin(), dimensions_.end(), 0); +} + +HloInstructionProto HloMapInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +bool HloMapInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (!dimensions().empty()) { + // Check that the map is executed in elementwise compatible dimensions. + if (dimensions().size() != shape().dimensions_size()) { + return false; + } + for (int i = 0; i < dimensions().size(); ++i) { + if (dimensions()[i] != i) { + return false; + } + } + } + return true; +} + +std::vector HloMapInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloMapInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return eq_computations(to_apply(), other.to_apply()); +} + +std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, to_apply()); +} + +HloSliceInstruction::HloSliceInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) + : HloInstruction(HloOpcode::kSlice, shape), + slice_starts_(start_indices.begin(), start_indices.end()), + slice_limits_(limit_indices.begin(), limit_indices.end()), + slice_strides_(strides.begin(), strides.end()) { + AppendOperand(operand); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (slice_strides_.empty()) { + slice_strides_ = std::vector(start_indices.size(), 1LL); + } +} + +HloInstructionProto HloSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + return proto; +} + +std::vector HloSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector bounds; + bounds.reserve(slice_starts_.size()); + const bool omit_stride = + std::all_of(slice_strides_.begin(), slice_strides_.end(), + [](int64 stride) { return stride == 1; }); + for (int i = 0; i < slice_starts_.size(); ++i) { + string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); + } + return {StrCat("slice={", Join(bounds, ", "), "}")}; +} + +bool HloSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return slice_starts_ == other_slice.slice_starts_ && + slice_limits_ == other_slice.slice_limits_ && + slice_strides_ == other_slice.slice_strides_; +} + +std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], slice_starts_, + slice_limits_, slice_strides_); +} + +HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) + : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), + literal_(std::move(literal)) {} + +HloInstructionProto HloConstantInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloConstantInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Shape* mutable_array_subshape = + ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); + CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + + // Normally array_subshape will always have a layout, but this invariant is + // temporarily broken in LayoutAssignment::AssignLayouts. + + if (!mutable_array_subshape->has_layout() || + !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { + literal_ = literal_->Relayout(new_layout, shape_index); + *mutable_array_subshape->mutable_layout() = new_layout; + } +} + +bool HloConstantInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return literal() == other_slice.literal(); +} + +std::unique_ptr +HloConstantInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(literal_->CloneToUnique()); +} + +string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string operands; + // For constants, show the actual value in place of an empty operand list. + if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants()) { + // Literal::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = literal().ToString(); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Do not show large constants or tuples. + operands = "{...}"; + } + return operands; +} + +HloTraceInstruction::HloTraceInstruction(const string& tag, + HloInstruction* operand) + : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), + literal_(Literal::CreateR1U8(tag)) { + AppendOperand(operand); + operand->set_tracing(this); +} + +HloInstructionProto HloTraceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloTraceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); +} + +HloFusionInstruction::HloFusionInstruction(const Shape& shape, + FusionKind fusion_kind, + HloInstruction* fused_root) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + CHECK(fused_root != nullptr); + SetAndSanitizeName("fusion"); + set_parent(fused_root->parent()); + set_metadata(fused_root->metadata()); + CloneAndFuseInternal(fused_root); +} + +HloFusionInstruction::HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + for (auto operand : operands) { + AppendOperand(operand); + } + SetAndSanitizeName("fusion"); + AppendComputation(fusion_computation); + fusion_computation->SetFusionInstruction(this); +} + +HloInstructionProto HloFusionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fusion_kind(xla::ToString(fusion_kind())); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); + return proto; +} + +bool HloFusionInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (fusion_kind() != FusionKind::kLoop) { + return false; + } + + if (!operand_idx.has_value()) { + for (auto* fused : fused_instructions()) { + if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { + return false; + } + } + return true; + } + // A loop-fusion is elementwise on an operand if all operations (computed + // using BFS) between the operand and the fused root are elementwise. + std::deque worklist; + std::unordered_set visited; + worklist.push_back(fused_parameter(operand_idx.value())); + visited.insert(fused_parameter(operand_idx.value())); + while (!worklist.empty()) { + HloInstruction* operand = worklist.front(); + worklist.pop_front(); + for (HloInstruction* user : operand->users()) { + CHECK_GE(user->unique_id(), 0); + if (ContainsKey(visited, user)) { + continue; + } + if (user->IsElementwise() || + IsInstructionElementwiseOnOperand(user, operand)) { + worklist.push_back(user); + visited.insert(user); + } else { + return false; + } + } + } + return true; +} + +HloInstruction* HloFusionInstruction::AddFusionOperand( + HloInstruction* new_operand) { + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + +void HloFusionInstruction::MergeFusionInstruction( + HloFusionInstruction* instruction_to_merge) { + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); + // Clone the instruction from which to merge fused instructions. + std::unique_ptr cloned = instruction_to_merge->Clone(); + HloFusionInstruction* cloned_fusion = + static_cast(cloned.get()); + // Replace uses of fused parameters with the corresponding operand of the + // fusion. Add all non-parameter fused instructions to + // 'unfused_instructions' to be merged into 'this'. This is done in reverse + // post order. + std::vector unfused_instructions; + auto fused_instructions = cloned_fusion->fused_instructions_computation() + ->MakeInstructionPostOrder(); + for (auto fused_it = fused_instructions.rbegin(); + fused_it != fused_instructions.rend(); ++fused_it) { + auto fused_instruction = *fused_it; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + TF_CHECK_OK( + fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand( + fused_instruction->parameter_number()))); + } else { + unfused_instructions.push_back(fused_instruction); + } + } + CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root()); + // Replace instruction_to_merge use of 'this' with unfused_root. + TF_CHECK_OK( + instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); + // Fuse 'unfused_instructions' into 'this'. + for (auto& instruction : unfused_instructions) { + FuseInstruction(instruction); + instruction->DetachFromOperands(); + } + CHECK_EQ(0, cloned_fusion->user_count()); + cloned_fusion->DetachFromOperands(); + TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( + cloned_fusion->fused_instructions_computation())); +} + +void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge) { + // Add all non-parameter fused instructions to 'unfused_instructions' to be + // merged into 'this'. `old_to_new' maps the instructions in the fused node + // to the disaseembled fusion instructions. + // Note that we add the unfused instructions to this->parent_ computation. + // This is necessary because the unique_id needs for an instruction and + // it's only added when inserting to the computation. + tensorflow::gtl::FlatMap old_to_new; + std::vector unfused_instructions; + auto computation_to_merge = + instruction_to_merge->fused_instructions_computation(); + auto post_order = computation_to_merge->MakeInstructionPostOrder(); + for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { + auto fused_instruction = *rit; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + InsertOrDie(&old_to_new, fused_instruction, + instruction_to_merge->mutable_operand( + fused_instruction->parameter_number())); + continue; + } + + // Here we clone the insertion and call FuseInstructionIntoMultiOutput() + // which clones again. This can be improved. + auto cloned_instruction = + parent()->AddInstruction(fused_instruction->Clone()); + unfused_instructions.push_back(cloned_instruction); + InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); + } + for (auto unfused_instruction : unfused_instructions) { + for (int64 index = 0; index < unfused_instruction->operand_count(); + index++) { + auto new_operand = + FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); + TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); + } + } + + HloInstruction* unfused_root = unfused_instructions.front(); + TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); + + TF_CHECK_OK( + instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); + if (GetModule()) { + TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); + } + + // Fuse the root instruction and generate multiple outputs. + FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); + // The rest instructions are of normal fusing. + for (int64 i = 1; i < unfused_instructions.size(); i++) { + auto instruction = unfused_instructions[i]; + FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); + } +} + +HloComputation* HloFusionInstruction::fused_instructions_computation() const { + CHECK(!called_computations().empty()); + auto* fused_instructions_computation = called_computations().front(); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; + return fused_instructions_computation; +} + +HloInstruction* HloFusionInstruction::fused_expression_root() const { + return fused_instructions_computation()->root_instruction(); +} + +HloInstruction* HloFusionInstruction::fused_parameter( + int64 parameter_number) const { + return fused_instructions_computation()->parameter_instruction( + parameter_number); +} + +const std::vector& HloFusionInstruction::fused_parameters() + const { + return fused_instructions_computation()->parameter_instructions(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloFusionInstruction::fused_instructions() const { + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloFusionInstruction::fused_instructions() { + return fused_instructions_computation()->instructions(); +} + +int64 HloFusionInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + +HloInstruction* HloFusionInstruction::FuseInstructionInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + // When add_output is false, this fusion instruction must be a user of + // instruction_to_fuse. + if (!add_output) { + CHECK(IsUserOf(instruction_to_fuse)); + } + HloInstruction* fused_instruction = + CloneAndFuseInternal(instruction_to_fuse, add_output); + return fused_instruction; +} + +HloInstruction* HloFusionInstruction::CloneAndFuseInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); + HloInstruction* clone = nullptr; + if (called_computations().empty()) { + // New fusion instruction. It should not be a multioutput instruction. + CHECK(!add_output); + auto builder = HloComputation::Builder("fused_computation", this); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + AppendComputation( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); + clone = fused_expression_root(); + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + // When add_output is false, instruction_to_fuse is necessarily an operand + // of the fusion instruction. After fusion this will no longer be the + // case. Remove the operand from the operand list and remove its + // corresponding fused parameter instruction. Renumber parameters as + // necessary to make parameter numbers consistent with their index in the + // fused_parameter_ vector. + bool in_operand_list = std::find(operands().begin(), operands().end(), + instruction_to_fuse) != operands().end(); + CHECK(add_output || in_operand_list); + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + if (instruction_to_fuse == operand(operand_num)) { + // replace the fused parameter instruction's uses with the clone. + HloInstruction* fused_parameter = fused_parameters[operand_num]; + TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); + + // Remove the corresponding fused parameter and operand from their + // respective vectors. + TF_CHECK_OK( + fused_instructions_computation()->RemoveParameter(operand_num)); + RemoveOperandAt(operand_num); + break; + } + } + // We've cloned instruction_to_fuse into this fusion instruction, so this + // fusion instruction is no longer a use of instruction_to_fuse. + if (in_operand_list) { + DetachFrom(instruction_to_fuse); + // When the instruction_to_fuse does not have other users, we don't need + // to generate a multioutput fusion instruction. + if (instruction_to_fuse->user_count() == 0) { + add_output = false; + } + } + } + + // Reread the parameters in the computation. + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + + // Add each operand of the clone as an operand of the fusion instruction. A + // complication is that some clone operands may already be operands of the + // fusion instruction. + for (int64 operand_num = 0; operand_num < clone->operand_count(); + ++operand_num) { + HloInstruction* operand = clone->mutable_operand(operand_num); + + // See if this operand is already an operand of the fusion node. + CHECK_EQ(operands().size(), fused_parameters.size()); + HloInstruction* fused_param = nullptr; + for (int64 i = 0; i < operands().size(); ++i) { + if (this->operand(i) == operand) { + fused_param = fused_parameters[i]; + break; + } + } + + if (fused_param == nullptr) { + // Clone's operand was not already an operand of the fusion + // instruction. Add it as an operand and add a corresponding fused + // parameter instruction. + fused_param = AddFusionOperand(operand); + } + TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); + } + + if (add_output) { + CHECK_GT(instruction_to_fuse->user_count(), 0); + // If this is already a multioutput fusion instruction, expand the root + // tuple by 1. + HloInstruction* fused_root = fused_expression_root(); + HloInstruction::InstructionVector tuple_elements; + bool newly_created_tuple_instr = false; + if (fused_root->opcode() == HloOpcode::kTuple) { + tuple_elements = fused_root->operands(); + } else { + tuple_elements.push_back(fused_root); + newly_created_tuple_instr = true; + } + if (clone->opcode() == HloOpcode::kTuple) { + for (auto inst : clone->operands()) { + tuple_elements.push_back(inst); + } + } else { + tuple_elements.push_back(clone); + } + HloInstruction* new_root = fused_instructions_computation()->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + fused_instructions_computation()->set_root_instruction(new_root); + *mutable_shape() = new_root->shape(); + if (fused_root->opcode() == HloOpcode::kTuple) { + TF_CHECK_OK( + fused_instructions_computation()->RemoveInstruction(fused_root)); + } + + // If this is a newly created multioutput instruction, we need to update + // the use of the original fusion instruction. + if (newly_created_tuple_instr) { + HloInstruction* new_instr = parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); + TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + } + int64 index = tuple_elements.size(); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + index -= instruction_to_fuse->operand_count(); + std::vector to_be_removed; + for (auto old_gte : instruction_to_fuse->users()) { + CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); + int64 old_tuple_index = old_gte->tuple_index(); + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + old_gte->shape(), this, index + old_tuple_index)); + TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); + to_be_removed.push_back(old_gte); + } + for (auto old_gte : to_be_removed) { + TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); + } + TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); + } else { + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + clone->shape(), this, index - 1)); + TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); + } + } + + VLOG(2) << "New clone:\n" << clone->ToString(); + return clone; +} + +std::vector HloFusionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("kind=", xla::ToString(fusion_kind()))}; +} + +bool HloFusionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); +} + +std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } + return MakeUnique(shape, fusion_kind(), new_operands, + new_fused_computation); +} + +HloRngInstruction::HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters) + : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { + for (HloInstruction* param : parameters) { + AppendOperand(param); + } +} + +HloInstructionProto HloRngInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_distribution(distribution_); + return proto; +} + +std::vector HloRngInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("distribution=", RandomDistributionToString(distribution_))}; +} + +bool HloRngInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +bool HloRngInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, distribution_, new_operands); +} + +HloParameterInstruction::HloParameterInstruction(int64 parameter_number, + const Shape& shape, + const string& name) + : HloInstruction(HloOpcode::kParameter, shape), + parameter_number_(parameter_number) { + SetAndSanitizeName(name); +} + +HloInstructionProto HloParameterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_parameter_number(parameter_number_); + return proto; +} + +string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + return StrCat(parameter_number_); +} + +bool HloParameterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return parameter_number() == casted_other.parameter_number(); +} + +std::unique_ptr +HloParameterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(parameter_number_, shape, name()); +} + +HloGetTupleElementInstruction::HloGetTupleElementInstruction( + const Shape& shape, HloInstruction* operand, int64 index) + : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); + AppendOperand(operand); +} + +HloInstructionProto HloGetTupleElementInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_tuple_index(tuple_index_); + return proto; +} + +std::vector HloGetTupleElementInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("index=", tuple_index())}; +} + +bool HloGetTupleElementInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return tuple_index() == casted_other.tuple_index(); +} + +std::unique_ptr +HloGetTupleElementInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + tuple_index()); +} + +HloReducePrecisionInstruction::HloReducePrecisionInstruction( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits) + : HloInstruction(HloOpcode::kReducePrecision, shape), + exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits) { + AppendOperand(operand); +} + +HloInstructionProto HloReducePrecisionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + return proto; +} + +std::vector HloReducePrecisionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("exponent_bits=", exponent_bits_), + StrCat("mantissa_bits=", mantissa_bits_)}; +} + +bool HloReducePrecisionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + // A reduce-precision operation is determined by the bit sizes. + return exponent_bits() == casted_other.exponent_bits() && + mantissa_bits() == casted_other.mantissa_bits(); +} + +std::unique_ptr +HloReducePrecisionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + shape, new_operands[0], exponent_bits(), mantissa_bits()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..6749d875559008b3a2bd479ff075c83d85d87509 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -0,0 +1,727 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// All HloInstruction subclasses are put in this file. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +class HloBatchNormInstruction : public HloInstruction { + public: + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + float epsilon() const { return epsilon_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, float epsilon, + int64 feature_index); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // A small float number added to the variance to avoid divide-by-zero error. + float epsilon_ = 0.0f; + + // An integer value representing the index of the feature dimension. + int64 feature_index_ = -1; +}; + +class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormTrainingInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormGradInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloFftInstruction : public HloInstruction { + public: + explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + FftType fft_type() const { return fft_type_; } + + const std::vector& fft_length() const { return fft_length_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; +}; + +class HloSendRecvInstruction : public HloInstruction { + public: + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + int64 channel_id() const { return channel_id_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, + int64 channel_id); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Represents a unique identifier for each Send/Recv instruction pair. + int64 channel_id_; +}; + +class HloSendInstruction : public HloSendRecvInstruction { + public: + explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloSendDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloSendDoneInstruction(HloSendInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloReverseInstruction : public HloInstruction { + public: + explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloConcatenateInstruction : public HloInstruction { + public: + explicit HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Accessor for the dimension in which a concatenate HLO should occur. + int64 concatenate_dimension() const { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloReduceInstruction : public HloInstruction { + public: + explicit HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloTransposeInstruction : public HloInstruction { + public: + explicit HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloBroadcastInstruction : public HloInstruction { + public: + explicit HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloMapInstruction : public HloInstruction { + public: + explicit HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands = {}); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSliceInstruction : public HloInstruction { + public: + explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + HloInstructionProto ToProto() const override; + + // Returns the start index in the given dimension for a slice node. + int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } + const std::vector& slice_limits() const { return slice_limits_; } + + // Returns the stride in the given dimension for a slice node. + int64 slice_strides(int64 dimension) const { + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + std::vector slice_strides_; + + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; +}; + +class HloConstantInstruction : public HloInstruction { + public: + explicit HloConstantInstruction(std::unique_ptr literal); + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Change the layout for an Constant Hlo instruction to match new_layout. For + // tuple shaped constants shape_index is the path to the internal array + // subshape whose layout needs to be changed. + void RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index = {}); + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloTraceInstruction : public HloInstruction { + public: + explicit HloTraceInstruction(const string& tag, HloInstruction* operand); + // Returns a tag to be used in tracing. + string TracingTag() const { return literal_->GetR1U8AsString(); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloFusionInstruction : public HloInstruction { + public: + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + HloInstruction* fused_root); + + explicit HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Merges the fused instructions from 'instruction_to_merge' into the + // fused instruction set of 'this', updating operands as necessary. + // + // Predondition: 'instruction_to_merge' must be an operand of 'this'. + void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); + + // Merges the fused instructions from instruction_to_merge into the fused + // instruction set of 'this' and generates multioutput fusion instructions. + // All the users of instruction_to_merge will be redirected to 'this' + // instruction. instruction_to_merge will be removed from its parent + // computation. + void MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge); + + // Fuses the given instruction in this fusion instruction. instruction_to_fuse + // is cloned and the clone is placed in the fusion + // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather + // than moved to cleanly handle the case where the instruction has a use + // outside the fusion instruction. Moving such an instruction into a fusion + // instruction would violate the single-result invariant of HLO instructions + // and significantly complicate code generation. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse); + } + + // Fuses the given instruction in this fusion instruction and generate + // multioutput fusion instruction. A clone of the instruction_to_fuse will + // be part of the output of fusion instructions. The users of + // instruction_to_fuse will be redirected to this fusion instructions. + // instruction_to_fuse will be removed from its parent computation. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); + } + + // Returns the computation for this fused instruction. + HloComputation* fused_instructions_computation() const; + + // Returns the root instruction of the fused expression contained within this + // fusion instruction. + HloInstruction* fused_expression_root() const; + + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. + int64 fused_instruction_count() const; + + // Returns the fused parameter instruction in this fusion instruction + // corresponding to the given parameter number. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Returns the vector of fused parameters inside this fusion instruction. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const { + return fused_expression_root()->opcode() == HloOpcode::kTuple; + } + + FusionKind fusion_kind() const { return fusion_kind_; } + + void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + + private: + // Fuses the given instruction into this fusion instruction. When add_output + // is false (which is the default), instruction_to_fuse is cloned and the + // clone is placed in the fusion instruction. instruction_to_fuse is + // unchanged. + // + // When add_output is true, a clone of the instruction_to_fuse will be part + // of the output of fusion instructions. The users of instruction_to_fuse + // will be redirected to this fusion instructions. instruction_to_fuse will + // be removed from its parent computation. + HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + // Clones the given instruction_to_fuse and insert the clone into this fusion + // instruction. If add_output is true, a clone of instruction_to_fuse will + // be in the output of the this fusion instruction (part of the tuple of the + // fusion root). + HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The type of the fusion. Used by kFusion only. + FusionKind fusion_kind_; +}; + +class HloRngInstruction : public HloInstruction { + public: + explicit HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters); + // Returns the random distribution for this rng node. + RandomDistribution random_distribution() const { return distribution_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The distribution requested for random number generation. + RandomDistribution distribution_; +}; + +class HloParameterInstruction : public HloInstruction { + public: + explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, + const string& name); + int64 parameter_number() const { return parameter_number_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 parameter_number_ = 0; +}; + +class HloGetTupleElementInstruction : public HloInstruction { + public: + explicit HloGetTupleElementInstruction(const Shape& shape, + HloInstruction* operand, int64 index); + // Returns the tuple index associated with this instruction. + int64 tuple_index() const { return tuple_index_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 tuple_index_ = -1; +}; + +class HloReducePrecisionInstruction : public HloInstruction { + public: + explicit HloReducePrecisionInstruction(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits); + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const { return exponent_bits_; } + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const { return mantissa_bits_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc similarity index 94% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc rename to tensorflow/compiler/xla/service/hlo_lexer.cc index fc0e4444521247734fc240a03da669244fe1a6a4..f0d9fdbc8f86da0bb9d7f9235239df677c9506bc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include @@ -26,9 +26,8 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace tools { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; namespace { @@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { +tensorflow::StringPiece HloLexer::StringPieceFromPointers( + const char* begin, const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); + return tensorflow::StringPiece(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + tensorflow::StringPiece identifier = + StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. #define KEYWORD(STR) \ @@ -230,7 +230,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = identifier.ToString(); + str_val_ = std::string(identifier); return TokKind::kIdent; } @@ -332,23 +332,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == StringPiece::npos) { + if (line_offset == tensorflow::StringPiece::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -StringPiece HloLexer::GetLine(LocTy loc) const { +tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == StringPiece::npos + const char* start = line_start == tensorflow::StringPiece::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); - const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + const char* end = + line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -370,7 +371,7 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - StringPiece raw = + tensorflow::StringPiece raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { @@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) { } } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h similarity index 90% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.h rename to tensorflow/compiler/xla/service/hlo_lexer.h index 27880b9b8afbfa58abfedc3b2cecd5236b78a6d6..ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ #include -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. class HloLexer { public: explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { @@ -57,7 +59,7 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - int64 GetInt64Val() const { + tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; } @@ -114,7 +116,7 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - int64 int64_val_; + tensorflow::int64 int64_val_; double decimal_val_; struct LineNoCacheTy { @@ -125,7 +127,6 @@ class HloLexer { mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using Worklist = std::deque; +using Workset = std::unordered_set; + +namespace { + +void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, + Workset* workset) { + if (workset->count(instruction) == 0) { + worklist->push_back(instruction); + workset->insert(instruction); + VLOG(3) << "ADD instruction: " << instruction->name(); + } +} + +using VisitorFunction = std::function; + +void ForEachLiveIndex(const ShapeTree& index_tree, + const VisitorFunction& func) { + index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) { + if (live) { + func(shape_index); + } + }); +} + +// Marks 'instruction' output live at 'shape_index'. +// Adds to 'worklist' iff: +// *) 'instruction' is not already on worklist. +// *) 'shape_index' has not yet been visited. +void MarkLiveAtIndex(const HloInstruction* instruction, + const ShapeIndex& shape_index, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + auto it_added = live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + if (it->second.element(shape_index) == false) { + AddToWorklist(instruction, worklist, workset); + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } +} + +// Marks 'instruction' live at all shape indices in its output. +void MarkLiveAtAllIndices(const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + bool add_to_worklist = false; + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/true)); + add_to_worklist = true; + } else { + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& shape_index) { + if (it->second.element(shape_index) == false) { + add_to_worklist = true; + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } + }); + } + if (add_to_worklist) { + AddToWorklist(instruction, worklist, workset); + } +} + +// Propagates liveness through Tuple instructions. +// *) For each tuple operand: +// *) For tuple output shape index associated with operand: +// *) Propgate live shape indices to tuple operand at the associated +// shape index in the operands output, and add to worklist. +void PropagateLivenessThroughTuple( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); + for (int64 operand_index = 0; operand_index < instruction->operand_count(); + ++operand_index) { + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + if (shape_index.empty() || shape_index[0] != operand_index) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index; + for (int i = 1; i < shape_index.size(); ++i) { + operand_shape_index.push_back(shape_index[i]); + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); + } +} + +// Propagates liveness through GetTupleElement instructions. +// *) For each live index in GetTupleElement output, mark output of GTE operand +// at associated shape index in its output, and add to worklist. +void PropagateLivenessThroughGTE( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // Mark operand top-level index. + MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist, + workset); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + // Propagate live shape indices along GTE -> Tuple edge. + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + ShapeIndex operand_shape_index(shape_index); + operand_shape_index.push_front(instruction->tuple_index()); + MarkLiveAtIndex(instruction->operand(0), operand_shape_index, + live_index_map, worklist, workset); + }); +} + +// Propagates liveness through While instructions. +// *) For each live index in While output, mark shape index of while.body.root +// and while.operand (adding each to worklist). +// *) Mark while.cond.root and add to worklist. +void PropagateLivenessThroughWhile( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kWhile); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while body computation root instruction. + MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index, + live_index_map, worklist, workset); + // Propagate liveness to tuple-shaped operand. + MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map, + worklist, workset); + }); + + // Propagate liveness to while condition computation root instruction. + MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); +} + +// Propagates liveness out of Parameter instructions to callers and aliasing +// positions. This can occur if liveness propagates to a parameter in the +// while.condition computation, requiring liveness to propagate out to caller +// callsite while (and while.body.root). +void PropagateLivenessToParameterCallers( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + auto* xla_while = callsite.instruction(); + const ShapeTree& index_tree = + FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while result{shape_index} + MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist, + workset); + // Propagate liveness to while body root{shape_index}. + MarkLiveAtIndex(xla_while->while_body()->root_instruction(), + shape_index, live_index_map, worklist, workset); + // Propagate liveness to operand(0){shape_index}. + MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map, + worklist, workset); + }); + } + } + } +} + +} // namespace + +HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) + : module_(module), call_graph_(CallGraph::Build(&module)) {} + +// Runs liveness analysis on 'module_'. +// Initializes worklist with entry root instruction (and any instruction with +// side-effects), marking all of their output shape indices live. +// Visits elements on worklist, propagating liveness from an instructions +// live output shape indices to its called computations and operands. +void HloLivenessAnalysis::RunAnalysis() { + Worklist worklist; + Workset workset; + // Add entry compuation root instruction. + MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(), + &live_index_map_, &worklist, &workset); + for (auto* computation : module_.computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->HasSideEffectNoRecurse()) { + // Add instructions with side effects. + MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist, + &workset); + } + } + } + + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop_front(); + workset.erase(workset.find(instruction)); + VLOG(1) << "VISIT instruction: " << instruction->name(); + + if (instruction->opcode() == HloOpcode::kTuple) { + PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { + PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kWhile && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessToParameterCallers(instruction, &live_index_map_, + &worklist, &workset, + call_graph_.get()); + } else { + // Propagate liveness to called computations. + for (auto* called_computation : instruction->called_computations()) { + MarkLiveAtAllIndices(called_computation->root_instruction(), + &live_index_map_, &worklist, &workset); + } + // Propagate liveness to operands. + for (HloInstruction* operand : instruction->operands()) { + MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); + } + } + } +} + +bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const { + if (ContainsKey(live_index_map_, instruction)) { + return FindOrDie(live_index_map_, instruction).element(shape_index); + } + return false; +} + +/* static */ +StatusOr> HloLivenessAnalysis::Run( + const HloModule& module) { + VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); + + auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + + liveness_analysis->RunAnalysis(); + + return std::move(liveness_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..fe55a8070a42a3d68836dd32cf7ce5823dd77951 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = + std::unordered_map>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0275294a1a86cef13e5b267ad578f30cc18858dc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -0,0 +1,402 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloLivenessAnalysisTest : public HloTestBase { + protected: + HloLivenessAnalysisTest() {} + + // Run liveness analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloLivenessAnalysis& RunLiveness(HloModule* module) { + liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie(); + return *liveness_; + } + + HloInstruction* GetInstruction(HloModule* module, const string& name) { + HloInstruction* to_return = nullptr; + for (auto* comp : module->computations()) { + for (auto* inst : comp->instructions()) { + if (inst->name() == name) { + to_return = inst; + break; + } + } + } + return CHECK_NOTNULL(to_return); + } + + std::unique_ptr liveness_; +}; + +// Test that add instruction at entry root is live at all output shape indices. +TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT add = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Test that a dead add instruction is marked as dead by analysis. +TEST_F(HloLivenessAnalysisTest, DeadAdd) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + add.1 = s32[] add(constant.1, constant.2) + ROOT add.2 = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); +} + +// Test that all output shape indices of entry root tuple (and defining +// instruction in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that all outputs of nested tuple and entry root (and defining +// instruction values appearing in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(1) + constant.2 = s32[] constant(2) + constant.3 = s32[] constant(3) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE at entry root of Tuple instruction only propgates liveness +// to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that GTE at entry root of nested Tuple instruction only propgates +// liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE of GTE (at entry root) of nested Tuple instruction only +// propgates liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {})); + + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_FALSE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_FALSE( + liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Test that live/dead while tuple elements are marked live/dead correctly. +TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a tuple element live in while.cond computation, propagates +// liveness to while.body.root/while.result/while.operand (where it is unused). +TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 + add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(add.1, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a use of while.result{0} propagates liveness to +// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. +TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2 + multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3) + ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.1 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + } + ENTRY SimpleLoop { + constant.2 = s32[] constant(0) + constant.3 = s32[] constant(1) + constant.4 = s32[] constant(2) + tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4) + while.1 = (s32[], s32[], s32[]) while(tuple.2), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0 + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2})); + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2})); + // While body root. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2})); + // While body param. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index bc74c4bc10cad20eab20b5caf8550b17048a5276..7e4b8834357d39099f76450b849d6b5624e4e3b4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -17,10 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { +using ::tensorflow::str_util::Join; + bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -132,6 +135,104 @@ bool HloCustomCallMatcher::MatchAndExplain( return result; } +bool HloShapeMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (ShapeUtil::Compatible(instruction->shape(), shape_)) { + return true; + } + *listener << instruction->ToString() << " has incorrect shape (expected: " + << ShapeUtil::HumanString(shape_) << ")"; + return false; +} + +void HloShapeMatcher::DescribeTo(std::ostream* os) const { + *os << ShapeUtil::HumanString(shape_); +} + +bool HloShapeAndLayoutMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (ShapeUtil::Equal(instruction->shape(), shape_)) { + return true; + } + *listener << instruction->ToString() << " has incorrect shape (expected: " + << ShapeUtil::HumanStringWithLayout(shape_) << ")"; + return false; +} + +void HloShapeAndLayoutMatcher::DescribeTo(std::ostream* os) const { + *os << ShapeUtil::HumanStringWithLayout(shape_); +} + +bool HloShardingMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!sharding_.has_value()) { + if (!instruction->has_sharding()) { + return true; + } + *listener << instruction->ToString() << " expected to have no sharding."; + return false; + } + if (instruction->has_sharding()) { + if (instruction->sharding() == sharding_.value()) { + return true; + } + *listener << instruction->ToString() + << " has incorrect sharding (expected: " << sharding_->ToString() + << ")"; + return false; + } else { + *listener << instruction->ToString() + << " has no sharding (expected: " << sharding_->ToString() << ")"; + return false; + } +} + +void HloShardingMatcher::DescribeTo(std::ostream* os) const { + if (sharding_.has_value()) { + *os << sharding_->ToString(); + } else { + *os << ""; + } +} + +bool HloDotWithContractingDimsMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + + const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); + if (dim_nums.lhs_contracting_dimensions_size() != 1 || + dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong lhs_contracting_dimensions (got {" + << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" + << lhs_contracting_dim_ << "})"; + return false; + } + + if (dim_nums.rhs_contracting_dimensions_size() != 1 || + dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong rhs_contracting_dimensions (got {" + << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" + << rhs_contracting_dim_ << "})"; + return false; + } + + return true; +} + +void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with lhs_contracting_dims={" << lhs_contracting_dim_ + << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}"; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 103f04a2cb7a1a5ae877d8bf259692f7cbed3408..c570b420c21fed4d7828feb24ee5c7859db94a79 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace xla { namespace testing { @@ -86,6 +88,71 @@ class HloCustomCallMatcher : public HloMatcher { ::testing::Matcher call_target_matcher_; }; +class HloShapeMatcher + : public ::testing::MatcherInterface { + public: + explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + Shape shape_; +}; + +class HloShapeAndLayoutMatcher + : public ::testing::MatcherInterface { + public: + explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + Shape shape_; +}; + +// Verify the sharding of an instruction against the provided HloSharding. If a +// nullopt is provided for the expected sharding then it checks that no sharding +// is present for an instruction. +class HloShardingMatcher + : public ::testing::MatcherInterface { + public: + explicit HloShardingMatcher( + const tensorflow::gtl::optional& sharding) + : sharding_(sharding) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + tensorflow::gtl::optional sharding_; +}; + +// Matches a Dot HLO instruction with specific LHS and RHS contracting +// dimensions. +class HloDotWithContractingDimsMatcher : public HloMatcher { + public: + explicit HloDotWithContractingDimsMatcher( + ::testing::Matcher lhs, + ::testing::Matcher rhs, int64 lhs_contracting_dim, + int64 rhs_contracting_dim) + : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}), + lhs_contracting_dim_(lhs_contracting_dim), + rhs_contracting_dim_(rhs_contracting_dim) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + int64 lhs_contracting_dim_; + int64 rhs_contracting_dim_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -113,7 +180,6 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); -HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -231,6 +297,70 @@ inline ::testing::Matcher CustomCall() { new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {})); } +// Verifies the shape or the shape and the layout of an HLO instruction against +// the provided shape object. +inline ::testing::Matcher Shape( + const class Shape& shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); +} +inline ::testing::Matcher Shape( + tensorflow::StringPiece shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( + ShapeUtil::ParseShapeString(shape).ValueOrDie())); +} +inline ::testing::Matcher ShapeWithLayout( + const class Shape& shape) { + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeAndLayoutMatcher(shape)); +} +inline ::testing::Matcher ShapeWithLayout( + tensorflow::StringPiece shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( + ShapeUtil::ParseShapeString(shape).ValueOrDie())); +} + +// Verifies the value of the HloSharing against the provided sharding object. +inline ::testing::Matcher Sharding( + const HloSharding& sharding) { + return ::testing::MakeMatcher( + new ::xla::testing::HloShardingMatcher(sharding)); +} +// Matcher for Sharding from sharding string +inline ::testing::Matcher Sharding( + tensorflow::StringPiece sharding) { + return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( + ParseSharding(sharding).ValueOrDie())); +} +// Verifies that no HloSharding is set for an HLO instruction. +inline ::testing::Matcher NoSharding() { + return ::testing::MakeMatcher( + new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); +} + +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher})); +} + +// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension +// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension +// equal to `rhs_contracting_dim`. +// +// Currently the HLO verifier rejects Dot operations with more than one +// contracting dimension (even though we can represent these in the +// DotDimensionNumbers proto) so there is no need to generalize this to support +// multiple contracting dimensions. +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher, + int64 lhs_contracting_dim, int64 rhs_contracting_dim) { + return ::testing::MakeMatcher( + new ::xla::testing::HloDotWithContractingDimsMatcher( + lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim)); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 1c21703a45e11914854153bc14fabd85e9ea57f2..9a3010cf1ff75e840130d8442bbe26d6041cef25 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" namespace op = xla::testing::opcode_matchers; @@ -100,5 +101,123 @@ TEST(HloMatchersTest, CustomCallMatcher) { R"(custom-call with call target that is equal to "foo_target")"); } +TEST(HloMatchersTest, ShapeMatcher) { + auto p0 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param"); + + EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {5, 7}))); + EXPECT_THAT(p0.get(), op::Shape("f32[5,7]")); + EXPECT_THAT( + p0.get(), + ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {5, 7})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]"))); + EXPECT_THAT(p0.get(), + ::testing::Not(op::Shape(ShapeUtil::MakeShape(F32, {7, 5})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::Shape("f32[7,5]"))); + EXPECT_THAT( + p0.get(), + ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {7, 5})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[7,5]"))); + EXPECT_THAT(p0.get(), + op::Shape(ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}))); + EXPECT_THAT(p0.get(), op::Shape("f32[5,7]{0,1}")); + EXPECT_THAT(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout( + F32, {5, 7}, {0, 1}))); + EXPECT_THAT(p0.get(), op::ShapeWithLayout("f32[5,7]{0,1}")); + EXPECT_THAT(p0.get(), + ::testing::Not(op::ShapeWithLayout( + ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {1, 0})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]{1,0}"))); + + EXPECT_THAT(Explain(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))), + "%param = f32[5,7]{0,1} parameter(0) has incorrect shape " + "(expected: f32[7,5])"); + EXPECT_THAT( + Explain(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout( + F32, {7, 5}, {1, 0}))), + "%param = f32[5,7]{0,1} parameter(0) has incorrect shape " + "(expected: f32[7,5]{1,0})"); +} + +TEST(HloMatchersTest, ShardingMatcher) { + auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}), + "param.0"); + p0->clear_sharding(); + auto p1 = HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {7}), + "param.1"); + p1->set_sharding(HloSharding::AssignDevice(1)); + + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}), + ShapeUtil::MakeShape(F32, {11})}); + auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2"); + Array assignment({2}); + assignment.SetValues({0, 1}); + auto sharding = HloSharding::Tuple( + tuple_shape, + {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), + HloSharding::AssignDevice(1), HloSharding::Replicate()}); + p2->set_sharding(sharding); + + EXPECT_THAT(p0.get(), op::NoSharding()); + EXPECT_THAT(p0.get(), + ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); + EXPECT_THAT(p1.get(), ::testing::Not(op::NoSharding())); + EXPECT_THAT(p1.get(), + ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); + EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + + EXPECT_THAT( + p2.get(), + op::Sharding( + "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), + "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " + "{maximal device=1})"); + EXPECT_THAT(Explain(p1.get(), op::NoSharding()), + "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} " + "expected to have no sharding."); + EXPECT_THAT(Explain(p1.get(), op::Sharding(HloSharding::AssignDevice(0))), + "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} " + "has incorrect sharding (expected: {maximal device=0})"); +} + +TEST(HloMatchersTest, DotMatcher) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0)); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/0, + /*rhs_contracting_dim=*/0)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "lhs_contracting_dimensions (got {1} want {0})"); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "rhs_contracting_dimensions (got {0} want {1})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cb2fe9f874012a51e1e6cbd1dd086dbb26994bde..9c59374b4a9d7e3dbfb99d8a6b30d4230e553658 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -32,23 +32,23 @@ limitations under the License. namespace xla { -HloModule::HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), - config_(config), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - unique_id_(next_unique_module_id_++) {} - -HloModule::HloModule(const string& name) - : name_(NameUniquer::GetSanitizedName(name)), - unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { @@ -58,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_entry_computation_layout()) { + if (!config_.has_host_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -83,6 +83,11 @@ HloComputation* HloModule::AddComputationInternal( for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); + computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -204,90 +209,38 @@ string HloModule::ToString(const HloPrintOptions& options) const { HloModuleProto HloModule::ToProto() const { HloModuleProto proto; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); + proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are added when the fusion instructions are created by - // HloInstruction::CreateFromProto. - if (computation->IsFusionComputation()) { - continue; - } HloComputationProto computation_proto = computation->ToProto(); + if (computation->name() == entry_computation_->name()) { + *proto.mutable_program_shape() = computation_proto.program_shape(); + } proto.add_computations()->Swap(&computation_proto); } return proto; } -namespace { - -// Construct a ProgramShape matching the shape of the parameters and root of the -// given module's entry computation. -StatusOr ProgramShapeFromProto(const HloModuleProto& module) { - const HloComputationProto* entry_computation = nullptr; - for (const HloComputationProto& computation : module.computations()) { - if (computation.name() == module.entry_computation_name()) { - entry_computation = &computation; - break; - } - } - TF_RET_CHECK(entry_computation != nullptr) - << "No computation with entry computation name" - << module.entry_computation_name(); - - tensorflow::gtl::FlatMap> parameters; - const HloInstructionProto* root = nullptr; - for (const HloInstructionProto& instruction : - entry_computation->instructions()) { - if (instruction.name() == entry_computation->root_name()) { - TF_RET_CHECK(root == nullptr) << "Entry computation has more than " - "one instruction with (root) name " - << instruction.name(); - root = &instruction; - } - if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { - TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) - << "Entry computation has more than one parameter instruction " - "with parameter number " - << instruction.parameter_number(); - parameters[instruction.parameter_number()] = {instruction.name(), - &instruction.shape()}; - } - } - TF_RET_CHECK(root != nullptr) - << "Entry computation is missing root instruction named " - << entry_computation->root_name(); - - ProgramShape program_shape; - *program_shape.mutable_result() = root->shape(); - for (int64 i = 0; i < parameters.size(); ++i) { - TF_RET_CHECK(ContainsKey(parameters, i)) - << "Entry computation missing parameter number " << i; - const string& name = parameters.at(i).first; - const Shape& shape = *parameters.at(i).second; - *program_shape.add_parameters() = shape; - program_shape.add_parameter_names(name); - } - - return std::move(program_shape); -} - -} // namespace - /* static */ StatusOr> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle) { + const HloModuleProto& proto, const HloModuleConfig& module_config) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, - ProgramShapeFromProto(proto)); - TF_RET_CHECK(expected_program_shape.parameters_size() == - module_config.entry_computation_layout().parameter_count()); + TF_RET_CHECK(proto.has_program_shape()) + << "No program shape found in the proto"; + const auto& expected_program_shape = proto.program_shape(); + TF_RET_CHECK( + expected_program_shape.parameters_size() == + module_config.device_entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.entry_computation_layout().parameter_layout(i).shape(); - TF_RET_CHECK( - ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + module_config.device_entry_computation_layout() + .parameter_layout(i) + .shape(); + TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), + parameter_shape)) << "HloModuleConfig has different shape for parameter " << i << " than the HLO module. Expected: " << ShapeUtil::HumanStringWithLayout( @@ -295,37 +248,51 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.entry_computation_layout().result_layout().shape(); - TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + module_config.device_entry_computation_layout().result_layout().shape(); + TF_RET_CHECK( + ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " "Expected: " << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); - - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> computations; + HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { TF_ASSIGN_OR_RETURN( std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map, - /*add_fused_computation=*/ - [&module](std::unique_ptr fused_computation) { - module->AddComputationInternal(std::move(fused_computation), - /*is_entry=*/false, - /*uniquify_names=*/false); - })); + HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); - TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); - string computation_name = computation->name(); + int64 computation_id = computation_proto.id(); + TF_RET_CHECK(computation_id != -1); + TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); + computation_map[computation_id] = computation.get(); + to_proto_id[computation.get()] = computation_id; + if (computation_id == proto.entry_computation_id()) { + entry = computation.get(); + } + computations.push_back(std::move(computation)); + } + TF_RET_CHECK(entry != nullptr); + + auto module = MakeUnique(proto.name(), module_config); + + // Sort the computations in the proto id's order. + std::sort(computations.begin(), computations.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + // Add sorted computations to the module. + for (auto& computation : computations) { + bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_name] = module->AddComputationInternal( - std::move(computation), - /*is_entry=*/proto.entry_computation_name() == computation_name, - /*uniquify_names=*/false); + module->AddComputationInternal(std::move(computation), is_entry, + /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -334,10 +301,6 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; for (HloComputation* computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); @@ -353,16 +316,18 @@ StatusOr> HloModule::CreateFromProto( /* static */ StatusOr HloModule::CreateModuleConfigFromProto( - const HloModuleProto& module) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - ProgramShapeFromProto(module)); + const HloModuleProto& module, const DebugOptions& debug_options) { + TF_RET_CHECK(module.has_program_shape()) + << "No program shape found in the proto"; + const auto& program_shape = module.program_shape(); HloModuleConfig module_config(program_shape); + module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_entry_computation_layout(); + module_config.mutable_host_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -370,6 +335,8 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); + *module_config.mutable_device_entry_computation_layout() = + module_config.host_entry_computation_layout(); return module_config; } @@ -423,7 +390,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( - parameter_count, old_operand->shape(), "")); + parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( @@ -518,7 +485,18 @@ std::list HloModule::MakeComputationPostOrder() const { added_computations.insert(computation.get()); } } - CHECK_EQ(post_order.size(), computations_.size()); + if (post_order.size() != computations_.size()) { + for (HloComputation* computation : post_order) { + LOG(ERROR) << "Post Order: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + for (auto& computation : computations_) { + LOG(ERROR) << "Computations: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() + << " computation_count=" << computations_.size(); + } return post_order; } @@ -535,59 +513,27 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix); - module->config_ = config_; - module->entry_computation_handle_ = entry_computation_handle_; - module->has_entry_computation_handle_ = has_entry_computation_handle_; + auto module = MakeUnique(name_ + "-" + suffix, config_); - std::unordered_map clone_map; - for (auto& computation : computations_) { - if (computation->IsFusionComputation()) { - // Cloning of a fused computation is handled by its fusion instruction. - continue; - } - - // When cloning a computation, pass in the new module, so that for any - // fusion instruction in this computation, the fused computation will be - // deep cloned to the new module. - auto cloned_computation = computation->Clone(suffix, module.get()); - InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); - - if (entry_computation_ == computation.get()) { - module->AddEntryComputation(std::move(cloned_computation)); - } else { - module->AddEmbeddedComputation(std::move(cloned_computation)); - } - } - - for (auto& cloned_computation : module->computations_) { - for (auto* instruction : cloned_computation->instructions()) { - // Rewrite instruction's called_computation to point to the cloned - // computations. - instruction->ReplaceCalledComputations([&](HloComputation* hlo) { - if (hlo->IsFusionComputation()) { - // Cloning of a fused computation has already been handled when its - // fusion instruction is cloned. So this hlo computation is already - // the cloned one. - return hlo; - } - return FindOrDie(clone_map, hlo); - }); - } - } + HloCloneContext context(module.get(), suffix); + auto cloned_computation = entry_computation_->Clone(suffix, &context); + module->AddEntryComputation(std::move(cloned_computation)); return module; } -HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { - HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); - TF_CHECK_OK( - clone->root_instruction()->Accept([this](HloInstruction* instruction) { - instruction->ReplaceCalledComputations([this](HloComputation* callee) { - return DeepCloneComputation(callee); - }); - return Status::OK(); - })); - return clone; +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, + HloCloneContext* context) { + HloComputation* new_computation; + if (context != nullptr) { + if ((new_computation = context->FindComputation(computation)) != nullptr) { + return new_computation; + } + new_computation = + AddEmbeddedComputation(computation->Clone(context->suffix(), context)); + } else { + new_computation = AddEmbeddedComputation(computation->Clone("")); + } + return new_computation; } uint64 HloModule::RandomNew64() const { @@ -595,6 +541,14 @@ uint64 HloModule::RandomNew64() const { return rng_(); } +HloComputation* HloModule::GetComputationWithName( + tensorflow::StringPiece name) { + auto it = c_find_if(computations(), [&](HloComputation* computation) { + return computation->name() == name; + }); + return it == computations().end() ? nullptr : *it; +} + /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 06d92f94fd6f62162b22575e9cc341f2906cd0db..757e65bda286d983d05e5a791aa7dffe97bac945 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -26,12 +26,13 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -41,21 +42,24 @@ namespace xla { // Describes a compilation unit at the HLO level. // -// A HLO module contains one or more HLO computations. The module contains one -// "entry" computation which produces the result. The module also includes any -// embedded computations used by instructions such as "map" and "reduce". All -// computations are owned by the module. +// HloModule is the top-level unit in the HLO IR. It corresponds to a whole +// "program". Running a module, from beginning to end, is the only way to run +// an XLA program. +// +// A module contains one "entry computation"; this HloComputation is like main() +// in a C program. The result of running the module is the result of running +// this computation. +// +// A module also contains some number of "nested computations". Each nested +// computation is attached to an HloInstruction within some other computation. +// The meaning of the nested computation depends on the instruction it's +// attached to. class HloModule { public: - HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config); - // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. - explicit HloModule(const string& name); explicit HloModule(const string& name, const HloModuleConfig& config); // Adds an entry computation to the module. A module can only have one entry @@ -86,8 +90,10 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all - // the called computations as well. - HloComputation* DeepCloneComputation(HloComputation* computation); + // the called computations as well. If the clone context is specified, it + // will be populated with the cloned object mappings. + HloComputation* DeepCloneComputation(HloComputation* computation, + HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { @@ -99,16 +105,20 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_entry_computation_layout() { - return config_.mutable_entry_computation_layout(); + ComputationLayout* mutable_host_entry_computation_layout() { + return config_.mutable_host_entry_computation_layout(); } - ComputationLayout entry_computation_layout() const { - return config_.entry_computation_layout(); + const ComputationLayout& host_entry_computation_layout() const { + return config_.host_entry_computation_layout(); } - const VersionedComputationHandle& entry_computation_handle() const { - return entry_computation_handle_; + ComputationLayout* mutable_device_entry_computation_layout() { + return config_.mutable_device_entry_computation_layout(); + } + + const ComputationLayout& device_entry_computation_layout() const { + return config_.device_entry_computation_layout(); } // Gets the computations in this module. @@ -131,6 +141,10 @@ class HloModule { MakeUnwrappingIterator(computations_.end())}; } + // Returns the computation in this module that has the name `name`. Returns + // null if there is no such computation. + HloComputation* GetComputationWithName(tensorflow::StringPiece name); + // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } @@ -165,14 +179,12 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle = - VersionedComputationHandle()); + const HloModuleProto& proto, const HloModuleConfig& module_config); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromProto( - const HloModuleProto& module); + const HloModuleProto& module, const DebugOptions& debug_options); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. @@ -187,11 +199,6 @@ class HloModule { // Returns a randomly generated uint64. uint64 RandomNew64() const; - // Returns the unique name for a computation in this module. - string GetUniqueCompuationName(const string& prefix) { - return computation_name_uniquer_.GetUniqueName(prefix); - } - // Returns the NameUniquer for uniquing instruction names in this module. NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } @@ -210,6 +217,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -227,10 +253,6 @@ class HloModule { mutable std::mt19937_64 rng_{42}; mutable tensorflow::mutex rng_mutex_; - // Versioned handle of the entry computation of the module. - bool has_entry_computation_handle_ = false; - VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/*separator=*/"."}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 822e2f1f53e5ee460b88c2241ecf7f6b91ef608b..dae5578a3158fecb8219e518841dec1020b2ca98 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -31,24 +31,33 @@ using tensorflow::strings::StrAppend; HloModuleConfig::HloModuleConfig() {} HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : entry_computation_layout_(program_shape) {} + : host_entry_computation_layout_(program_shape), + device_entry_computation_layout_(program_shape) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - entry_computation_layout_ = ComputationLayout(program_shape); + host_entry_computation_layout_ = ComputationLayout(program_shape); + device_entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); + tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { + host_entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); + host_entry_computation_layout_->result_shape().SerializeAsString()); + for (const ShapeLayout& param_layout : + device_entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend( + &key, tensorflow::str_util::Join(params, ", "), ") => ", + device_entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index a5ee895e48448fbb8fa3879dc1b6764c1f9f6966..cdb0b29a2399b387bc617262032e9083ba079625 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -41,31 +41,59 @@ class HloModuleConfig { explicit HloModuleConfig(const ProgramShape& program_shape); // Checks if this config has an entry computation layout already. - bool has_entry_computation_layout() const { - return entry_computation_layout_.has_value(); + bool has_host_entry_computation_layout() const { + return host_entry_computation_layout_.has_value(); + } + + bool has_device_entry_computation_layout() const { + return device_entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the layout of the entry computation. + // Returns a constant reference to the on-host layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& host_entry_computation_layout() const { + CHECK(host_entry_computation_layout_.has_value()); + return *host_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-host entry computation. + // Assumes the layout was set. + ComputationLayout* mutable_host_entry_computation_layout() { + CHECK(host_entry_computation_layout_.has_value()); + return &(*host_entry_computation_layout_); + } + + // Returns a constant reference to the on-device layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& device_entry_computation_layout() const { + CHECK(device_entry_computation_layout_.has_value()); + return *device_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-device entry computation. // Assumes the layout was set. - const ComputationLayout& entry_computation_layout() const { - CHECK(entry_computation_layout_.has_value()); - return *entry_computation_layout_; + ComputationLayout* mutable_device_entry_computation_layout() { + CHECK(device_entry_computation_layout_.has_value()); + return &(*device_entry_computation_layout_); } - // Returns a mutable pointer to the layout of the entry computation. Assumes - // the layout was set. - ComputationLayout* mutable_entry_computation_layout() { - CHECK(entry_computation_layout_.has_value()); - return &(*entry_computation_layout_); + // Returns whether to enable HLO-level profiling. + bool hlo_profiling_enabled() const { + return debug_options_.xla_hlo_profile(); } - // Sets/returns whether to enable HLO-level profiling. - bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } - void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } + // Sets/returns whether this is a "host module". Host modules are used to + // record the data- and control-flow dependencies of host side computation + // that communicates with compiled code. They are used for analysis and + // scheduling purposes, but no code is generated. + bool is_host_module() const { return is_host_module_; } + void set_is_host_module(bool is_host_module) { + is_host_module_ = is_host_module; + } // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } @@ -99,10 +127,11 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional entry_computation_layout_; + tensorflow::gtl::optional host_entry_computation_layout_; + tensorflow::gtl::optional device_entry_computation_layout_; - // Whether to enable HLO-level profiling. - bool hlo_profiling_enabled_ = false; + // Whether this is a 'host module'. + bool is_host_module_ = false; // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc new file mode 100644 index 0000000000000000000000000000000000000000..98d20315e399c6b1a3979b5d11a89ef93869f4d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool HasSendRecv(HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecv || + instruction->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (auto* sub_computation : instruction->called_computations()) { + if (HasSendRecv(sub_computation)) { + return true; + } + } + } + return false; +} + +StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + + const auto* xla_while = instruction; + auto* while_body_comp = xla_while->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + + if (!ShapeUtil::IsTuple(xla_while->shape()) || + while_body_root->opcode() != HloOpcode::kTuple || + HasSendRecv(while_body_comp)) { + // Only run DCE on tuple-shaped while loops where body root is Tuple, + // with no send/recv instructions. + VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); + continue; + } + + // Remove dead tuple elements. + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(xla_while->shape()); + for (int64 i = 0; i < tuple_element_count; ++i) { + if (liveness->IsLive(xla_while, {i})) { + continue; + } + VLOG(1) << "WhileDCE Dead while tuple element." + << " while: " << xla_while->name() << " tuple_index: " << i; + // Transform while.body computation to make tuple element at + // 'shape_index' as simple pass-through parameter (which candidate + // be removed later by simplification pass). + HloInstruction* pass_thru_gte = while_body_comp->AddInstruction( + HloInstruction::CreateGetTupleElement( + while_body_param->shape().tuple_shapes(i), while_body_param, + i)); + // Replace while.body.root Tuple operand at 'tuple_index' with + // 'pass_thru_gte', making prior operand a dead root (to be cleaned + // up with a subsequent DCE pass). + TF_RETURN_IF_ERROR( + while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + changed = true; + } + } + } + return changed; +} + +} // namespace + +StatusOr HloModuleDCE::Run(HloModule* module) { + VLOG(2) << "Before HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + std::unique_ptr liveness; + TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + + // Sweep through while instructions, transforming dead while tuple element + // computations to pass through tuple values (creating dead roots in while + // body computation in the process). + TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get())); + + // Run HloDCE to clean up any dead code created during HloModuleDCE. + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module)); + + VLOG(2) << "After HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + return hlo_module_dce_changed | hlo_dce_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h new file mode 100644 index 0000000000000000000000000000000000000000..29024085c1038961ef2b3721de1ce0e8a55ccf45 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes dead code from computations in the module using +// HloModule-scoped analysis (HloLivenessAnalysis). +// +// Sweeps through live instructions which cross computation boundaries (kWhile), +// and removes code at dead shape indices. +// +class HloModuleDCE : public HloPassInterface { + public: + ~HloModuleDCE() override {} + tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..363862e4905fc13a4ef07aeaac255259fc6b86ba --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloModuleDceTest : public HloTestBase { + protected: + HloModuleDceTest() {} + + // Returns whether the given instruction exists in the given computation. + bool HasInstruction(const HloComputation& computation, + const HloInstruction* instruction) { + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); + } + + // Returns whether the while instruction with name 'while_name' in + // 'computation' passes through its tuple element at 'tuple_index' from + // parameter to root instruction. + bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation, + const string& while_name, + const int64 tuple_index) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile && + instruction->name() == while_name) { + auto* while_body_comp = instruction->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + auto* operand = while_body_root->operand(tuple_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->tuple_index() == tuple_index && + operand->operand(0) == while_body_param) { + return true; + } + return false; + } + } + return false; + } +}; + +// Tests that a while with all outputs live is unmodified. +TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests a while loop with one unused output (which is used in the while loop +// body by an instruction with side-effects: rng) is unmodified. +TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], f32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 + constant.2 = f32[] constant(1.0) + rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform + add.1 = s32[] add(get-tuple-element.2, constant.2) + ROOT tuple = (s32[], f32[]) tuple(add, add.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.3 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + } + ENTRY SimpleLoop { + constant.4 = s32[] constant(0) + constant.5 = f32[] constant(0.0) + tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5) + while = (s32[], f32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a while loop with one dead tuple element at {1} has its while +// loop body modified to make that tuple element pass-through the while body. +TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} should now be pass-through after ModuleDCE. + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a tuple element {1} used by condition computation (which appears +// dead in while.body{1} and at while.result{1}) propgates liveness of this +// tuple element to while.body{1} and at while.result{1}. +TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[]) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4) + while = (s32[], s32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} still be pass-through after ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at index {1} between +// two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6) + while.1 = (s32[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1 and while.2 should have pass-thru elements, + // after being modified to pass through unused tuple element {1}. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and +// while.2{1}, between two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5) + while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf33640db16638803f4f8e6c66f35d6bb6e2c9fe --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -0,0 +1,489 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +string HloModuleGroupMetadata::TrackedInstruction::ToString() const { + string repr = + (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL"; + switch (kind_) { + case ComputationKind::kInvalid: + repr += ":INVALID"; + break; + case ComputationKind::kWhileCondition: + repr += ":WHILE_CONDITION"; + break; + case ComputationKind::kWhileBody: + repr += ":WHILE_BODY"; + break; + case ComputationKind::kConditionalTrue: + repr += ":CONDITIONAL_TRUE"; + break; + case ComputationKind::kConditionalFalse: + repr += ":CONDITIONAL_FALSE"; + break; + case ComputationKind::kCallFunction: + repr += ":CALL"; + break; + } + return repr; +} + +/* static */ StatusOr> +HloModuleGroupMetadata::Build(const std::vector& modules) { + auto metadata = MakeUnique(modules); + TF_RETURN_IF_ERROR(metadata->Build()); + return std::move(metadata); +} + +Status HloModuleGroupMetadata::Build() { + TF_RETURN_IF_ERROR(RecordInstructions()); + TF_RETURN_IF_ERROR(VerifyChannelInstructions()); + + // Record all companion while instructions. + const auto visitor = [this](HloInstruction* hlo) -> Status { + // We only need to process if the instruction is within the computation + // of a companion instruction, like in the condition or body computation + // of a While. + const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent()); + if (tracked == nullptr) { + return Status::OK(); + } + // Add the parent computation of this channel instruction and its peer + // computation (both must be while computations) as companions. + if (IsChannelInstruction(hlo)) { + HloComputation* peer_computation = PeerComputation(hlo); + const TrackedInstruction* peer_tracked = + GetTrackedInstruction(peer_computation); + TF_RET_CHECK(peer_tracked != nullptr) + << "Peer instruction is not a possible companion"; + TF_RET_CHECK(*tracked == *peer_tracked) + << "Peer instruction does not match the computation kind"; + TF_RETURN_IF_ERROR( + AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); + } + + // Add the parents of companion instructions (they must be all of the same + // kind of instructions, opcode wise) as companions. + if (IsCompanionInstruction(hlo)) { + for (HloInstruction* companion : Companions(hlo)) { + const TrackedInstruction* companion_tracked = + GetTrackedInstruction(companion->parent()); + TF_RET_CHECK(companion_tracked != nullptr); + TF_RET_CHECK(*tracked == *companion_tracked); + TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(), + companion_tracked->instruction())); + } + } + return Status::OK(); + }; + + // Visit the computations in postorder so that the companion information grows + // from inner computations to outer ones. + for (HloModule* module : modules_) { + for (HloComputation* computation : module->MakeComputationPostOrder()) { + TF_RETURN_IF_ERROR(computation->Accept(visitor)); + } + } + TF_RETURN_IF_ERROR(VerifyCompanionSets()); + if (VLOG_IS_ON(4)) { + DumpCollectedStats(); + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyCompanionSets() const { + for (const auto& companions : companion_sets_) { + // A companion set must be composed at most of an instruction per + // device/module. + std::unordered_set devices; + for (HloInstruction* instruction : *companions) { + // Go through all the communicating instructions (send, recv) of the given + // companion, and record their device. + auto it = tracked_instructions_comms_.find(instruction); + if (it == tracked_instructions_comms_.end()) { + // Companions can be added even if they have no communicating + // instructions, if they are parent of companions. + continue; + } + std::unordered_set comm_devices; + for (HloInstruction* comm_instruction : it->second) { + auto device = GetInstructionDevice(*comm_instruction); + TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() + << " does not have a device"; + comm_devices.insert(*device); + } + for (int64 device : comm_devices) { + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); + } + } + } + } + return Status::OK(); +} + +bool HloModuleGroupMetadata::IsChannelInstruction( + const HloInstruction* instruction) const { + switch (instruction->opcode()) { + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kSendDone: + case HloOpcode::kRecvDone: + return true; + default: + return false; + } +} + +bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { + return companion_set_index_.count(hlo) > 0; +} + +bool HloModuleGroupMetadata::InstructionCommunicates( + HloInstruction* hlo) const { + return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo); +} + +const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( + int64 channel_id) const { + CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end()); + return channels_[channel_id_map_.at(channel_id)]; +} + +HloComputation* HloModuleGroupMetadata::PeerComputation( + const HloInstruction* instruction) const { + CHECK(IsChannelInstruction(instruction)); + const Channel& channel = GetChannel(instruction->channel_id()); + switch (instruction->opcode()) { + case HloOpcode::kSend: + case HloOpcode::kSendDone: + return channel.recv->parent(); + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + return channel.send->parent(); + default: + LOG(FATAL) << "opcode not supported"; + } +} + +std::vector +HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const { + std::vector path; + const HloComputation* parent = hlo->parent(); + const TrackedInstruction* companion; + while ((companion = GetTrackedInstruction(parent)) != nullptr) { + parent = companion->instruction()->parent(); + path.push_back(*companion); + } + return path; +} + +bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility( + const std::vector& path0, + const std::vector& path1) const { + if (path0.size() != path1.size()) { + VLOG(5) << "Companion path size do not match: " << path0.size() + << " != " << path1.size(); + return false; + } + for (int64 i = 0; i < path0.size(); ++i) { + if (path0[i] != path1[i]) { + VLOG(5) << "Companion instructions at path index " << i + << " do not have the same opcode: " << path0[i].ToString() + << " vs " << path1[i].ToString(); + return false; + } + } + return true; +} + +int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { + for (int64 i = 0; i < modules_.size(); ++i) { + if (modules_[i] == module) { + return i; + } + } + LOG(FATAL) << "unknown module"; +} + +tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( + const HloInstruction& instruction) const { + // The module group metadata can be created in both "single module, multiple + // devices" and "multiple modules, no explicit devices" fashions. + // The API returns an optional even though the current implementation always + // returns a device, to account for cases where we cannot guess a device. + // In such cases the VerifyChannelInstructions() will return proper errors. + tensorflow::gtl::optional device = + instruction.sharding_unique_device(); + if (!device) { + device = GetModuleId(instruction.parent()->parent()); + } + return device; +} + +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + +Status HloModuleGroupMetadata::RecordInstructions() { + const auto visitor = [this](HloInstruction* hlo) -> Status { + if (hlo->opcode() == HloOpcode::kWhile) { + tracked_instructions_[hlo->while_condition()] = + TrackedInstruction(hlo, ComputationKind::kWhileCondition); + tracked_instructions_[hlo->while_body()] = + TrackedInstruction(hlo, ComputationKind::kWhileBody); + } else if (hlo->opcode() == HloOpcode::kConditional) { + tracked_instructions_[hlo->true_computation()] = + TrackedInstruction(hlo, ComputationKind::kConditionalTrue); + tracked_instructions_[hlo->false_computation()] = + TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } else if (hlo->opcode() == HloOpcode::kCall) { + tracked_instructions_[hlo->to_apply()] = + TrackedInstruction(hlo, ComputationKind::kCallFunction); + } + if (!IsChannelInstruction(hlo)) { + return Status::OK(); + } + + // Add a new channel if needed. + if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { + channels_.emplace_back(); + channels_.back().id = hlo->channel_id(); + channel_id_map_[hlo->channel_id()] = channels_.size() - 1; + max_channel_id_ = std::max(max_channel_id_, hlo->channel_id()); + } + Channel& channel = channels_[channel_id_map_[hlo->channel_id()]]; + + if (hlo->opcode() == HloOpcode::kSend) { + TF_RET_CHECK(channel.send == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple send instructions"; + channel.send = hlo; + } + if (hlo->opcode() == HloOpcode::kRecv) { + TF_RET_CHECK(channel.recv == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple recv instructions"; + channel.recv = hlo; + } + if (hlo->opcode() == HloOpcode::kSendDone) { + TF_RET_CHECK(channel.send_done == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple send-done instructions"; + channel.send_done = hlo; + } + if (hlo->opcode() == HloOpcode::kRecvDone) { + TF_RET_CHECK(channel.recv_done == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple recv-done instructions"; + channel.recv_done = hlo; + } + return Status::OK(); + }; + + for (HloModule* module : modules_) { + for (auto* computation : module->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(visitor)); + } + } + VLOG(2) << "Created " << channels_.size() << " channels"; + return Status::OK(); +} + +Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, + HloInstruction* instruction2) { + TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || + instruction1->opcode() == HloOpcode::kConditional || + instruction1->opcode() == HloOpcode::kCall); + VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " + << instruction2->ToString(); + + if (!ContainsKey(companion_set_index_, instruction1) && + !ContainsKey(companion_set_index_, instruction2)) { + companion_sets_.push_back( + tensorflow::MakeUnique>()); + auto companion_set = companion_sets_.back().get(); + companion_set->insert(instruction1); + companion_set->insert(instruction2); + companion_set_index_[instruction1] = companion_sets_.size() - 1; + companion_set_index_[instruction2] = companion_sets_.size() - 1; + } else if (!ContainsKey(companion_set_index_, instruction1)) { + companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_set_index_[instruction1] = companion_set_index_[instruction2]; + } else if (!ContainsKey(companion_set_index_, instruction2)) { + companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_set_index_[instruction2] = companion_set_index_[instruction1]; + } else if (companion_set_index_[instruction1] != + companion_set_index_[instruction2]) { + companion_sets_[companion_set_index_[instruction1]]->insert( + Companions(instruction2).begin(), Companions(instruction2).end()); + int64 index_to_remove = companion_set_index_[instruction2]; + for (HloInstruction* hlo : Companions(instruction2)) { + companion_set_index_[hlo] = companion_set_index_[instruction1]; + } + companion_sets_.erase(companion_sets_.begin() + index_to_remove); + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyChannelInstructions() { + for (const Channel& channel : channels_) { + if (channel.send == nullptr) { + return FailedPrecondition("missing send for id : %lld", channel.id); + } + if (channel.recv == nullptr) { + return FailedPrecondition("missing recv for id : %lld", channel.id); + } + if (channel.send_done == nullptr) { + return FailedPrecondition("missing send-done for id : %lld", channel.id); + } + if (channel.recv_done == nullptr) { + return FailedPrecondition("missing recv-done for id : %lld", channel.id); + } + } + + // Check if the shapes match for each channel. + for (const Channel& channel : channels_) { + const Shape& send_shape = channel.send->operand(0)->shape(); + const Shape& recv_shape = channel.recv_done->shape(); + if (!ShapeUtil::Compatible(send_shape, recv_shape)) { + return FailedPrecondition("send/recv shapes do not match"); + } + auto send_device = GetInstructionDevice(*channel.send); + auto send_done_device = GetInstructionDevice(*channel.send_done); + if (!send_device) { + return FailedPrecondition("send instruction must have a device: %s", + channel.send->ToString().c_str()); + } + if (!send_done_device) { + return FailedPrecondition("send_done instruction must have a device: %s", + channel.send_done->ToString().c_str()); + } + if (*send_device != *send_done_device) { + return FailedPrecondition( + "send and send-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, *send_device, *send_done_device); + } + auto recv_device = GetInstructionDevice(*channel.recv); + auto recv_done_device = GetInstructionDevice(*channel.recv_done); + if (!recv_done_device) { + return FailedPrecondition("recv_done instruction must have a device: %s", + channel.recv_done->ToString().c_str()); + } + if (*recv_device != *recv_done_device) { + return FailedPrecondition( + "recv and recv-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, *recv_device, *recv_done_device); + } + if (*send_device == *recv_device) { + return FailedPrecondition( + "send and recv (channel=%lld) must be on different devices: %lld", + channel.id, *send_device); + } + } + + for (const Channel& channel : channels_) { + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done)); + } + // Check if the nest levels match for each channel. + for (const Channel& channel : channels_) { + std::vector path = GetCompanionsPath(channel.send); + if (!CheckCompanionPathsCompatibility( + path, GetCompanionsPath(channel.send_done)) || + !CheckCompanionPathsCompatibility(path, + GetCompanionsPath(channel.recv)) || + !CheckCompanionPathsCompatibility( + path, GetCompanionsPath(channel.recv_done))) { + return FailedPrecondition( + "Nest companion paths do not match for channel %lld", channel.id); + } + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::CheckCommunicatingInstruction( + HloInstruction* instruction) const { + HloComputation* computation = instruction->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return Status::OK(); + } + return FailedPrecondition("channel is used in disallowed computation"); +} + +void HloModuleGroupMetadata::DumpCollectedStats() const { + std::map, int64> communication_histogram; + for (auto& channel : channels_) { + auto from_device = GetInstructionDevice(*channel.send); + auto to_device = GetInstructionDevice(*channel.recv); + LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device + << " to_device=" << *to_device << " send=" << channel.send->name() + << " send_done=" << channel.send_done->name() + << " recv=" << channel.recv->name() + << " recv_done=" << channel.recv_done->name(); + communication_histogram[std::pair(*from_device, + *to_device)] += 1; + } + for (auto& fromto_count : communication_histogram) { + LOG(INFO) << "From " << fromto_count.first.first << " to " + << fromto_count.first.second << ": " << fromto_count.second; + } + for (auto& companion_set : companion_sets_) { + LOG(INFO) << "Companion set:"; + for (HloInstruction* instruction : *companion_set) { + LOG(INFO) << " " << instruction->name(); + } + } + for (auto& instruction_comm : tracked_instructions_comms_) { + LOG(INFO) << "Communicating instruction " << instruction_comm.first->name(); + for (HloInstruction* instruction : instruction_comm.second) { + auto device = GetInstructionDevice(*instruction); + LOG(INFO) << " " << instruction->name() << " on device " << *device; + } + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..ffde3a332dfc141ca928a44cfdf4686900e9f47b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -0,0 +1,267 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class for bookkeeping the information on the given modules, in particular on +// the interaction between computations. +// +// Companion instructions are one of the information collected as we build the +// metadata. For example, for each While instruction, companion instructions +// refer to a set of While instructions in other computations that communicate +// with each other. +// In the example below with 3 modules, {While_0, While_2, While_5}, {While_1, +// While_4}, {While_3, While_6} are companion sets. +// +// +// While_0() { While_2() { While_5() { +// While_1() { Send(0) } While_3() { Send(1) } While_6() { Recv(1) } +// } While_4() { Recv(0) } +// } +// +// Companion instructions are used to detect cycles in the graph and also for +// global scheduling. +class HloModuleGroupMetadata { + public: + // The kind of companion computation a given instruction can be within. + enum class ComputationKind { + kInvalid, + kWhileCondition, + kWhileBody, + kConditionalTrue, + kConditionalFalse, + kCallFunction, + }; + + // Tracks the instruction mapped to a given computation, and the computation + // kind. + // For example, a body computation of a while instruction, will generate a + // TrackedInstruction with instruction being the while instruction, and + // kind being ComputationKind::kWhileBody. + class TrackedInstruction { + public: + TrackedInstruction() = default; + TrackedInstruction(HloInstruction* instruction, ComputationKind kind) + : instruction_(instruction), kind_(kind) {} + + bool operator==(const TrackedInstruction& rhs) const { + return instruction_->opcode() == rhs.instruction_->opcode() && + kind_ == rhs.kind_; + } + bool operator!=(const TrackedInstruction& rhs) const { + return !operator==(rhs); + } + + HloInstruction* instruction() const { return instruction_; } + + string ToString() const; + + private: + HloInstruction* instruction_ = nullptr; + ComputationKind kind_ = ComputationKind::kInvalid; + }; + + // Represents a channel and the 4 instructions that form the channel. + struct Channel { + int64 id = -1; + HloInstruction* send = nullptr; + HloInstruction* recv = nullptr; + HloInstruction* send_done = nullptr; + HloInstruction* recv_done = nullptr; + }; + + explicit HloModuleGroupMetadata(const std::vector& modules) + : modules_(modules) {} + + ~HloModuleGroupMetadata() = default; + + // Build and return the metadata for the given modules. + static StatusOr> Build( + const std::vector& modules); + + // Returns true if the instruction is one of the 4 channel instructions (Send, + // Recv, SendDone, RecvDone). + bool IsChannelInstruction(const HloInstruction* instruction) const; + + // Returns true if the instruction is a companion instruction. See the class + // comment above on companion instructions. + bool IsCompanionInstruction(HloInstruction* hlo) const; + + // Returns true if the instruction is either a channel instruction or a + // companion instruction. + bool InstructionCommunicates(HloInstruction* hlo) const; + + // Returns the Channel instance for the given channel id. + const Channel& GetChannel(int64 channel_id) const; + + // Returns the computation that contains the peer channel instructions for + // the given instruction. + // + // Precondition: IsChannelInstruction(instruction) is true. + HloComputation* PeerComputation(const HloInstruction* instruction) const; + + // Returns the path of the nested companion instructions, in terms of HLO + // instructions. The path goes from inner to outer companions. + // The returned path does not include the input hlo instruction, in case it + // is a companion instruction. + std::vector GetCompanionsPath( + const HloInstruction* hlo) const; + + // Checks whether two companion paths (as returned by the GetCompanionsPath() + // API) are compatible. The two paths are compatible if the sequence of + // opcodes, and the companion kinds, of the two paths matches. + bool CheckCompanionPathsCompatibility( + const std::vector& path0, + const std::vector& path1) const; + + // Returns the unique integer for each module. The returned id is the index of + // the module in the module vector. + int64 GetModuleId(const HloModule* module) const; + + // Retrieves the device an instruction is assigned to. Either from the + // sharding information, or from the ordinal of the module the instruction + // is in. + tensorflow::gtl::optional GetInstructionDevice( + const HloInstruction& instruction) const; + + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + + // Returns the companion instructions for the given instruction. + // + // Precondition: IsCompanionWhile(instruction) is true. + const std::unordered_set& Companions( + HloInstruction* instruction) const { + CHECK_EQ(companion_set_index_.count(instruction), 1); + return companion_set(companion_set_index_.at(instruction)); + } + + // Returns the companion set at the given index. + const std::unordered_set& companion_set(int64 index) const { + CHECK_LT(index, companion_sets_.size()); + return *companion_sets_[index]; + } + + // Returns the companion set index of the given instruction. + int64 companion_set_index(HloInstruction* instruction) const { + return companion_set_index_.at(instruction); + } + + // Returns the list of all companion sets in the HLO module group. + const std::vector>>& + companion_sets() const { + return companion_sets_; + } + + // Returns all channels in the module group. + const std::vector& channels() const { return channels_; } + + // Returns the maximum channel id used in the module group. + int64 max_channel_id() const { return max_channel_id_; } + + private: + Status Build(); + + // Record all channel instructions and While instructions. + Status RecordInstructions(); + + // Verifies the given HloModules are well-formed and follow the specification, + // in particular with respect to using channel instructions. + // + // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). + // * The shape of channel instructions match. + // * The nest level of channel instructions match. + // * Channel instructions are used in allowed computations; i.e., in the + // entry computation of the module or condition/body of While computations. + // + // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a + // cycle in the graph, but it would be good to verify here. + Status VerifyChannelInstructions(); + + // Adds metadata that the given two instructions are companions. + Status AddCompanion(HloInstruction* instruction1, + HloInstruction* instruction2); + + // Checks whether a communicating instruction is placed in a valid position + // within the graph. + Status CheckCommunicatingInstruction(HloInstruction* instruction) const; + + // Performs a consistency check on the companion sets built for the input + // modules. Check that a companion set does not include instructions from the + // same module/device. + Status VerifyCompanionSets() const; + + // Retrieves a pointer to the stored TrackedInstruction associated with a + // tracked computation, or nullptr in case such computation is not tracked. + const TrackedInstruction* GetTrackedInstruction( + const HloComputation* computation) const { + auto it = tracked_instructions_.find(computation); + return it != tracked_instructions_.end() ? &it->second : nullptr; + } + + // Dump all the collected module group statistics to the logs. + void DumpCollectedStats() const; + + // List of all companion instructions sets in the module. + std::vector>> + companion_sets_; + + // Map from each companion while instruction to the index into companion_set_. + tensorflow::gtl::FlatMap companion_set_index_; + + // Map from computation to the instruction using it (a kWhile, kConditional). + tensorflow::gtl::FlatMap + tracked_instructions_; + + // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of + // communicating instructions within the proper called computation(s). + tensorflow::gtl::FlatMap> + tracked_instructions_comms_; + + // All channels in the module. + std::vector channels_; + + // Map from channel ids to the index in channels_. + tensorflow::gtl::FlatMap channel_id_map_; + + // The maximum channel id used in the module group. + int64 max_channel_id_ = -1; + + // The modules that this metadata was built from. + const std::vector& modules_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a0d1e264eb5095ff53721416ebcf4842a063f97 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -0,0 +1,317 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group_util.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +std::vector HloModuleGroupUtil::GlobalPredecessors( + HloInstruction* instruction) { + std::vector predecessors; + + // Adds to the unique predecessors list and also add companion instructions + // if the given predecessor has those. + auto add_unique_predecessor = [&](HloInstruction* predecessor) { + if (std::find(predecessors.begin(), predecessors.end(), predecessor) != + predecessors.end()) { + return; + } + if (!metadata_.IsCompanionInstruction(predecessor)) { + predecessors.push_back(predecessor); + return; + } + for (HloInstruction* companion : metadata_.Companions(predecessor)) { + predecessors.push_back(companion); + } + }; + + // If the given instruction is a companion instruction, we need to find the + // predecessors of all of its companion instructions. + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(instruction)) { + for (HloInstruction* companion : metadata_.Companions(instruction)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(instruction); + } + + for (HloInstruction* hlo : instruction_group) { + for (HloInstruction* operand : hlo->operands()) { + add_unique_predecessor(operand); + } + for (HloInstruction* control_predecessor : hlo->control_predecessors()) { + add_unique_predecessor(control_predecessor); + } + } + if (instruction->opcode() == HloOpcode::kRecvDone) { + // Send is a remote predecessor of RecvDone. + HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + add_unique_predecessor(send); + } + if (instruction->opcode() == HloOpcode::kSend) { + // Recv is a remote predecessor of Send. + HloInstruction* recv_done = + metadata_.GetChannel(instruction->channel_id()).recv_done; + CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + CHECK_EQ(recv_done->operand_count(), 1); + HloInstruction* recv = recv_done->mutable_operand(0); + add_unique_predecessor(recv); + } + return predecessors; +} + +std::vector HloModuleGroupUtil::GlobalSuccessors( + HloInstruction* instruction) { + std::vector successors; + + // Adds to the unique successors list and also add companion instructions + // if the given successor has those. + auto add_unique_successor = [&](HloInstruction* successor) { + if (std::find(successors.begin(), successors.end(), successor) != + successors.end()) { + return; + } + if (!metadata_.IsCompanionInstruction(successor)) { + successors.push_back(successor); + return; + } + for (HloInstruction* companion : metadata_.Companions(successor)) { + successors.push_back(companion); + } + }; + + // If the given instruction is a companion instruction, we need to find the + // successors of all of its companion instructions. + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(instruction)) { + for (HloInstruction* companion : metadata_.Companions(instruction)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(instruction); + } + + for (HloInstruction* hlo : instruction_group) { + for (HloInstruction* user : hlo->users()) { + add_unique_successor(user); + } + for (HloInstruction* control_successor : hlo->control_successors()) { + add_unique_successor(control_successor); + } + } + if (instruction->opcode() == HloOpcode::kRecv) { + // Send is a remote successor of Recv. + const HloInstruction* recv_done = instruction->users().front(); + CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + add_unique_successor(send); + } + if (instruction->opcode() == HloOpcode::kSend) { + // RecvDone is a remote successor of Send. + HloInstruction* recv_done = + metadata_.GetChannel(instruction->channel_id()).recv_done; + add_unique_successor(recv_done); + } + return successors; +} + +std::vector HloModuleGroupUtil::RootInstructions( + tensorflow::gtl::ArraySlice computations) { + std::vector roots; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + if (GlobalSuccessors(instruction).empty()) { + roots.push_back(instruction); + } + } + } + return roots; +} + +Status HloModuleGroupUtil::VisitTopologicalOrder( + VisitStates* visit_state, const VisitFunction& visit_function, + HloInstruction* root) { + // Stack of HLO instructions visited in DFS order. + std::stack stack; + stack.push(root); + + while (!stack.empty()) { + HloInstruction* hlo = stack.top(); + + // Find the instruction group of the currently visited instruction. The + // instruction group represents all companion instructions of the + // current instruction, and are considered to be a single entity for the + // purpose of the traversal (i.e., they must always be in the same visit + // state). + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(hlo)) { + for (HloInstruction* companion : metadata_.Companions(hlo)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(hlo); + } + + if ((*visit_state)[hlo] == VisitState::kVisited) { + // All instructions in the group must be in the same state. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited); + } + stack.pop(); + continue; + } + + if ((*visit_state)[hlo] == VisitState::kVisiting) { + TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group)); + + // Set the visit state of all instructions in the group to kVisited. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting); + (*visit_state)[instruction] = VisitState::kVisited; + } + stack.pop(); + continue; + } + + // Set the visit state of all instructions in the group to kVisiting. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited) + << instruction->ToString(); + (*visit_state)[instruction] = VisitState::kVisiting; + } + + // For each instruction in the group, visit its predecessors (operands, + // control predecessors and remote predecessors). + for (HloInstruction* instruction : instruction_group) { + for (HloInstruction* predecessor : GlobalPredecessors(instruction)) { + // Visiting a node that is already being visited implies that there is + // a cycle. Generate an error with the list of instructions in the + // cycle. + if ((*visit_state)[predecessor] == VisitState::kVisiting) { + string cyclic_instructions; + for (const auto& state : *visit_state) { + if (state.second == VisitState::kVisiting) { + tensorflow::strings::StrAppend(&cyclic_instructions, + state.first->ToString(), "\n"); + } + } + // TODO(b/64305524): Improve the error message to print out the + // instructions in a deterministic order that forms the cycle. + return FailedPrecondition( + "Cross-computation cycle detected via communicating nodes. The " + "cycle contains the node %s. The cycle is found among the " + "following nodes. Note that the order of the nodes is arbitrary " + "and that the list may include nodes that are not part of the " + "cycle.\n%s", + predecessor->ToString().c_str(), cyclic_instructions.c_str()); + } + stack.push(predecessor); + } + } + } + + return Status::OK(); +} + +Status HloModuleGroupUtil::VerifyComputations( + tensorflow::gtl::ArraySlice computations) { + auto visit_function = + [&](HloInstruction* instruction, + const std::vector& instruction_group) { + return Status::OK(); + }; + int64 instructions_count = 0; + VisitStates visit_states; + for (HloComputation* computation : computations) { + // Visit all instructions, and not just from the root instruction of the + // computation. This allows us to detect dead cycles (i.e., cycles that + // are not reachable from the root) or to enforce an order for the + // communication instructions that are not reachable from any roots. + for (HloInstruction* instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + VisitTopologicalOrder(&visit_states, visit_function, instruction)); + } + instructions_count += computation->instruction_count(); + } + + // Check if all instructions are visited and are in the visited state. + TF_RET_CHECK(visit_states.size() == instructions_count); + for (auto& state : visit_states) { + TF_RET_CHECK(state.second == VisitState::kVisited); + } + + return Status::OK(); +} + +StatusOr> +HloModuleGroupUtil::ComputeReachability( + tensorflow::gtl::ArraySlice computations) { + std::list post_order; + auto visit_function = + [&](HloInstruction* instruction, + const std::vector& instruction_group) { + post_order.insert(post_order.end(), instruction_group.begin(), + instruction_group.end()); + return Status::OK(); + }; + HloModuleGroupUtil::VisitStates visit_states; + for (HloInstruction* root : RootInstructions(computations)) { + TF_RETURN_IF_ERROR( + VisitTopologicalOrder(&visit_states, visit_function, root)); + } + auto reachability = MakeUnique(post_order); + for (HloInstruction* hlo : post_order) { + reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); + } + return std::move(reachability); +} + +void HloModuleGroupUtil::UpdateReachabilityThroughInstruction( + HloInstruction* instruction, HloReachabilityMap* reachability_map) { + std::queue worklist; + worklist.push(instruction); + + while (!worklist.empty()) { + HloInstruction* item = worklist.front(); + worklist.pop(); + if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item), + item)) { + for (HloInstruction* successor : GlobalSuccessors(item)) { + worklist.push(successor); + } + } + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h new file mode 100644 index 0000000000000000000000000000000000000000..c25ca1aff50b288f3ac3885cbed53e7ba9768430 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +// Collection of utilities for handling HloModuleGroups. +class HloModuleGroupUtil { + public: + explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata) + : metadata_(metadata) {} + + // Returns all unique predecessors of the instruction. This includes: + // * predecessors in the same computation: operands and control predecessors + // * Recv is a predecessor of Send + // * Send is a predecessor of RecvDone + // * predecessors of companions (if the instruction is a companion while) + // * predecessors' companions (for any predecessor that is a companion while) + std::vector GlobalPredecessors(HloInstruction* instruction); + + // Returns all unique successors of the instruction. This includes: + // * successors in the same computation: users and control successors + // * Send is a successor of Recv + // * RecvDone is a predecessor of Send + // * successors of companions (if the instruction is a companion while) + // * successors' companions (for any successor that is a companion while) + std::vector GlobalSuccessors(HloInstruction* instruction); + + // Returns the root instructions of the computations. + std::vector RootInstructions( + tensorflow::gtl::ArraySlice computations); + + // Visit state of each instruction during DFS traversal. + enum VisitState { + kNotVisited = 0, + kVisiting, + kVisited, + }; + + // Function called on each instruction group during the DFS traversal. See the + // comment for VisitTopologicalOrder()). + using VisitFunction = std::function& instruction_group)>; + + // Given the hlo instruction as the root, recursively visits all its + // predecessor instructions in DFS order to visit nodes in topological order. + // + // Note that the DFS traversal does not only visit nodes in the same + // computation (parent of the root instruction), but also visits nodes in + // different computations connected via communication instructions. During the + // traversal, companion While instructions (see the class comment in + // HloModuleGroupMetadata) are treated as a single instruction (called + // instruction group, which contains only a single instruction if the visiting + // node is not a companion while) -- visiting one of the instructions in the + // group effectively visits all other instructions in the group, and then all + // predecessor instructions of the group are visited. + // + // * visit_state: map from each instruction to its visit state. + // * visit_function: function called when each instruction group. + // * root: the root instruction of the traversal. + using VisitStates = tensorflow::gtl::FlatMap; + Status VisitTopologicalOrder(VisitStates* visit_state, + const VisitFunction& visit_function, + HloInstruction* root); + + // Verifies that the computations are well-formed (e.g., no cycles). + Status VerifyComputations( + tensorflow::gtl::ArraySlice computations); + + // Below Reachability utils resemble those in HloComputation, except that + // they can handle instructions across multiple computations. + // + // Creates the reachability map for the instructions in the computations. + StatusOr> ComputeReachability( + tensorflow::gtl::ArraySlice computations); + + // Updates the reachability of the given instruction, taking the global + // predeccessorss and successors into account. + void UpdateReachabilityThroughInstruction( + HloInstruction* instruction, HloReachabilityMap* reachability_map); + + private: + const HloModuleGroupMetadata& metadata_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index af24604c39b554f146793594958f373999844b4c..a35546f5f41b149d119ee141fd734da8bfd055b2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -57,6 +57,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ V(kConditional, "conditional") \ @@ -68,16 +69,19 @@ namespace xla { V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ + V(kDomain, "domain") \ V(kDot, "dot") \ V(kDynamicSlice, "dynamic-slice") \ V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kExpm1, "exponential-minus-one") \ V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ @@ -86,6 +90,7 @@ namespace xla { V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ + V(kLog1p, "log-plus-one") \ V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f030c65b889d67e082a3677b8739ddb..774345124b4ad62e35d9423a23f1dbaa28e44d80 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 1b24d8da9e832e6847cb6f405e15af3c455f695a..dcd4725fe78e8b9b5d14437e964cb5aaf1664117 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -66,6 +65,28 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } } + // If the common ancestor is a conditional instruction, even though the true + // and false computations are not really ordered per-se, we define the true + // computation to be ordered before the false one. + // This ensures that buffers can still be shared among the two computations + // as they will forcibly have disjoint liveness. + if (a_ancestor == b_ancestor && + a_ancestor->opcode() == HloOpcode::kConditional) { + const HloComputation* true_computation = a_ancestor->true_computation(); + const HloComputation* false_computation = a_ancestor->false_computation(); + if (call_graph_->InstructionIsNestedIn(a, true_computation) && + call_graph_->InstructionIsNestedIn(b, false_computation)) { + return true; + } + // If 'b' is the conditional ancestor, and 'a' is within the true or false + // computations, 'a' executes before 'b'. + if (b == a_ancestor && + (call_graph_->InstructionIsNestedIn(a, true_computation) || + call_graph_->InstructionIsNestedIn(a, false_computation))) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } @@ -118,7 +139,18 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { b.defining_instruction()->while_condition()))) { return true; } - + // If 'b' is a conditional phi and 'a' is in the true or false computation, + // then 'a' executes before 'b'. + if (b.is_phi() && + b.defining_instruction()->opcode() == HloOpcode::kConditional && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->true_computation()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->false_computation()))) { + return true; + } return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); } @@ -137,10 +169,10 @@ bool HloOrdering::UseIsBeforeValueDefinition( // is before the def if the instruction allows buffer sharing (in place // computation). if (use.instruction == value.defining_instruction() && - CanShareOperandBufferWithUser( + dataflow.CanShareOperandBufferWithUser( use.instruction->mutable_operand(use.operand_number), use.operand_index, value.defining_instruction(), - value.defining_index(), dataflow)) { + value.defining_index())) { VLOG(4) << " use is value def, and instruction can share use buffer"; return true; } @@ -212,18 +244,17 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() << ", b = " << b.ToShortString() << ")"; if (!IsDefinedBefore(a, b)) { - VLOG(4) << "a not defined before b"; + VLOG(4) << a << " not defined before " << b; return false; } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { - VLOG(4) << "use of a (" << use << ") not before b is defined"; + VLOG(4) << "use of " << a << " (" << use << ") not before " << b + << " is defined"; return false; } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index ee526d8dd7f7e81b3a846741d3e452935f486bd2..985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -183,6 +183,10 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: + // TODO(dimvar): HloModuleSequence is not a good name because it sounds like + // a sequence of modules, instead of a map of schedules for all computations + // in a module. We should change it at some point. + // // A sequence of instructions for each computation in the module. using HloModuleSequence = tensorflow::gtl::FlatMapAddEntryComputation(builder.Build()); - - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - - // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - - SequentialHloOrdering ordering(module.get(), sequence); - EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); -} - TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // Tests the ordering of instructions in different computations using the // following HLO code: @@ -357,10 +310,71 @@ ENTRY while.v11 { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } +TEST_F(HloOrderingTest, ConditionalInstructionOrdering) { + const char* module_str = R"( +HloModule test_conditional_module + +true_branch { + param.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1) +} + +false_branch { + param.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1 + add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4) + ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4) +} + +ENTRY root { + param.3 = (pred[], (s32[], s32[])) parameter(0) + pred.1 = pred[] get-tuple-element(param.3), index=0 + cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1 + conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch + cond_res.1 = s32[] get-tuple-element(conditional), index=0 + cond_res.2 = s32[] get-tuple-element(conditional), index=1 + add.3 = s32[] add(cond_res.1, cond_res.2) + ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + + // Even though the true and false branches has no ordering, since they do not + // interfere (as they are mutually exclusive), we define the true computation + // to be before the false one. + // Similarly, any instruction in the true or false branches are considered + // before the conditional instruction. The roots are effectively "at the same + // time" WRT the conditional, but they are Phi-ed anyway. + HloInstruction* add_1 = FindInstruction(module.get(), "add.1"); + HloInstruction* add_2 = FindInstruction(module.get(), "add.2"); + HloInstruction* add_3 = FindInstruction(module.get(), "add.3"); + HloInstruction* conditional = FindInstruction(module.get(), "conditional"); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_2))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_3))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(add_3))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc similarity index 83% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.cc rename to tensorflow/compiler/xla/service/hlo_parser.cc index cd2b843ad36013ae83818ecbc184fb823093f037..fef475380c5c810e1c4712406dde6b1135be3d97 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -24,17 +26,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -namespace tools { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using ::tensorflow::StringPiece; +using ::tensorflow::gtl::optional; +using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsInts; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -53,7 +55,12 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } + + // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); private: // ParseXXX returns false if an error occurred. @@ -77,11 +84,15 @@ class HloParser { // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. - bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + Literal* literal); + bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + Literal* literal); template - bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + bool SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal); bool ParseOperands(std::vector* operands); @@ -93,9 +104,15 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr entry_metadata; + std::unique_ptr exit_metadata; }; // Types of attributes. @@ -116,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -163,21 +181,27 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector* result); + const TokKind delim, + std::vector* result); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -189,7 +213,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParseInt64(int64* result); + bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -242,10 +266,10 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -303,18 +327,20 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - if (param_shape.has_layout()) { - module_->mutable_entry_computation_layout() - ->mutable_parameter_layout(p) - ->ResetLayout(param_shape.layout()); - } + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(p) + ->CopyLayoutFromShape(param_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + ->mutable_parameter_layout(p) + ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - if (result_shape.has_layout()) { - module_->mutable_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(result_shape.layout()); - } + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); } return true; @@ -381,6 +407,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } + instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } @@ -437,10 +464,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - int64 parameter_number; + tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || @@ -470,13 +501,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -553,11 +587,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional to_apply; + optional> replica_group_ids; + optional barrier; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + attrs["replica_group_ids"] = { + /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + + if (replica_group_ids) { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, *replica_group_ids, + barrier ? *barrier : "")); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, {}, barrier ? *barrier : "")); + } break; } case HloOpcode::kReshape: { @@ -569,6 +620,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kGenerateToken: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGenerateToken(operands)); + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; @@ -592,7 +651,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -603,7 +662,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -617,7 +676,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -628,7 +687,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -642,7 +701,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -700,7 +759,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -713,7 +772,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -725,7 +784,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -740,6 +799,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -751,7 +813,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -764,7 +826,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -808,7 +870,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -832,7 +894,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -846,7 +908,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -862,7 +924,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -879,7 +941,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -950,8 +1012,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -996,7 +1058,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kHostCompute: { optional channel_name; - optional cost_estimate_ns; + optional cost_estimate_ns; attrs["channel_name"] = {/*required=*/true, AttrTy::kString, &channel_name}; attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, @@ -1009,16 +1071,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; @@ -1049,18 +1111,65 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } - case HloOpcode::kGather: - // TODO(b/72710576): HLO parsing is not implemented for Gather. - return TokenError("HLO parsing is not implemented for Gather"); + case HloOpcode::kGather: { + optional> output_window_dims; + attrs["output_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; + optional> elided_window_dims; + attrs["elided_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; + optional> gather_dims_to_operand_dims; + attrs["gather_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &gather_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + optional> window_bounds; + attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, + &window_bounds}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateGather( + shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], + dim_numbers, *window_bounds)); + break; + } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.entry_metadata), + std::move(domain.exit_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - instruction->set_name(name); + instruction->SetAndSanitizeName(name); + if (instruction->name() != name) { + return Error(name_loc, + StrCat("illegal instruction name: ", name, + "; suggest renaming to: ", instruction->name())); + } - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1077,6 +1186,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_raw_backend_config_string(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1126,8 +1238,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { @@ -1154,7 +1266,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - int64 dim; + tensorflow::int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1166,7 +1278,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - int64 device; + tensorflow::int64 device; if (!ParseInt64(&device)) { return false; } @@ -1225,10 +1337,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); *sharding->mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment_dimensions) { + for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (int64 device : devices) { + for (tensorflow::int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1237,6 +1349,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; + optional entry_sharding; + optional exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector* instructions) { @@ -1263,40 +1403,50 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, +bool HloParser::SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, int64 linear_index, +bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1307,7 +1457,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, } } -bool HloParser::SetValueInLiteral(bool value, int64 linear_index, +bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -1320,7 +1470,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index, } template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal) { // Check that linear_index is in range. if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { @@ -1432,7 +1583,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; } @@ -1440,8 +1591,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // Create a literal with the given shape in default layout. *literal = Literal::CreateFromDimensions(shape.element_type(), AsInt64Slice(shape.dimensions())); - int64 nest_level = 0; - int64 linear_index = 0; + tensorflow::int64 nest_level = 0; + tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -1449,16 +1600,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), - elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim( + elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1533,7 +1683,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { LocTy loc = lexer_.GetLoc(); - int64 value; + tensorflow::int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -1573,29 +1723,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, switch (shape.element_type()) { case PRED: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F32: return ParseSparseLiteralHelper(literal, shape); case BF16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F64: return ParseSparseLiteralHelper(literal, shape); default: @@ -1608,9 +1758,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, template bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, const Shape& shape) { - std::vector index; + std::vector index; - int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = ShapeUtil::Rank(shape); *literal = MakeUnique(shape); @@ -1628,7 +1778,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, LocTy index_loc = lexer_.GetLoc(); index.clear(); if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); + tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); if (rank != 1) { return Error( @@ -1646,7 +1796,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1661,7 +1811,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, value = static_cast(lexer_.GetKind() == TokKind::kw_true); lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value_s64; + tensorflow::int64 value_s64; if (!ParseInt64(&value_s64)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -1814,7 +1964,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -1822,23 +1984,24 @@ bool HloParser::ParseAttributeHelper( LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(result); return true; } case AttrTy::kInt32: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -1872,7 +2035,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -1914,12 +2077,12 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } @@ -1964,6 +2127,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast(attr_out_ptr)); + } } }(); if (!success) { @@ -1990,9 +2156,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2002,7 +2169,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2066,7 +2235,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2090,7 +2260,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( << str; } - const int64 rank = lhs_rhs_out[0].length(); + const tensorflow::int64 rank = lhs_rhs_out[0].length(); if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2204,7 +2374,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2238,7 +2408,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { if (!ParseToken(start, StrCat("expects an int64 list starting with ", TokKindToString(start)))) { return false; @@ -2247,7 +2417,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // empty } else { do { - int64 i; + tensorflow::int64 i; if (!ParseInt64(&i)) { return false; } @@ -2364,7 +2534,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParser::ParseDxD(const string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, @@ -2372,7 +2543,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 1D if (lexer_.GetKind() == TokKind::kInt) { - int64 number; + tensorflow::int64 number; if (!ParseInt64(&number)) { return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } @@ -2392,7 +2563,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParser::ParseWindowPad( + std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -2403,7 +2575,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (int i = 0; i < padding_str.size(); i++) { - std::vector low_high; + std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -2427,7 +2599,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -2449,7 +2621,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -2536,7 +2708,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -2619,10 +2791,48 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShardingOnly() { + lexer_.Lex(); + OpSharding op_sharding; + if (!ParseSharding(&op_sharding)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after sharding"); + } + return HloSharding::FromProto(op_sharding); +} + +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace -StatusOr> Parse(StringPiece str, - const HloModuleConfig& config) { +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -2630,10 +2840,29 @@ StatusOr> Parse(StringPiece str, return parser.ConsumeHloModule(); } -StatusOr> Parse(StringPiece str) { +StatusOr> ParseHloString( + tensorflow::StringPiece str) { + HloModuleConfig config; + return ParseHloString(str, config); +} + +StatusOr ParseSharding(tensorflow::StringPiece str) { HloModuleConfig config; - return Parse(str, config); + HloParser parser(str, config); + return parser.ParseShardingOnly(); +} + +StatusOr ParseWindow(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h similarity index 52% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.h rename to tensorflow/compiler/xla/service/hlo_parser.h index 2f97a2b9b19d0cdb64a2869913da62c55e14c1d5..3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -13,30 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace tools { + +// For details about the syntax accepted by this parser, see +// g3doc/hlo_parser.md. // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. -StatusOr> Parse(tensorflow::StringPiece str, - const HloModuleConfig& config); +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> Parse(tensorflow::StringPiece str); +StatusOr> ParseHloString( + tensorflow::StringPiece str); + +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +StatusOr ParseSharding(tensorflow::StringPiece str); + +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +StatusOr ParseSharding(tensorflow::StringPiece str); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc similarity index 83% rename from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc rename to tensorflow/compiler/xla/service/hlo_parser_test.cc index b8c6b59204f897c7dc07b846370b5b776a19a808..f834d34d57106b11cf398f966d8c0224f00d1b8d 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { -namespace tools { + namespace { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; struct TestData { string test_name; @@ -64,7 +66,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -80,13 +82,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -231,6 +234,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} } +)" +}, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + )" }, // int32 result = 0; @@ -716,6 +730,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] { ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) } +)" +}, +{ +"gather", +R"(HloModule StringifyGather + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); @@ -739,7 +765,7 @@ add_F32.v3 { ENTRY MapBinaryAdder.v3 { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) - ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 + ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3 } )" @@ -862,6 +888,54 @@ ENTRY dot { )" }, +{ +"gather", +R"(HloModule gather + +ENTRY Gather { + input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + +)" +}, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add +} + +)" +}, +// cross-replica-sum with subgroups +{ +"CrossReplicaSumWithSubgroups", +R"(HloModule CRS_Subgroups + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CrossReplicaSumWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc" +} + +)" +} }); // clang-format on } @@ -870,16 +944,16 @@ class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } - // Expects "ToString(Parse(string)) == string", that is, parses the string, - // asserts that it succeeded, stringifies the parsed module, and checks that - // the it equals the original string. + // Expects "ToString(ParseHloString(string)) == string", that is, parses the + // string, asserts that it succeeded, stringifies the parsed module, and + // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString( HloPrintOptions().set_print_large_constants(true))); @@ -890,7 +964,7 @@ class HloParserShortTest : public HloParserTest { protected: void ExpectEqualShort() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); @@ -911,14 +985,14 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOpcode) { @@ -931,8 +1005,8 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongShape) { @@ -943,8 +1017,8 @@ ENTRY %blabla (x: g32[]) -> g32[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOperandsSize) { @@ -956,8 +1030,8 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, OperandNotFound) { @@ -967,8 +1041,8 @@ ENTRY %blabla (x: f32[]) -> pred[] { %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, MoreConstants) { @@ -982,12 +1056,25 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // Constant instructions have no name. The string will be parsed successfully // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = ParseHloString(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->raw_backend_config_string()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -996,8 +1083,8 @@ ENTRY %some_2 () -> f32[2] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); } @@ -1010,8 +1097,8 @@ ENTRY %some_2x3 () -> f32[2,3] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); } @@ -1024,8 +1111,8 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); } @@ -1039,8 +1126,8 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); } @@ -1053,7 +1140,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // The string will be parsed successfully but the output strings are not // exactly the same, because "3e2" is parsed into value 300 and will be @@ -1067,11 +1154,11 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, InvalidDimLabels) { @@ -1087,17 +1174,18 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; + ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + ExpectHasSubstr( - Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), - "expects dim labels pattern"); - - ExpectHasSubstr(Parse(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) - .status() - .error_message(), - "must have the same rank"); + "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { @@ -1112,8 +1200,8 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { @@ -1128,7 +1216,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "attribute channel_id is expected but not seen"); } @@ -1144,7 +1232,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "'done' is not defined"); } @@ -1157,7 +1245,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { @@ -1171,7 +1259,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects padding_low and padding_high separated by '_'"); } @@ -1183,7 +1271,7 @@ ENTRY %test_comma.v4 () -> f32[] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { @@ -1193,7 +1281,7 @@ ENTRY %CustomCall () -> f32[1] { %constant = f32[1]{0} constant({12345}) ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "Shape of computation CustomCall, f32[1], is not compatible " "with that of its root instruction foo, f32[1,2,3]"); } @@ -1212,9 +1300,9 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->entry_computation_layout(); + auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); @@ -1235,7 +1323,7 @@ c1 { c2 { const2 = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); } @@ -1246,7 +1334,7 @@ ENTRY consts { first = f32[1]{0} constant({12345}) last = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ( module.ValueOrDie()->entry_computation()->root_instruction()->name(), @@ -1261,7 +1349,7 @@ ENTRY c1 { ENTRY c2 { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects only one ENTRY"); } @@ -1271,25 +1359,10 @@ ENTRY consts { ROOT const1 = f32[1]{0} constant({12345}) ROOT const2 = f32[1]{0} constant({12345}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "one computation should have only one ROOT"); } -TEST_F(HloParserTest, InstructionExists) { - const string original = R"(HloModule comp_exists -c1 { - instr = f32[1]{0} constant({12345}) -} -c2 { - instr = f32[1]{0} constant({67890}) -})"; - - ExpectHasSubstr(Parse(original).status().error_message(), - R"(was parsing 3:3: error: instruction previously defined here - instr = f32[1]{0} constant({12345}) - ^)"); -} - TEST_F(HloParserTest, ComputationExists) { const string original = R"(HloModule comp_exists comp { @@ -1298,12 +1371,52 @@ comp { comp { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), R"(was parsing 2:1: error: computation previously defined here comp { ^)"); } +TEST_F(HloParserTest, CrossComputationLookup) { + const string original = R"(HloModule cross_computation_lookup: +tcalla (a: (s32[], s32[])) -> (s32[], s32[]) { + ROOT aparam = (s32[], s32[]) parameter(0) +} + +tcallb (b: (s32[], s32[])) -> s32[] { + rparam = (s32[], s32[]) parameter(0) + ROOT gte0 = s32[] get-tuple-element(aparam), index=0 +} + +ENTRY entry { + param = (s32[], s32[]) parameter(0) + call0 = (s32[], s32[]) call(param), to_apply=tcalla + ROOT call1 = s32[] call(param), to_apply=tcallb +})"; + ExpectHasSubstr( + ParseHloString(original).status().error_message(), + "was parsing 8:39: error: instruction does not exist: aparam"); +} + +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + } // namespace -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 5120775737bfa32bbb656421216f2b3fbef590ea..d8f1ab916b5c5c500c2d8dcd8605be083f95862a 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -90,7 +90,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = name().ToString() + ": pipeline start"; + string prefix = std::string(name()) + ": pipeline start"; bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +98,12 @@ StatusOr HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), - "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + std::string(name()), "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(pass->name().ToString()) > 0) { + if (disabled_passes.count(std::string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -121,7 +121,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - name().ToString(), pass->name().ToString()); + std::string(name()), std::string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 78e6a101c10a1e812e3e2631d520139fd0bc425c..3460679558d185d1e022660d9a1d23176d0d96bf 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include + +#include "tensorflow/compiler/xla/util.h" + namespace xla { HloProto MakeHloProto(const HloModule& module, @@ -35,4 +39,35 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + + std::vector parameter_shapes; + const auto& program_shape = hlo_proto.hlo_module().program_shape(); + for (const Shape& shape : program_shape.parameters()) { + parameter_shapes.push_back(&shape); + } + return parameter_shapes; +} + +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + if (!hlo_proto.hlo_module().program_shape().has_result()) { + return NotFound("HloProto missing result in its program shape"); + } + + return &hlo_proto.hlo_module().program_shape().result(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 320288fdb9aa0810b306b1d78bd1ff4cfc366ed2..3d9c375cd5d26f92cf8316f78789daf4fc08c927 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,15 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Returns the shapes of the parameters of the entry computation. Shape pointers +// refer to shapes inside of the given HloProto. +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto); + +// Returns the shape of the output of the entry computation. The shape pointer +// refers to the output shape inside of the given HloProto. +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9cca138703c8fa61aadf69dd7304a215a9f4be2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class HloProtoUtilTest : public ::testing::Test {}; + +TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) { + HloProto hlo_proto; + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing HloModuleProto")); +} + +TEST_F(HloProtoUtilTest, MissingProgramShape) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_name("entry"); + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing program shape")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 8e167633bb13476301fa0c4afa0b123c9b47e40d..4738e46f8aeb96a4c25d04b3246bd21f644fe3ea 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion( const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; + SetReachabilityToUnionHelper(inputs, instruction, &bit_vector); + return bit_vector != tmp_bit_vector_; +} +void HloReachabilityMap::FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); +} + +void HloReachabilityMap::SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { - bit_vector.SetToZero(); + bit_vector->SetToZero(); } - bit_vector.Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector.OrWith(GetBitVector(input)); + bit_vector->OrWith(GetBitVector(input)); } - - return bit_vector != tmp_bit_vector_; } void HloReachabilityMap::SetReachable(const HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 553ec11f6f9a2997ab7113f9b8241e04c7fe20d5..69bb2b3cee6dafe058c45b4e74e93401bea2cfc9 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -57,6 +57,11 @@ class HloReachabilityMap { tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency @@ -133,6 +138,11 @@ class HloReachabilityMap { return bit_vectors_[GetIndex(instruction)]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector); + // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 98b8d34be1f331aaeac94e952deeae1e76379861..9c7bc7a5ea7c77dadb8772f08b823c3579cf2154 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -71,6 +71,20 @@ bool IsRematerializable(const HloInstruction* instruction) { } } +// Checks whether an instruction can be rematerialized, by looking up the +// cache before, and eventually calling the IsRematerializable() API. +bool CanBeRematerialized( + const HloInstruction* instruction, + tensorflow::gtl::FlatMap* remat_able) { + auto it = remat_able->find(instruction); + if (it != remat_able->end()) { + return it->second; + } + bool rematerializable = IsRematerializable(instruction); + (*remat_able)[instruction] = rematerializable; + return rematerializable; +} + // Type holding a unique identifier for each Buffer object. using BufferId = int64; using BufferIdList = tensorflow::gtl::InlinedVector; @@ -273,9 +287,8 @@ ItemList GetUsers(const InstructionList& instruction_list, for (const BufferAlias& buffer_alias : points_to_analysis.GetBufferAliases(*logical_buffer)) { for (const HloInstruction* user : buffer_alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(buffer_alias.instruction(), - buffer_alias.index(), user, - points_to_analysis)) { + if (points_to_analysis.DoesNotUseOperandBuffer( + buffer_alias.instruction(), buffer_alias.index(), user)) { // The alias may be an operand of 'user', but the LogicalBuffer cannot // possibly be used by the instruction so ignore 'user'. This is the // case, for example, for the tuple element buffers in a GetTupleElement @@ -844,9 +857,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, - const InstructionList& instruction_list, - int64 memory_limit_bytes) { +Item* PickRematerializationCandidate( + const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, int64 memory_limit_bytes, + tensorflow::gtl::FlatMap* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -870,8 +884,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, << " is excluded from rematerialization"; continue; } - - if (!IsRematerializable(candidate)) { + if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; continue; @@ -975,6 +988,9 @@ StatusOr HloRematerialization::RematerializeComputation( // blacklist. tensorflow::gtl::FlatSet remat_move_instructions; + // The map from instructions to their rematerializable status. + tensorflow::gtl::FlatMap remat_able; + // The peak memory of the computation at any point in the instruction // sequence. int64 peak_memory = memory_tracker.memory_usage(); @@ -1012,7 +1028,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes); + memory_tracker, instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1214,9 +1230,9 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( *module, - [this](const LogicalBuffer& buffer) { + [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); @@ -1320,7 +1336,7 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SchedulerAlgorithm scheduler_algorithm, + MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { HloRematerialization remat(scheduler_algorithm, size_function); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 52553439033a3bcfa4b472f13f9cd4b1ecf5ed96..2ee2dd0571ae8c6604e4ca722351fd48a913bda5 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -66,12 +66,12 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, const ShapeSizeFunction& size_function) : scheduler_algorithm_(scheduler_algorithm), size_function_(size_function) {} @@ -108,7 +108,7 @@ class HloRematerialization { const HloInstruction* instruction) const; // Selects an algorithm to use for HLO scheduling. - SchedulerAlgorithm scheduler_algorithm_; + MemorySchedulerAlgorithm scheduler_algorithm_; // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 1b7d26dde501a6a0955d62ea0938e0683a32d49d..e81334d5a84268a129cd4e90091e97dc23243226 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase { // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[] %param = {...} + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, /*limit_indices=*/{1}, @@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence) { + TF_EXPECT_OK(verifier().Run(module).status()); + return HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, + sequence); + } + // Various shapes used in the canned computations. const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); @@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_EQ(body_computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // Both computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 8); + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); } // Test rematerialization of a doubly nested computation. All computations @@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/middle_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(middle_computation->instruction_count(), 6); - EXPECT_EQ(inner_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // All computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(middle_computation->instruction_count(), 7); - EXPECT_EQ(inner_computation->instruction_count(), 8); + // All computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(middle_computation->instruction_count(), 9); + EXPECT_EQ(inner_computation->instruction_count(), 9); } TEST_F(HloRematerializationTest, RngNotRematerialized) { @@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, + bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), SchedulerAlgorithm::kAuto, &sequence)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 41b079eb799d06321a31f7d7ae0630dc8d58c46b..e1f9d8efd4974055947438c8a2e15cb77d1b5c75 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -16,27 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_runner.h" -#include #include #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { /*static*/ StatusOr> @@ -44,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } namespace { @@ -52,10 +44,9 @@ namespace { // Creates an HloModule from the given proto. StatusOr> HloProtoToModule( const HloProto& proto, const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN( - HloModuleConfig config, - HloModule::CreateModuleConfigFromProto(proto.hlo_module())); - config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto.hlo_module(), + debug_options)); TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(proto.hlo_module(), config)); return std::move(module); @@ -89,18 +80,9 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } -// Define this in .cc file to avoid having to include eigen or forward declare -// these types in the header. -struct HloRunner::EigenThreadPoolWrapper { - std::unique_ptr pool; - std::unique_ptr device; -}; - -HloRunner::HloRunner() {} - HloRunner::HloRunner(se::Platform* platform) { BackendOptions backend_options; backend_options.set_platform(platform); @@ -110,81 +92,266 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::ExecuteInternal( +StatusOr HloRunner::TransferLiteralToDevice( + const Literal& literal) { + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), backend().memory_allocator(), + backend().default_device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, buffer)); + return std::move(buffer); +} + +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals) { + std::vector buffers; + for (const Literal* literal : literals) { + CHECK(literal != nullptr); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + TransferLiteralToDevice(*literal)); + buffers.push_back(std::move(buffer)); + } + return std::move(buffers); +} + +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals) { + std::vector literal_pointers; + literal_pointers.reserve(literals.size()); + for (const auto& literal : literals) { + literal_pointers.push_back(literal.get()); + } + return TransferLiteralsToDevice(literal_pointers); +} + +StatusOr> HloRunner::TransferLiteralFromDevice( + const ShapedBuffer& buffer) { + return backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), buffer); +} + +StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - if (run_hlo_passes) { - TF_ASSIGN_OR_RETURN( - module, backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr)); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_buffers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(argument.get()); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor(), - /*device_allocator=*/nullptr)); + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Get service run options. se::Stream stream(backend().default_stream_executor()); stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); - ExecutableRunOptions run_options; - run_options.set_device_ordinal(backend().default_device_ordinal()); - run_options.set_stream(&stream); - run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - backend().eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_run_options( - run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); + return executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments); +} - // Copy arguments to device. - std::vector> argument_buffers; - std::vector argument_buffer_ptrs; - for (Literal* argument : arguments) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr argument_buffer, - backend().transfer_manager()->AllocateScopedShapedBuffer( - argument->shape(), run_options.allocator(), - run_options.device_ordinal())); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, *argument_buffer)); - argument_buffers.push_back(std::move(argument_buffer)); - argument_buffer_ptrs.push_back(argument_buffers.back().get()); +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); } + return ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} +StatusOr>> HloRunner::ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( - std::unique_ptr result, - executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, - /*hlo_execution_profile=*/nullptr)); - - // Create a ScopedShapedBuffer of the result to manage deallocation. This will - // deallocate all the device memory when it goes out of scope. + std::unique_ptr executable, + CreateExecutable(std::move(module), options.run_hlo_passes)); TF_ASSIGN_OR_RETURN( - std::unique_ptr scoped_result, - ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator())); - - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), *scoped_result); - if (result_literal.ok()) { - VLOG(4) << "Executed binary and got result: " - << result_literal.ValueOrDie()->ToString(); - } else { - VLOG(4) << "Executed binary and got status: " - << result_literal.status().ToString(); + DeviceAssignment device_assignment, + backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + std::vector> streams; + std::vector service_run_options; + + std::vector argument_buffers; + // This reserve() call is necessary for correctness, because + // argument_buffer_ptrs contains pointers into the elements of + // argument_buffers. + argument_buffers.reserve(options.num_replicas * options.arguments.size()); + + // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are + // no arguments. + std::vector argument_buffer_ptrs( + options.num_replicas * options.arguments.size() + 1); + std::vector> + argument_buffer_slices; + int64 index = 0; + for (int64 i = 0; i < options.num_replicas; ++i) { + int64 device = device_assignment(i, 0); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device)); + streams.push_back(MakeUnique(executor)); + streams.back()->Init(); + service_run_options.emplace_back(GetServiceRunOptionsForDevice( + device, streams.back().get(), &device_assignment)); + + // Copy arguments to device. + for (const Literal* argument : options.arguments) { + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer argument_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + argument->shape(), backend().memory_allocator(), device)); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + executor, *argument, argument_buffer)); + argument_buffers.push_back(std::move(argument_buffer)); + argument_buffer_ptrs[index++] = &argument_buffers.back(); + } + argument_buffer_slices.emplace_back( + &argument_buffer_ptrs[index - options.arguments.size()], + options.arguments.size()); } - return result_literal; + + std::unique_ptr pool; + int64 num_threads = (options.infeed != nullptr) ? options.num_replicas : 0; + if (ShapeUtil::IsInitialized(options.outfeed_shape)) { + num_threads += options.num_replicas; + } + if (num_threads > 0) { + pool = MakeUnique( + tensorflow::Env::Default(), "infeed_outfeed", + /*num_threads=*/num_threads); + } + if (options.infeed != nullptr) { + for (int64 i = 0; i < options.num_replicas; ++i) { + int64 device = device_assignment(i, 0); + pool->Schedule([this, device, &options]() { + se::StreamExecutor* executor = + backend().stream_executor(device).ValueOrDie(); + VLOG(1) << "Starting infeed on device " << device; + for (int64 step = 1; + options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed( + executor, *options.infeed)); + if (step % 100 == 0) { + VLOG(1) << "Infeed step " << step; + } + } + }); + } + } + if (ShapeUtil::IsInitialized(options.outfeed_shape)) { + for (int64 i = 0; i < options.num_replicas; ++i) { + int64 device = device_assignment(i, 0); + pool->Schedule([this, device, &options]() { + se::StreamExecutor* executor = + backend().stream_executor(device).ValueOrDie(); + VLOG(1) << "Starting outfeed on device " << device; + for (int64 step = 1; + options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { + auto literal = MakeUnique(); + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( + executor, options.outfeed_shape, literal.get())); + if (options.outfeed_values != nullptr) { + options.outfeed_values->push_back(std::move(literal)); + } + if (step % 100 == 0) { + VLOG(1) << "Outfeed step " << step; + } + } + }); + } + } + + LOG(INFO) << "Replicated execution started"; + TF_ASSIGN_OR_RETURN(std::vector results, + executable->ExecuteOnStreams(service_run_options, + argument_buffer_slices)); + LOG(INFO) << "Replicated execution terminated"; + + std::vector> exec_results; + for (int64 i = 0; i < options.num_replicas; ++i) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + backend().transfer_manager()->TransferLiteralFromDevice( + streams[i]->parent(), results[i])); + exec_results.push_back(std::move(literal)); + } + return std::move(exec_results); +} + +StatusOr> HloRunner::CreateExecutable( + std::unique_ptr module, bool run_hlo_passes) { + if (run_hlo_passes) { + TF_ASSIGN_OR_RETURN( + module, backend().compiler()->RunHloPasses( + std::move(module), backend().default_stream_executor(), + backend().memory_allocator())); + } + return backend().compiler()->RunBackend(std::move(module), + backend().default_stream_executor(), + backend().memory_allocator()); +} + +ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( + int64 device, se::Stream* stream, DeviceAssignment* device_assignment) { + ExecutableRunOptions run_options; + run_options.set_device_ordinal(device); + run_options.set_stream(stream); + run_options.set_allocator(backend().memory_allocator()); + run_options.set_intra_op_thread_pool( + backend().eigen_intra_op_thread_pool_device()); + if (device_assignment != nullptr) { + run_options.set_device_assignment(device_assignment); + } + return ServiceExecutableRunOptions( + run_options, backend().StreamBorrower(), + /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); } Backend& HloRunner::backend() { if (!backend_) { backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); - VLOG(1) << "executing on platform " << backend().platform()->Name(); + VLOG(1) << "Executing on platform " << backend().platform()->Name(); } return *backend_; } +const Backend& HloRunner::backend() const { + return const_cast(this)->backend(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cbaebc68bee708090b8ccb2eae19b556c4d6d453..65537f07f56e74b7fe2c2f9792af21efc7229573 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -16,17 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ +#include #include +#include #include #include #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,9 +44,43 @@ namespace xla { // file), or parsed from a hlo textual IR string. class HloRunner { public: - HloRunner(); - - HloRunner(::perftools::gputools::Platform* platform); + // The options used to configure a ExecuteReplicated() call. + struct ReplicatedExecuteOptions { + // The number of devices the HLO module should be replicated onto. + int64 num_replicas = 1; + + // The arguments to be fed to each replica. Since this is used for a + // replicated execution, all the arguments are the same for all replicas. + std::vector arguments; + + // If the HLO module being run has an infeed instruction, this will be the + // data which will be fed to it, for as many as infeed_steps steps. + const Literal* infeed = nullptr; + + // The number of times the infeed literal should be fed to the HLO module. + // For a clean exit, this should match the iterations-per-loop parameter + // used when generating the HLO module proto (that is usually the main + // while bounary counter). A value higher then iterations-per-loop would + // lead to infeed threads feeding to a gone computation, while a lower + // value would trigger a stuck ExecuteReplicated() call (the computation + // will be trying to infeed data which will never come). + int64 infeed_steps = -1; + + // The shape of the outfeed operation. If empty, the HLO module does not + // generate any outfeed. + Shape outfeed_shape; + + // A pointer to a vector where the outfeed values will be stored. If + // nullptr, the values will be read and discarded. + std::vector>* outfeed_values = nullptr; + + // Whether the HLO passes should be run on the input module. Usually + // saved modules are coming from after the HLO pass pipeline, so triggering + // another run will likely cause errors. + bool run_hlo_passes = false; + }; + + explicit HloRunner(se::Platform* platform); ~HloRunner(); @@ -63,17 +102,48 @@ class HloRunner { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Transfers data between the host and device. + StatusOr TransferLiteralToDevice(const Literal& literal); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals); + StatusOr> TransferLiteralFromDevice( + const ShapedBuffer& buffer); + // Executes the given module with given literals as input and returns the - // result as a Literal. The LiteralPtr type accepts Literal* or - // std::unique_ptr. + // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr> Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + // As Execute(), but accepts and returns device buffers instead of host + // buffers. + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + // Executes a given HLO module into a set of replicas, and returns a map + // with the replica number as key, and the corresponding returned literal as + // value. + StatusOr>> ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. @@ -81,33 +151,24 @@ class HloRunner { // This creates the backend lazily so it's possible to instantiate an // HloRunner in a program without any backends linked in. Backend& backend(); + const Backend& backend() const; private: - StatusOr> ExecuteInternal( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); - - struct EigenThreadPoolWrapper; - - std::unique_ptr thread_pool_wrapper_; + // Creates an executable object given an HLO module. If run_hlo_passes is + // true, the HLO passes will be run before. + StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes); + + // Creates a ServiceExecutableRunOptions object to configure a run on device, + // using the provided stream object. If device_assignment is not nullptr, it + // will be used to configure the replication parameters. Replicated executions + // should pass the device_assignment parameter. + ServiceExecutableRunOptions GetServiceRunOptionsForDevice( + int64 device, se::Stream* stream, DeviceAssignment* device_assignment); std::unique_ptr backend_; }; -template -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - for (const auto& argument : arguments) { - argument_pointers.push_back(&*argument); - } - return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index f6e33403f538bd8492b04c34d46a458f7f06cc06..b14ade3549d093acdb5cdc7ae99dd025a42d5621 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -36,33 +36,37 @@ using ::tensorflow::strings::HumanReadableNumBytes; namespace xla { -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - namespace { // Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. +// sequence which minimizes memory usage by preferring to schedule the node that +// frees bigger buffer and defines smaller outputs. +// +// Note that list scheduler is a greedy algorithm which cannot guarantee a +// global optimal solution. As a counterexample, considering the following +// graph: +// +// +--> B ===> C -------+ +// A -> | | +// | v +// +--> D ---> F=======>G +// | ^ +// | | +// +--> E -----+ +// +// --> : Buffer with size 1 +// ==> : Buffer with size 2 +// +// The list scheduler will always try to defer scheduling B in a greedy way +// since its output buffer is bigger than input. The sequence it creates will +// be: +// A D E F B C G +// , which has a maximum memory usage of 6 (B is alive while F is executing). +// +// An optimal way to shedule the previous graph is: +// A B C D E F G +// , which has a maximum memory usage of 5 (when F is executing). +// class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions @@ -70,8 +74,11 @@ class ListScheduler { static StatusOr> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + ListScheduler scheduler(computation, points_to_analysis, size_function, + memory_by_computation); return scheduler.CreateSchedule(); } @@ -92,10 +99,13 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function) { + size_function_(size_function), + memory_by_computation_(memory_by_computation) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -103,10 +113,11 @@ class ListScheduler { for (auto* instruction : computation.instructions()) { tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - instr_uses.insert(buffer); - } + points_to_analysis.GetPointsToSet(operand).ForEachElement( + [&](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + instr_uses.insert(buffers.begin(), buffers.end()); + }); } buffer_uses_[instruction] = std::vector( instr_uses.begin(), instr_uses.end()); @@ -184,6 +195,12 @@ class ListScheduler { } // Returns the number of bytes freed if the HLO instruction is scheduled. + // If the instruction calls subcomputations, we count the memory used by the + // subcomputations as memory "defined" by the instruction. This is not + // entirely accurate, because subcomputation memory will be freed after the + // instruction finishes. But it is more accurate than not taking + // subcomputations into account at all. In the future, we may improve + // accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -193,7 +210,19 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - return freed_bytes - entry.bytes_defined; + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : entry.instruction->called_computations()) { + auto it = memory_by_computation_.find(c); + if (it != memory_by_computation_.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; } // Constructs the scheduling priority of the given instruction. @@ -247,6 +276,8 @@ class ListScheduler { auto best_it = ready_queue.end(); --best_it; const HloInstruction* best = best_it->second.instruction; + VLOG(2) << "Schedule instruction: " << best->ToShortString() + << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); @@ -314,6 +345,11 @@ class ListScheduler { const HloComputation& computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; + // Computations are analyzed in post-order. When scheduling an instruction + // that includes subcomputations, such as a while loop, we use this map to + // look up the memory needed by subcomputations. + const tensorflow::gtl::FlatMap& + memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. tensorflow::gtl::FlatMap> RunDFSMemoryScheduler( +StatusOr> ScheduleComputationsInModule( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function, + memory_by_computation); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation); +} + +} // namespace + +StatusOr> DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { // This ordering is based on DFS post-order, with a heuristic to decide which // operand to visit first. The heuristic is based on 'extra_users', which is // simply users-1 for each instruction. By subtracting 1, we're saying that // instructions with no users or a single user don't count; instructions with // lots of fan-out will be visited earlier. + int64 cumulative_total_size = 0; tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -357,14 +414,24 @@ StatusOr> RunDFSMemoryScheduler( continue; } extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - total_sizes[hlo] = SumLogicalBufferSizes( + int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + total_sizes[hlo] = logical_buffer_size; + cumulative_total_size += logical_buffer_size; tensorflow::gtl::FlatSet unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + // total_sizes[hlo] transitively includes the sizes of all nodes that + // lead to it. But computation is a DAG, so we are double-counting nodes, + // which can lead to overflows for large programs. + // cumulative_total_size caps the size to prevent overflows. + // NOTE(dimvar): this is quite ugly and should be changed. It's unclear + // why we care about transitive sizes; when scheduling a node, its input + // and output buffers should be all that matters, not its "history". + total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -390,97 +457,122 @@ StatusOr> RunDFSMemoryScheduler( })); CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; -} +} // namespace xla -StatusOr MinimumMemoryForComputation( +StatusOr> ListMemoryScheduler( const HloComputation& computation, - const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return ListScheduler::Run(computation, points_to_analysis, size_function, + memory_by_computation); } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm == SchedulerAlgorithm::kListSchedule) { - return ListScheduler::Run(computation, points_to_analysis, size_function); - } - if (algorithm == SchedulerAlgorithm::kDfsSchedule) { - return RunDFSMemoryScheduler(computation, points_to_analysis, - size_function); - } + const tensorflow::gtl::FlatMap& + memory_by_computation) { + const auto& post_order = computation.MakeInstructionPostOrder(); + return std::vector{post_order.begin(), + post_order.end()}; +} - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. +StatusOr> DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + DFSMemoryScheduler(computation, points_to_analysis, + size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); - if (list_memory <= dfs_memory) { + TF_ASSIGN_OR_RETURN( + std::vector post_order_sequence, + PostOrderMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); + TF_ASSIGN_OR_RETURN( + const int64 post_order_memory, + MinimumMemoryForComputation(computation, post_order_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); + + if (min_memory == list_memory) { VLOG(2) << "Chose min-memory list sequence: " << HumanReadableNumBytes(list_memory); return list_sequence; - } else { + } else if (min_memory == dfs_memory) { VLOG(2) << "Chose min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; } } -} // namespace - -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN( - sequence[computation], - CreateMemoryMinimizingSequence(*computation, *points_to_analysis, - size_function, algorithm)); + tensorflow::gtl::FlatMap memory_by_computation; + for (const auto* computation : module.MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + ScheduleComputationsInModule( + *computation, *points_to_analysis, size_function, + algorithm, memory_by_computation)); + memory_by_computation[computation] = + MinimumMemoryForComputation(*computation, one_computation_sequence, + *points_to_analysis, size_function) + .ValueOrDie(); + sequence[computation] = std::move(one_computation_sequence); + } } return sequence; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, algorithm); + tensorflow::gtl::FlatMap empty_map; + return ScheduleComputationsInModule(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 1d1eb1e064f75c2220b39e84b010e720a0c37880..2b33ccc8bfb895286bb3747aab0a16cf25e2cfae 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -22,39 +22,68 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" namespace xla { -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function>( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&, + const tensorflow::gtl::FlatMap&)> + MemorySchedulerAlgorithm; -enum class SchedulerAlgorithm { - kListSchedule, - kDfsSchedule, +// List scheduler +StatusOr> ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); - // Selects the available scheduler algorithm that had the minimum memory in - // the resulting sequence (a la MinimumMemoryForSequence). - kAuto, -}; +// DFS-order scheduler +StatusOr> DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// Naive Post Order scheduler +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr> DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence( +StatusOr ScheduleComputationsInModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); + const MemorySchedulerAlgorithm& algorithm = {}); -// Overload of above that computes the sequence for a single computation. -StatusOr> CreateMemoryMinimizingSequence( +// Computes the schedule for a single computation. +// Currently only used by the GPU backend. +StatusOr> ScheduleOneComputation( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); + const LogicalBuffer::SizeFunction& size_function); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 7fb338e7042ce19ac9647e23719e738f3ef42c7c..6f1b1215d39dfbaeff768de70fa0a0859cd97381 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -30,63 +31,298 @@ limitations under the License. namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class HloSchedulingTest : public HloTestBase {}; + +TEST_F(HloSchedulingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); +} + +TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { + const char* module_str = R"( +HloModule test_aliasing_module + +ENTRY root { + param = s32[1000] parameter(0) + p0 = s32[1000] copy(param) + p1 = s32[1000] copy(param) + t = (s32[1000], s32[1000]) tuple(p0, p1) + a = s32[1000] get-tuple-element(t), index=0 + b = s32[1000] get-tuple-element(t), index=1 + c = s32[1000] add(a, b) + d = s32[1000] add(c, b) + e = s32[1000] add(c, c) + f = s32[1000] add(e, e) + ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : + sequence.at(module->entry_computation())) { + instructions_by_name[instruction->name()] = instruction; + } + + // The first instruction should be the parameter and the last the root. + EXPECT_EQ(instructions_by_name.at("param"), + sequence.at(module->entry_computation()).front()); + EXPECT_EQ(instructions_by_name.at("result"), + sequence.at(module->entry_computation()).back()); + + // Instructions "d" and "e" will both be schedulable at the same time, but + // instruction "d" allows us to free the buffer of "p1", so the list scheduler + // should prefer it. + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), + instructions_by_name.at("e"))); +} + +TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) + // ROOT %not-equal-to = pred[] not-equal-to( + // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) + // } + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // ROOT %subtract = f32[4]{0} subtract( + // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) + // } + // %SubcomputationsNotAccounted () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant( + // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) + // %transpose = f32[2,4]{1,0} transpose( + // f32[2,4]{1,0} %constant.3), dimensions={0,1} + // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), + // condition=%WhileCond, + // body=%WhileBody + // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} + // ROOT %add = f32[2,4]{1,0} add( + // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + // param != 0 + // Needs 17 bytes auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + // param - 1 + // Needs 16 bytes auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + // transpose(matrix) + bcast(while) auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }, + ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // This schedule is an example of List's greedy heuristics being suboptimal. + // The while_loop is more expensive than transpose, so it would have been + // better to schedule it first, instead of during the busy time. + EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); + EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); +} + +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 447c2446668253c932b44b51b2db22bfd47f9957..9fb15df7c26951fb7f0d62b0d6533d6312e7a4d5 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -20,6 +20,7 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrCat; HloSharding HloSharding::AssignDevice(int64 device_id) { @@ -38,6 +39,34 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(tile_shape, assignment); } +HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve(sub_shardings.leaf_count()); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + if (flattened_list.empty()) { + // Empty tuple sharding ends up having no leaves, but we want to allow + // empty tuple HLO instruction results to have sharding, so we fetch the + // root ({}) sharding value from the ShapeTree. + // A ShapeTree created with ShapeTree(shape, init) will have + // init as value at its root. + flattened_list.push_back(sub_shardings.element(ShapeIndex({}))); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Tuple( + const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) + << "Flat list has " << flattened_list.size() << ", required " + << RequiredLeaves(tuple_shape); + return HloSharding(flattened_list); +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -48,17 +77,15 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", - "devices=", VectorString(tile_assignment_), "}"); + return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", + Join(tile_assignment_.dimensions(), ","), "]", + Join(tile_assignment_, ","), "}"); } } @@ -124,6 +151,49 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +int64 HloSharding::RequiredLeaves(const Shape& shape) { + // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are + // concerned, but they do have a single tuple_elements_ entry since we want + // to allow empty tuple results to have sharding. + return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); +} + +Status HloSharding::CheckLeafCount(const Shape& shape) const { + int64 shape_leaves = RequiredLeaves(shape); + TF_RET_CHECK(shape_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + return Status::OK(); +} + +StatusOr> HloSharding::AsShapeTree( + const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + if (ShapeUtil::IsEmptyTuple(shape)) { + // Empty tuples have no leaves, but we want to assign them a sharding + // anyway, so we use the root element sharding. + *result.mutable_element(ShapeIndex({})) = *it; + } + return std::move(result); + } else { + return ShapeTree(shape, *this); + } +} + +StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { + if (IsTuple()) { + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + return *this; + } + return Tuple(ShapeTree(shape, *this)); +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -165,19 +235,7 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } - // The easiest way to get the number of elements in a nested tuple is just to - // create a shape tree. We could call GetAsShapeTree, but that will try and - // apply our tuple_shardings_ to the shape tree, and that might cause a crash - // at this point as we haven't validated them. - ShapeTree bool_shape_tree(shape, false); - int64 num_leaves = - std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); - if (num_leaves != tuple_elements_.size()) { - return tensorflow::errors::InvalidArgument( - StrCat("Validation tuple shape has ", num_leaves, - " leaf elements, but this sharding contains ", - tuple_elements_.size(), " elements.")); - } + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); // Now we've validated the number of tuple elements, it's safe to request a // shape tree. @@ -222,7 +280,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, uint32 core) { + [&](tensorflow::gtl::ArraySlice indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { @@ -250,37 +308,24 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, ", input_shape=", ShapeUtil::HumanString(shape)); } - // The tile shape must not be the same as the input shape without maximal_ - // also set. If this is the case, we're not actually sharded and the correct - // constructor should have been used. - if (ShapeUtil::Equal(shape, tile_shape_)) { + // The correct constructor have to be used to create tile maximal shardings. + if (tile_assignment_.num_elements() == 1) { return tensorflow::errors::InvalidArgument( - "Tile shape is the same as the input shape. If a replicated sharding " - "was intended, use HloSharding::Replicated(). If a device placement " - "was intended, use HloSharding::AssignDevice()"); - } - - // The tile shape must not be greater than the input shape in any dimension. - for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) { - auto tile_dim = tile_shape_.dimensions(i); - auto shape_dim = shape.dimensions(i); - if (tile_dim > shape_dim) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile is larger than input shape (dimension ", i, ", ", - tile_dim, " > ", shape_dim)); - } + "Tile assignment only contains a single device. If a replicated " + "sharding was intended, use HloSharding::Replicated(). If a device " + "placement was intended, use HloSharding::AssignDevice()"); } - // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim] - // tile[dim]) for every dimension contained within tile. + // The tile assignment tensor must contain enough element to cover the full + // shape with tiles of the specified size. for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 expected_dim = - CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); - if (tile_assignment_.dimensions()[i] != expected_dim) { + int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); + if (shape.dimensions(i) > total_tile_size) { return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, - " expected ", expected_dim, " but got ", - tile_assignment_.dimensions()[i])); + StrCat("Tile assignment tensor has too few element to cover the full " + "shape. Dimension ", + i, ", shape ", shape.dimensions(i), ", total size ", + total_tile_size)); } } @@ -344,4 +389,59 @@ OpSharding HloSharding::ToProto() const { return result; } +HloSharding HloSharding::TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform) const { + CHECK(!IsTuple()); + if (IsTileMaximal()) { + return *this; + } + CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); + Shape new_tile_shape; + new_tile_shape.set_element_type(tile_shape().element_type()); + for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { + int64 dim; + if (tile_assignment().dim(i) == 1) { + dim = new_shape.dimensions(i); + } else if (transform) { + dim = transform(i, tile_shape().dimensions(i)); + } else { + dim = tile_shape().dimensions(i); + } + new_tile_shape.add_dimensions(dim); + } + TF_CHECK_OK( + LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); + return HloSharding::Tile(new_tile_shape, tile_assignment()); +} + +HloSharding HloSharding::GetSubSharding(const Shape& shape, + const ShapeIndex& index) const { + CHECK(IsTuple()); + + Shape sub_shape = ShapeUtil::GetSubshape(shape, index); + ShapeTree sub_shape_tree(sub_shape, Replicate()); + sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); + return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) + : sub_shape_tree.element(ShapeIndex({})); +} + +tensorflow::gtl::optional HloSharding::ExtractSingleSharding() + const { + if (!IsTuple()) { + return *this; + } + for (int64 i = 1; i < tuple_elements_.size(); ++i) { + if (tuple_elements_[0] != tuple_elements_[i]) { + return tensorflow::gtl::optional(); + } + } + return tuple_elements_.front(); +} + +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { + out << sharding.ToString(); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 7263198385cf0c84b1dac1e15177dcac99adaafb..6a744e0247273e25c5de3143b7bbba2b79ee816a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -70,31 +70,25 @@ class HloSharding { // Creates a new sharding for a tuple type. The given ShapeTree must have // elements for every leaf shape contained in the tuple. - static HloSharding Tuple(const ShapeTree& sub_shardings) { - std::vector flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); - for (const auto& index_to_sharding : sub_shardings.leaves()) { - flattened_list.push_back(index_to_sharding.second); - } - return HloSharding(flattened_list); - } + static HloSharding Tuple(const ShapeTree& sub_shardings); - // Creates a new sharding for a tuple type. The requested tuple shape must not - // be nested. For nested tuples, use the ShapeTree overload. + // Creates a new sharding for a tuple type. The number of elements in + // shardings must match the number of leaf nodes in tuple_shape. For + // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)); - CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); - std::vector flattened_list(shardings.begin(), shardings.end()); - CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); - return HloSharding(flattened_list); - } + tensorflow::gtl::ArraySlice shardings); // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); + // Checks whether device is a reserved device number. A reserved device number + // has usually a special meaning, with dedicated handling logic. + static bool IsReservedDevice(int64 device) { return device < 0; } + OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -156,24 +150,30 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. + StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { - if (IsTuple()) { - ShapeTree result(shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } else { - return ShapeTree(shape, *this); - } + return AsShapeTree(shape).ValueOrDie(); } + // Retrieves the sub sharding at a given index, out of a tuple sharding. + // REQUIRES: IsTuple() + HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + + // If the current sharding is a tuple sharding, return itself as result. + // Otherwise returns a tuple sharding for the input shape, with all the leaves + // having this object sharding. + StatusOr GetTupleSharding(const Shape& shape) const; + + // Extracts the sharding that is common within the current sharding. + // If the current sharding is not a tuple sharding, the current sharding will + // be returned. If it is a tuple, and all the tuple elements are common, the + // common element will be returned. Otherwise the optional will contain no + // value. + tensorflow::gtl::optional ExtractSingleSharding() const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && + ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -200,6 +200,12 @@ class HloSharding { return h; } + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); + } + }; + // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } @@ -207,6 +213,26 @@ class HloSharding { // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } + // Returns the flattened list of all the leaf shardings in a tuple shape, by + // pre-order walk (ShapeTree iterator order). + // REQUIRES: IsTuple(). + const std::vector& tuple_elements() const { + return tuple_elements_; + } + + // Return a new sharding that can apply to the given new shape. + // If this sharding is tile-maximal, the returned sharding will be the same as + // this sharding. If this sharding is not tile-maximal, the returned + // sharding's tile size will differ: + // - Non-sharded dimensions will be adapted to be the same as `new_shape`; + // tile_dimension(i) = new_shape.dimensions(i); + // - Sharded dimensions will be kept the same unless `transform` is supplied + // in which case tile_dimension(i) = transform(i, tile_dimension(i)); + // REQUIRES: !IsTuple(). + HloSharding TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform = nullptr) const; + private: HloSharding() : replicated_(true), @@ -233,11 +259,19 @@ class HloSharding { tile_assignment_({0}), tuple_elements_(tuple_shardings) {} + // Checks that the number of elements in tuple_elements_ is consistent with + // the tuple shape passes as argument. + Status CheckLeafCount(const Shape& shape) const; + // Internal helper to validate a tuple sharding. Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; + // Returns the number of tuple_elements_ entries to fit the shape. + static int64 RequiredLeaves(const Shape& shape); + bool replicated_; bool maximal_; bool tuple_; @@ -249,6 +283,8 @@ class HloSharding { std::vector tuple_elements_; }; +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b4b071af46df19520f9ba1f1f632692d489de59 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -0,0 +1,394 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +struct PassThrough { + PassThrough(HloInstruction* user, HloInstruction* operand) + : user(user), operand(operand) {} + + HloInstruction* user = nullptr; + HloInstruction* operand = nullptr; +}; + +void SetSingleSharding(HloInstruction* instruction, + const HloSharding& sharding) { + VLOG(4) << " " << instruction->name() << " to " << sharding; + instruction->set_single_sharding(sharding); +} + +bool ShardingMatches(const HloSharding& sharding1, + const HloSharding& sharding2) { + auto single_sharding1 = sharding1.ExtractSingleSharding(); + if (single_sharding1) { + auto single_sharding2 = sharding2.ExtractSingleSharding(); + if (single_sharding2) { + return *single_sharding1 == single_sharding2; + } + } + // Anything which is not unique across all elements, gets a full sharding + // compare. + return sharding1 == sharding2; +} + +// When we create domains, they are never "empty", where with empty we mean +// that a kDomain instruction has as operand another kDomain instruction of the +// same kind. +// But when the HLO optimizations are run, empty domains can be created. +// For example: +// +// Domain(device=None, device=0) -> +// Tuple(device=0) -> +// GTE(device=0) -> +// Domain(device=0, device=None) +// +// In that case the tuple simplifier could create something like: +// +// Domain(device=None, device=0) -> Domain(device=0, device=None) +// +// Which is a so called empty domain. +// In the case above, crossing an empty domain which was transiting through +// device 0, requires the normalization phase to fixup the empty domain by +// adding back a Tuple+GTE pair with the proper device. +// One particular case where this can create problems is the result of the +// entry computation, where the GTE assignments are used by TF to tell the +// XLA where the results should be sent. +std::vector LocatePassThroughDomainLinks( + const DomainMetadata::Domain& domain) { + std::vector pass_through; + for (HloInstruction* instruction : domain.enter_domains) { + CHECK(instruction->opcode() == HloOpcode::kDomain) + << "Instruction is not a kDomain: " << instruction->ToString(); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(user) != 0) { + pass_through.emplace_back(user, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " " << user->ToString(); + VLOG(2) << " " << instruction->ToString(); + } + } + } + return pass_through; +} + +Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + for (auto& pass_through : LocatePassThroughDomainLinks(domain)) { + HloInstruction* tuple = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateTuple({pass_through.operand})); + HloInstruction* gte = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), + tuple, 0)); + gte->set_sharding(sharding); + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } + return Status::OK(); +} + +std::unique_ptr CloneShardingForDomain( + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (!single_sharding) { + return MakeUnique(sharding); + } + return MakeUnique(*single_sharding); +} + +Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + VLOG(4) << "Applying " << sharding << " sharding"; + for (HloInstruction* instruction : domain.instructions) { + // We only change instructions without sharding, since otherwise we might + // mess up with eventual HLO passes which has knowledge of it. + if (!instruction->has_sharding()) { + SetSingleSharding(instruction, sharding); + } else { + VLOG(4) << " " << instruction->name() << " already has sharding " + << instruction->sharding(); + } + } + return Status::OK(); +} + +// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. +// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() +// sharding will be returned. +ShapeTree GetTupleSharding(HloInstruction* tuple) { + if (tuple->has_sharding()) { + return tuple->sharding().GetAsShapeTree(tuple->shape()); + } + return ShapeTree(tuple->shape(), HloSharding::Replicate()); +} + +// Retrieves the sharding of operand, asked from a user instruction which is +// within domain. If operand is a kDomain, it means that sharding argument is +// the operand sharding, otherwise the operand's own sharding will be returned. +const HloSharding* GetOperandSharding(const HloInstruction* operand, + const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); + // Here the user of operand is within the domain instruction set, and since it + // is user of operand, we need to look into the enter_domains set. If this is + // not a kDomain within the user domains set, then return the operand + // sharding, if any. + if (operand->opcode() != HloOpcode::kDomain || + domain.enter_domains.count(const_cast(operand)) == 0) { + return operand->has_sharding() ? &operand->sharding() : nullptr; + } + // At this point operand is a kDomain of the currently processed domain, so we + // can refer to sharding as the domain sharding. + return &sharding; +} + +// Tries to propagate the sharding information into the instructions that are +// part of the domain, in a post order manner (operand propagate to user). +StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + int64 assigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (instruction->has_sharding()) { + continue; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* tuple = instruction->mutable_operand(0); + const HloSharding* tuple_sharding = + GetOperandSharding(tuple, domain, sharding); + if (tuple_sharding != nullptr) { + if (tuple_sharding->IsTuple()) { + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + } else { + SetSingleSharding(instruction, *tuple_sharding); + } + ++assigned; + } + } else if (instruction->opcode() == HloOpcode::kTuple) { + int64 tuple_assigned = 0; + ShapeTree shape_tree = GetTupleSharding(instruction); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloSharding* operand_sharding = + GetOperandSharding(instruction->operand(i), domain, sharding); + if (operand_sharding != nullptr && + shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } + } + if (tuple_assigned > 0) { + HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); + VLOG(4) << " " << instruction->name() << " to sharding " + << tuple_sharding; + instruction->set_sharding(tuple_sharding); + ++assigned; + } + } else { + // If all the operand of the given instruction has the same single device + // assignment, assign that device to this instruction as well. + const HloSharding* common_sharding = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + const HloSharding* operand_sharding = + GetOperandSharding(operand, domain, sharding); + if (operand_sharding != nullptr) { + if (common_sharding != nullptr && + *common_sharding != *operand_sharding) { + common_sharding = nullptr; + break; + } + common_sharding = operand_sharding; + } + } + if (common_sharding != nullptr) { + VLOG(4) << " " << instruction->name() << " to sharding " + << *common_sharding; + instruction->set_sharding(*common_sharding); + ++assigned; + } + } + } + return assigned; +} + +Status ApplyDomainSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (single_sharding) { + // Shortcut the simple case. We have a unique sharding, so we call + // the ApplyDomainSingleSharding() API which will apply array or tuple + // shaped sharding to the domain instructions. + return ApplyDomainSingleSharding(domain, *single_sharding); + } + VLOG(1) << "Assigning non-trivial sharding " << sharding; + for (;;) { + TF_ASSIGN_OR_RETURN(int64 assigned, + ApplyDomainShardingPass(domain, sharding)); + if (assigned == 0) { + break; + } + } + int64 unassigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (!instruction->has_sharding()) { + LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); + ++unassigned; + } + } + // Should we error out if unassigned > 0? + return Status::OK(); +} + +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +std::unique_ptr CreateDomain(HloInstruction* instruction, + HloInstruction* operand) { + const HloSharding* instruction_sharding = + instruction->has_sharding() ? &instruction->sharding() : nullptr; + const HloSharding* operand_sharding = + operand->has_sharding() ? &operand->sharding() : nullptr; + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && operand_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && operand_sharding != nullptr && + ShardingMatches(*instruction_sharding, *operand_sharding)) { + return nullptr; + } + std::unique_ptr real_instruction_sharding; + std::unique_ptr real_operand_sharding; + if (instruction_sharding != nullptr) { + real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); + } + if (operand_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*operand_sharding); + } + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (real_instruction_sharding != nullptr + ? real_instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (real_operand_sharding != nullptr + ? real_operand_sharding->ToString() + : "None"); + + std::unique_ptr operand_side_metadata = + MakeUnique(std::move(real_operand_sharding)); + std::unique_ptr user_side_metadata = + MakeUnique(std::move(real_instruction_sharding)); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +StatusOr> ExtractOriginalCommonSharding( + tensorflow::gtl::ArraySlice instructions) { + // If we are here, all the instructions being passed had the same sharding + // (or no sharding), by the means of the ShardingMatches() API. + // As such, no kDomain was inserted, and here we are asked to extract the + // original common sharding. + // All the instructions passed to this API are part of the same computation. + const HloSharding* sharding = nullptr; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + if (sharding == nullptr) { + sharding = &instruction->sharding(); + } else { + TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) + << "Sharding " << *sharding << " does not match the one in " + << instruction->ToString(); + } + } + } + if (sharding == nullptr) { + return std::unique_ptr(); + } + VLOG(4) << "Extracted sharding is " << *sharding; + return CloneShardingForDomain(*sharding); +} + +} // namespace + +std::unique_ptr ShardingMetadata::Clone() const { + std::unique_ptr sharding; + if (sharding_ != nullptr) { + sharding = MakeUnique(*sharding_); + } + return MakeUnique(std::move(sharding)); +} + +bool ShardingMetadata::Matches(const DomainMetadata& other) const { + const ShardingMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a ShardingMetadata, then it is clearly a no match. + return false; + } + if (sharding_ == nullptr) { + return other_ptr->sharding_ == nullptr; + } + return other_ptr->sharding_ != nullptr + ? ShardingMatches(*sharding_, *other_ptr->sharding_) + : false; +} + +string ShardingMetadata::ToString() const { + return sharding_ != nullptr ? sharding_->ToString() : "None"; +} + +Status ShardingMetadata::NormalizeInstructions( + const DomainMetadata::Domain& domain) const { + if (sharding_ != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); + } + return Status::OK(); +} + +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() + << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } + return Status::OK(); +} + +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand) { + return CreateDomain(instruction, operand); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..ec162c34904ee2dfac3daeeee37133282a9c9698 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// A DomainMetadata implementation that internally wraps a sharding attribute. +class ShardingMetadata : public DomainMetadata { + public: + explicit ShardingMetadata(std::unique_ptr sharding) + : sharding_(std::move(sharding)) {} + + std::unique_ptr Clone() const override; + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override; + + string ToString() const override; + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override; + + static tensorflow::StringPiece KindName() { return "sharding"; } + + private: + std::unique_ptr sharding_; +}; + +// Within a set of instructions which had common sharding attributes before +// entring the HLO passes pipeline, apply sharding heuristics and normalize the +// instructions whose sharding deviates from the one which is inferred as to be +// the original one. +// Policy wise, HLO passes are allowed to create new unassigned instructions, +// but if they do create assigned ones, they have to conform to the ones around. +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); + +// Given an HLO graph edge between instruction and one of its operands, creates +// a ShardingMetadata based kDomain instruction if the sharding between +// instruction and operand changes. Returns nullptr if there is no need for a +// domain separation. +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 0c7487b3ac77ff181d44dd55ebcf2608feaf02ea..54b7402b866361748d9eb35182b0bf486c4c9bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_sharding.h" - #include #include #include #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -88,7 +87,7 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should pass. + // Test should fail because of more devices used then `num_device`. Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); HloSharding sharding = HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); @@ -97,17 +96,8 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should fail due to the tile being larger than the input space. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}), - /*num_devices=*/4)); - } - - { - // Test should fail due to the tile not dividing the input space into 4 - // sections (even with padding). + // Test should fail because the total tiled size in dimension 0 is 4 but we + // have 6 elements along that dimensions. Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); HloSharding sharding = HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); @@ -269,5 +259,102 @@ TEST_F(HloShardingTest, Hash) { } } +TEST_F(HloShardingTest, TransformShardedTileShapeTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding result = sharding.TransformShardedTileShape( + ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), + [](int dim, int value) { return dim * 111; }); + HloSharding expected = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), + Array4D({{{{0, 1}, {2, 3}}}})); + EXPECT_EQ(result, expected); +} + +TEST_F(HloShardingTest, ToStringReplicatedTest) { + HloSharding sharding = HloSharding::Replicate(); + EXPECT_EQ(sharding.ToString(), "{replicated}"); +} + +TEST_F(HloShardingTest, ToStringAssignDeviceTest) { + HloSharding sharding = HloSharding::AssignDevice(7); + EXPECT_EQ(sharding.ToString(), "{maximal device=7}"); +} + +TEST_F(HloShardingTest, ToStringTiledTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), + Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); +} + +TEST_F(HloShardingTest, ToStringTupleTest) { + HloSharding sharding = HloSharding::Tuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), + ShapeUtil::MakeShape(U32, {7, 25}), + ShapeUtil::MakeShape(S32, {9, 11})}), + {HloSharding::Replicate(), + HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), + Array2D({{3, 5}})), + HloSharding::AssignDevice(3)}); + EXPECT_EQ(sharding.ToString(), + "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); +} + +TEST_F(HloShardingTest, OstreamTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + std::ostringstream oss; + oss << sharding; + EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); +} + +TEST_F(HloShardingTest, ParseHloString) { + auto check = [](const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding, + ParseSharding(sharding.ToString())); + EXPECT_EQ(sharding, parsed_sharding); + }; + check(HloSharding::Replicate()); + check(HloSharding::AssignDevice(2)); + check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}}))); + // Empty tuple. One sharding is required for empty tuples, as we need to be + // able to assign sharding to them, even though they have no leaves. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), + {HloSharding::Replicate()})); + { + // Non-nested tuple. + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})}); + check(HloSharding::Tuple( + tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)})); + } + { + // Nested tuple. + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})})}); + std::vector leaf_shardings = { + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)}; + ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); + // Assign leaf_shardings to sharding_tree leaves. + auto it = leaf_shardings.begin(); + for (auto& index_to_sharding : sharding_tree.leaves()) { + index_to_sharding.second = *it++; + } + check(HloSharding::Tuple(sharding_tree)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index f8d98f0678596750bb76462e550085753678e860..be156d765dc10d54eaf301e90883babbc5693e28 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h similarity index 84% rename from tensorflow/compiler/xla/tools/parser/hlo_token.h rename to tensorflow/compiler/xla/service/hlo_token.h index 7928bee5c2097f353b182095a555c334d7b69c95..533429608bc2e13626a3e746fbe465398e1f4bb4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ #include @@ -22,9 +22,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Defines different kinds of tokens in a hlo module string. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. enum class TokKind { // Markers kEof, @@ -72,7 +74,6 @@ enum class TokKind { string TokKindToString(TokKind kind); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 05b7dce3d1ecf935b80ba1cb46ef089b7b3b6f33..7b27dbfec376b8ba16d00285f10e2cc291e07a61 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -69,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) { HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) - : id_(id), is_phi_(is_phi) { + : BufferValue(instruction, index, id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. positions_.push_back(HloPosition{instruction, index}); } @@ -90,8 +92,8 @@ string HloValue::ToShortString() const { string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) ? defining_index().ToString() : ""; - return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(), - index_str); + return StrCat(id(), " ", is_phi_ ? "PHI " : "", + defining_instruction()->name(), index_str); } string HloValue::ToString(int indent) const { diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 2a711e8b42590c29d0aaab95dcf110063ada3182..a1151f65e07dffdcd52f645f61dcc9b4f26459c0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -16,16 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ -#include +#include #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -80,30 +84,9 @@ struct HloUse { std::ostream& operator<<(std::ostream& out, const HloUse& use); -// Class describing a value used by the dataflow analysis. XLA arrays are -// trivially a single HloValue. Tuples are made up of more than one HloValue: an -// HloValue for the pointer vector, and an HloValue for each child element. -// -// Every HloValue is defined by a particular instruction and most instructions -// define only a single HloValue. Instructions which define a single HloValue -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single HloValue -// which is a vector of pointers to the values containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple values only the top-level HloValue (the vector of pointers) is -// defined by the Tuple instruction. The values containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one HloValue. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. These tuple-shaped instructions do -// not assemble a tuple from existing HloValues like the Tuple instruction does, -// but rather define all the HloValues in the tuple. -class HloValue { +// HloDataflowAnalysis uses this subclass of BufferValue. +class HloValue : public BufferValue { public: - using Id = int64; - // Predicate comparing HloValues by increasing id, useful for std::sort. static bool IdLessThan(const HloValue* a, const HloValue* b) { return a->id() < b->id(); @@ -120,6 +103,7 @@ class HloValue { // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + ~HloValue() override {} // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not @@ -127,10 +111,6 @@ class HloValue { void SetPositionsAndComputeUses( tensorflow::gtl::ArraySlice positions); - // Return a unique identifier for this HloValue. This value is used for stable - // sorting and iteration - Id id() const { return id_; } - // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -142,12 +122,18 @@ class HloValue { return defining_position().instruction; } + HloInstruction* instruction() const override { + return defining_instruction(); + } + // Return the shape index at which this HloValue is defined in the output of // its defining instruction. const ShapeIndex& defining_index() const { return defining_position().index; } + const ShapeIndex& index() const override { return defining_index(); } + // Return the shape of this HloValue. - const Shape& shape() const { return defining_position().shape(); } + const Shape& shape() const override { return defining_position().shape(); } // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } @@ -164,12 +150,11 @@ class HloValue { // Return a single-line string representation of the value. string ToShortString() const; - string ToString(int indent = 0) const; + string ToString(int indent) const; - private: - // Unique identifier for this HloValue. Used for stable sorting and iteration. - const Id id_; + string ToString() const override { return ToString(0); } + private: // Whether this instruction is a phi value. const bool is_phi_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index b1fd068115e1d104a11d880675ef84e07d6d5602..9034073cc8a82311297ccd087741e6713110a5a7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -105,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the @@ -115,7 +114,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -126,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -163,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { @@ -182,7 +179,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_shape.dimensions(operand_dimension)) << broadcast->ToString() << " operand shape " << operand_shape; } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { @@ -190,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { @@ -199,22 +196,18 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, @@ -383,6 +376,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: @@ -409,7 +403,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -432,6 +426,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { + std::vector operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -489,14 +491,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -540,13 +542,13 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), instr2->ToString().c_str(), instr2->channel_id()); } - return tensorflow::Status::OK(); + return Status::OK(); } string ComputationsToString( @@ -570,22 +572,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -601,7 +603,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -611,14 +613,14 @@ Status VerifyHloStructure(HloModule* module) { } } } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -634,37 +636,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -673,13 +675,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -694,43 +696,149 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. - return tensorflow::Status::OK(); + return Status::OK(); +} + +Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { + auto* while_cond = instruction->while_condition(); + auto* while_body = instruction->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %lld : %s", + while_cond->num_parameters(), while_cond->ToString().c_str()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %lld : %s", + while_body->num_parameters(), while_body->ToString().c_str()); + } + if (instruction->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %lld : %s", + instruction->operand_count(), instruction->ToString().c_str()); + } + auto* init = instruction->operand(0); + auto* cond_param = while_cond->parameter_instruction(0); + if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) { + return FailedPrecondition( + "While condition's parameter must have the same shape as the " + "loop's 'init'. init: %s, param: %s", + init->ToString().c_str(), cond_param->ToString().c_str()); + } + auto* cond_root = while_cond->root_instruction(); + if (!ShapeUtil::Compatible(cond_root->shape(), + ShapeUtil::MakeShape(PRED, {}))) { + return FailedPrecondition("While condition should have shape PRED: %s", + cond_root->ToString().c_str()); + } + auto* body_param = while_body->parameter_instruction(0); + if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) { + return FailedPrecondition( + "While body's parameter must have the same shape as the loop's" + " 'init'. init: %s, param: %s", + init->ToString().c_str(), body_param->ToString().c_str()); + } + auto* body_root = while_body->root_instruction(); + if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) { + return FailedPrecondition( + "While body should have same shape as the loop's 'init'." + "init: %s, body: %s", + init->ToString().c_str(), body_root->ToString().c_str()); + } + return Status::OK(); +} + +Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { + const Shape& out_shape = instruction->shape(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (!ShapeUtil::IsScalar(operand_shape) && + !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { + return FailedPrecondition( + "Implicit broadcast is not allowed in HLO." + "Found non-compatible shapes for instruction %s.\n" + "output: %s\noperand: %s\n", + HloOpcodeString(instruction->opcode()).c_str(), + ShapeUtil::HumanString(out_shape).c_str(), + ShapeUtil::HumanString(operand_shape).c_str()); + } + } + return Status::OK(); } +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. For example, TOKEN types have no Literal representation and cannot be +// on the interface of the entry computation (parameters and root instruction). +Status VerifyEntryAndExitShapes(const HloModule& module) { + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + if (ShapeContainsToken( + module.entry_computation()->root_instruction()->shape())) { + return InternalError( + "Entry root is or contains a token shape: %s", + ShapeUtil::HumanString( + module.entry_computation()->root_instruction()->shape()) + .c_str()); + } + return Status::OK(); +} + +} // namespace + StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -762,45 +870,18 @@ StatusOr HloVerifier::Run(HloModule* module) { } else if (instruction->opcode() == HloOpcode::kBroadcast) { // If you see this failure then someone has confused the difference // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO has invalid number of dimensions."; + << "Broadcast HLO (" << instruction->ToShortString() + << ") has invalid number of dimensions: " + << instruction->dimensions().size() + << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - TF_RET_CHECK(while_cond->num_parameters() == 1) - << "While condition must have exactly 1 parameter; had " - << while_cond->num_parameters() << ": " << while_cond->ToString(); - TF_RET_CHECK(while_body->num_parameters() == 1) - << "While body must have exactly 1 parameter; had " - << while_body->num_parameters() << ": " << while_body->ToString(); - TF_RET_CHECK(instruction->operand_count() == 1) - << "While loop must have exactly one operand; had " - << instruction->operand_count() << ": " << instruction->ToString(); - - auto* init = instruction->operand(0); - auto* cond_param = while_cond->parameter_instruction(0); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape())) - << "While condition's parameter must have the same shape as the " - "loop's 'init'. init: " - << init->ToString() << ", param: " << cond_param->ToString(); - auto* cond_root = while_cond->root_instruction(); - TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(), - ShapeUtil::MakeShape(PRED, {}))) - << "While condition should have shape PRED: " - << cond_root->ToString(); - - auto* body_param = while_body->parameter_instruction(0); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape())) - << "While body's parameter must have the same shape as the loop's " - "'init'. init: " - << init->ToString() << ", param: " << body_param->ToString(); - auto* body_root = while_body->root_instruction(); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape())) - << "While body should have same shape as the loop's 'init'. init: " - << init->ToString() << ", body: " << body_root->ToString(); + TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); + } else if (instruction->IsElementwise()) { + TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); } auto previous = instructions.find(instruction->name()); @@ -818,6 +899,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1dd7ec3c51e18dcfe89bd478de87798ba3858119..7283b3e7dcdbed5be18a1da1571287cf0c089288 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,10 +81,9 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleGenerateToken(HloInstruction* token) override; - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } + Status FinishVisit(HloInstruction*) override { return Status::OK(); } protected: // Check the instruction's shape against the shape given by ShapeInference @@ -102,7 +101,7 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); - // Checks if the given two instructions shares the same channel id. + // Checks if the given two instructions share the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2); @@ -144,9 +143,15 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; + Status CheckWhileInstruction(HloInstruction* instruction); + + // Checks that the non-scalar operand shapes are compatible to the output + // shape, i.e., that there are no implicit broadcasts of size-one dimensions. + Status CheckElementwiseInstruction(HloInstruction* instruction); + // Creates a ShapeVerifier that checks that shapes match inferred - // expectations. This is a factory function because ShapeVerifier, Note that - // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object + // expectations. This is a factory function because ShapeVerifier, + // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; }; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 13e4557317f74b3fb46f07fb91c339fd2f34752f..d7458c338e9f1df9fac90270845aae0b8f779ee2 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -27,6 +27,7 @@ using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; @@ -35,20 +36,26 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto append_op = [&](const OpInfo& op) { + auto print_op = [&](const OpInfo& op) { + // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that + // were expected to be free and are actually free -- things like (on most + // backends) kParameter or kConstant HLOs. There's no need to clutter the + // profile with these. + if (op.optimal_seconds == 0 && op.cycles == 0) { + return; + } + string bytes_per_sec; string bytes_per_cycle; - if (op.cycles <= 0 || op.bytes_accessed < 0) { - bytes_per_sec = ""; - bytes_per_cycle = ""; - } else { - bytes_per_sec = - HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); + if (op.cycles > 0 && op.bytes_accessed >= 0) { + bytes_per_sec = StrCat( + HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)), + "/s"); + double bpc = static_cast(op.bytes_accessed) / op.cycles; if (op.bytes_accessed > op.cycles) { - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = - Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + bytes_per_cycle = Printf("%.3fB/cycle", bpc); } } @@ -59,14 +66,16 @@ string HumanReadableProfileBuilder::ToString() const { double nsecs = op.cycles / clock_rate_ghz_; Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec (%12.1f optimal) :: %18s " - ":: %18s :: %12s/s :: %12s/cycle :: %s\n", + "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " + ":: %18s :: %14s :: %16s :: %s\n", op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds * 1e6, + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), op.flop_count <= 0 - ? "" + ? "" : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" + op.transcendental_count <= 0 ? "" : HumanReadableNumTranscendentalOps( op.transcendental_count, nsecs) .c_str(), @@ -78,24 +87,26 @@ string HumanReadableProfileBuilder::ToString() const { int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { - optimal_seconds_sum += op.optimal_seconds; - total_flops += op.flop_count; - total_transcendentals += op.transcendental_count; - total_bytes += op.bytes_accessed; + if (op.optimal_seconds > 0) { + optimal_seconds_sum += op.optimal_seconds; + } + total_flops += std::max(op.flop_count, int64{0}); + total_transcendentals += std::max(op.transcendental_count, int64{0}); + total_bytes += std::max(op.bytes_accessed, int64{0}); } VLOG(1) << "Total floating point ops: " << total_flops; - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); - // Sort ops in decreasing order of cycles. + // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); std::sort( sorted_ops.begin(), sorted_ops.end(), [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); for (const auto& op : sorted_ops) { - append_op(op); + print_op(op); } if (total_cycles_ <= 0) { @@ -109,8 +120,20 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds above estimated optimum"); table.SetEntryName("ops"); table.SetShowCategoryTable(); + table.SetShowAllEntries(); float total_discrepancy_in_microseconds = 0.0f; - for (const auto& op : sorted_ops) { + for (const auto& op : op_infos_) { + // Skip ops with < 0 optimal seconds. These are ops for which we don't + // know the optimal time. + if (op.optimal_seconds < 0) { + continue; + } + // Also skip ops with 0 actual cycles. These ops were free; there's no + // need to clutter the "above estimated optimum" table with them, + // because they can't be optimized further. + if (op.cycles == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -128,7 +151,14 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds"); table.SetEntryName("ops"); table.SetShowCategoryTable(); - for (const auto& op : sorted_ops) { + table.SetShowAllEntries(); + for (const auto& op : op_infos_) { + // Skip ops with 0 optimal seconds and 0 actual cycles. As in + // print_op(), these are uninteresting because they're expected to be + // free, and they were actually free. + if (op.cycles == 0 && op.optimal_seconds == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -139,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const { StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); } } + + if (total_bytes > 0) { + MetricTableReport table; + table.SetMetricName("MiB read+written"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : op_infos_) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = static_cast(op.bytes_accessed) / (1 << 20); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, + table.MakeReport(static_cast(total_bytes) / (1 << 20))); + } return s; } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index fc24acd2713f4cd8af2816ffdf085e84a4920cbc..6f56c3aa82e9d1c942fd67ff7a5948cf2e54370d 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -32,7 +32,7 @@ class HumanReadableProfileBuilder { explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(computation_name.ToString()), + : computation_name_(std::string(computation_name)), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -41,15 +41,17 @@ class HumanReadableProfileBuilder { int64 total_cycles() const { return total_cycles_; } // Adds an operation to the profile. If you don't know the number of - // floating-point ops or bytes touched by the op, pass -1 for that param. + // floating-point ops or bytes touched by the op, or if you don't know how + // fast it would run optimally, pass -1 for that param. void AddOp(tensorflow::StringPiece op_name, tensorflow::StringPiece short_name, tensorflow::StringPiece category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back( - {op_name.ToString(), short_name.ToString(), category.ToString(), cycles, - flop_count, transcendental_count, bytes_accessed, optimal_seconds}); + op_infos_.push_back({std::string(op_name), std::string(short_name), + std::string(category), cycles, flop_count, + transcendental_count, bytes_accessed, + optimal_seconds}); } // Gets the human-readable profile. @@ -61,10 +63,10 @@ class HumanReadableProfileBuilder { string short_name; string category; int64 cycles; - int64 flop_count; + int64 flop_count; // -1 if unknown int64 transcendental_count; - int64 bytes_accessed; - float optimal_seconds; + int64 bytes_accessed; // -1 if unknown + float optimal_seconds; // -1 if unknown }; double CyclesToSeconds(int64 cycles) const { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b3fa6c1572cf0ed91fc427722edcb23d8b8529d --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -0,0 +1,733 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gtl = ::tensorflow::gtl; + +namespace { +using Analysis = IndexedArrayAnalysis; +using UnknownArray = Analysis::UnknownArray; +using ConstantArray = Analysis::ConstantArray; +using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using tensorflow::gtl::ArraySlice; +using tensorflow::str_util::Join; +} // namespace + +string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { + switch (root->kind()) { + case Array::kUnknown: { + auto* unknown_tensor = root->as(); + return tensorflow::strings::StrCat("%", + unknown_tensor->instruction().name()); + } + + case Array::kConstant: { + if (print_constants) { + string contents = root->as()->literal()->ToString(); + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, + ")"); + } + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + } + + case Array::kScalarIndexedConstant: + case Array::kScalarIndexed: { + auto* indexed_array = root->as(); + string name = root->kind() == Array::kScalarIndexedConstant + ? "scalar-indexed-const" + : "scalar-indexed"; + return tensorflow::strings::StrCat( + "(", name, " ", ToString(indexed_array->source(), print_constants), + " ", ToString(indexed_array->indices(), print_constants), " ", + indexed_array->source_dim(), "->[", + Join(indexed_array->output_dims(), ","), "])"); + } + } +} + +StatusOr IndexedArrayAnalysis::GetArrayFor( + const HloInstruction* instr) { + auto it = cache_.find(instr); + if (it != cache_.end()) { + return it->second; + } + + TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); + return FindOrDie(cache_, instr); +} + +Status IndexedArrayAnalysis::TraverseAndPopulateCache( + const HloInstruction* root) { + // Depth first search over the DAG, invoking ComputeArrayFor in post order. + // The HLO instructions already in the cache are considered leaves. + + gtl::InlinedVector stack; + + enum DfsState { kDiscovered, kVisited }; + gtl::FlatMap dfs_state_map; + + stack.push_back(root); + InsertOrDie(&dfs_state_map, root, kDiscovered); + + do { + const HloInstruction* instr = stack.back(); + if (cache_.count(instr)) { + stack.pop_back(); + continue; + } + + switch (FindOrDie(dfs_state_map, instr)) { + case kDiscovered: { + for (const HloInstruction* operand : instr->operands()) { + if (!cache_.count(operand)) { + stack.push_back(operand); + CHECK(!dfs_state_map.count(operand) || + dfs_state_map[operand] == kDiscovered); + dfs_state_map[operand] = kDiscovered; + } + } + dfs_state_map[instr] = kVisited; + break; + } + + case kVisited: + stack.pop_back(); + TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); + InsertOrDie(&cache_, instr, array); + break; + } + } while (!stack.empty()); + + return Status::OK(); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayFor( + const HloInstruction* instr) { + Array* computed_array; + if (instr->IsElementwise() && instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseUnaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->IsElementwise() && instr->operand_count() == 2) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseBinaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForConstant(instr->literal())); + } else if (instr->opcode() == HloOpcode::kGather) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForReshape(instr->shape(), + FindOrDie(cache_, instr->operand(0)))); + } else { + computed_array = nullptr; + } + + if (!computed_array) { + computed_array = Construct(instr); + } + + return computed_array; +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( + const Literal& literal) { + return Construct(&literal); +} + +StatusOr IndexedArrayAnalysis::FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape) { + // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). + // `source` is the inner Gather(A, X). + + Array* a = source->source(); + Array* x = source->indices(); + Array* y = indices; + + // This bit is slightly tricky, so we do a naive "simulation" of the two + // consecutive gather operations to infer what the composed gather should look + // like. + + enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; + + std::vector simulated_index(a->shape().dimensions_size(), + IndexComponent::Ungathered); + + // Simulate the first gather. + EraseAt(&simulated_index, source->source_dim()); + for (int64 gather_dim : source->output_dims()) { + simulated_index.insert(simulated_index.begin() + gather_dim, + IndexComponent::GatheredFirst); + } + + // Simulate the second gather. + EraseAt(&simulated_index, source_dim); + for (int64 output_dim : output_dims) { + simulated_index.insert(simulated_index.begin() + output_dim, + IndexComponent::GatheredSecond); + } + + int64 source_dim_for_index_array = + FindIndex(source->output_dims(), source_dim); + CHECK_NE(source_dim_for_index_array, source->output_dims().size()); + + std::vector output_dims_for_index_array; + int64 gathered_index_components_seen = 0; + for (IndexComponent simulation_dim : simulated_index) { + if (simulation_dim == IndexComponent::GatheredSecond) { + output_dims_for_index_array.push_back(gathered_index_components_seen); + } + if (simulation_dim != IndexComponent::Ungathered) { + gathered_index_components_seen++; + } + } + + std::vector dim_sizes_for_composed_index; + std::vector output_dims_for_new_gather; + for (int64 i = 0, e = simulated_index.size(); i < e; i++) { + if (simulated_index[i] != IndexComponent::Ungathered) { + dim_sizes_for_composed_index.push_back(shape.dimensions(i)); + output_dims_for_new_gather.push_back(i); + } + } + + Array* inner_indices = ConstructScalarIndexedArray( + x, y, source_dim_for_index_array, output_dims_for_index_array, + ShapeUtil::MakeShape(x->shape().element_type(), + dim_sizes_for_composed_index)); + return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), + output_dims_for_new_gather, + std::move(shape)); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + return nullptr; + } + + CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + if (!c_binary_search(dim_numbers.elided_window_dims(), + dim_numbers.gather_dims_to_operand_dims(0))) { + return nullptr; + } + + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + std::vector output_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + output_dims.push_back(i); + } + } + + if (auto* indexed = dynamic_cast(source)) { + auto it = c_find(indexed->output_dims(), source_dim); + if (it != indexed->output_dims().end()) { + return FoldGatherOfGather(indexed, indices, source_dim, output_dims, + shape); + } + } else if (auto* constant = dynamic_cast(source)) { + return Construct(constant, indices, source_dim, + output_dims, shape); + } + + return Construct(source, indices, source_dim, output_dims, + shape); +} + +namespace { +// Returns an index into `values` such that the product of the range +// [values.begin()+index, values.end()) is equal to `product`. If there is no +// such index, return -1. All integers in `values` must be positive. +int64 FindSuffixWithProduct(ArraySlice values, int64 product) { + DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + + int64 current_product = 1; + int64 i; + for (i = values.size() - 1; i >= 0 && product > current_product; --i) { + current_product *= values[i]; + } + + if (product == current_product) { + return i + 1; + } + + return -1; +} + +struct ReshapePassthroughDimPair { + int64 result_dim; + int64 operand_dim; +}; + +// Returns a set of dimension pairs such for all (result_dim, operand_dim) in +// the set: +// +// output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim] +// +// The returned vector of pairs is sorted in both the result_dim and the +// operand_dim components. +std::vector ComputeReshapePassthroughDimPairs( + ArraySlice operand_shape, ArraySlice result_shape) { + // A reshape can be seen as an index mapping from output index to input index: + // + // (i_0, ..., i_n) = f(o_0, ..., o_m) + // + // This function returns the pairs (j, k) for which the following invariant + // holds for all indices in the shape: + // + // o_j == i_k + // + // And this occurs when: + // + // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m + // + // (where O_x are the sizes of the output shape and I_x are the sizes of the + // input shape) and the size of the dimension j of the result is the same as + // the size of dimension k in the operand. + // + // These conditions are sufficient because the Reshape HLO is spec'ed such + // that the rightmost dimensions are always minor in the flattening and refine + // operation. + + std::vector result; + int64 result_subarray_size = 1; + for (int64 result_dim = result_shape.size() - 1; result_dim >= 0; + --result_dim) { + int64 candidate_operand_dim = + FindSuffixWithProduct(operand_shape, result_subarray_size); + + // result_subarray_size does not include the elements in the current + // `result_dim` dimension (we multiply in result_shape[result_dim] at the + // end of loop body) so candidate_operand_dim can never be zero. + CHECK_NE(candidate_operand_dim, 0); + + if (candidate_operand_dim != -1 && + result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { + result.push_back({/*result_dim=*/result_dim, + /*operand_dim=*/candidate_operand_dim - 1}); + } + result_subarray_size *= result_shape[result_dim]; + } + + c_reverse(result); + + if (VLOG_IS_ON(3)) { + std::vector result_strings; + c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" + << Join(result_shape, ",") << "] passthrough indices are [" + << Join(result_strings, ",") << "]"; + } + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.result_dim < rhs.result_dim; + })); + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.operand_dim < rhs.operand_dim; + })); + + return result; +} + +// Return true if `dim` is stated as an passthrough operand dim in +// `passthrough_dims`. +bool IsReshapePassthroughOperandDim( + ArraySlice passthrough_dims, int64 dim) { + return c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); +} + +// Maps `operand_dim` which must be an passthrough operand dimension to its +// corresponding passthrough result dimension based on `passthrough_dims`. +int64 MapPassthroughOperandDimToResultDim( + ArraySlice passthrough_dims, int64 operand_dim) { + auto it = c_find_if(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); + CHECK(it != passthrough_dims.end()); + return it->result_dim; +} + +int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, + ArraySlice result_shape, + int64 source_passthrough_dim) { + int64 indexed_source_subarray_size = + std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, + operand_shape.end(), 1, std::multiplies()); + + return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); +} + +}; // namespace + +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + auto* scalar_indexed = dynamic_cast(operand); + if (!scalar_indexed) { + return nullptr; + } + + // Try to fold Reshape(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(Const', Indices) + // + // We can view the reshape and the scalar-indexed operations as functions that + // map an output index (i.e. an index into the result) to an input index + // (i.e. an index into the operand). The key idea used here is that the + // output-to-input mapping for some reshape operations may "pass through" some + // output dimensions into the input space unchanged -- i.e. there may exist + // output dimension "O" and input dimension "I" such that OutputIndex[O] is + // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through + // dimensions in the input space of the reshape happen to be include all the + // output dimensions for the scalar-indexed node then, roughly, the following + // holds: + // + // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx)) + // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs)) + // + // Where Ps are the set of the pass-through components of Idx that are + // also the output dims of the scalar-indexed node, and Qs are the rest. + // For brevity, we're playing fast and loose with the notation here -- we + // don't literally require Idx to be a concatenation of Ps and Qs, as + // suggested by the "++". + // + // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs)) + // + // Again, we're playing fast and loose with the notation around "++". + // Generally this ++ will be a different function that the ++ in the + // previous step. + // + // If the scalar-indexed node has a constant as the source then the + // SourceIndexOfReshape function can be "folded into" the constant itself by + // reshaping it, leaving us with: + // + // == SourceIndexOfScalarIndexed(Ps ++ Qs) + // == SourceIndexOfScalarIndexed(Idx) + // + // which is just a scalar-indexed node (with parameters different from the + // scalar-indexed node we started with) with a reshaped constant as the + // source. + // + // We can't fold SourceIndexOfReshape into the constant without introducing + // another precondition: since the new scalar-indexed node will have a + // reshaped (constant) array as its source it will, in general, have a + // different source dimension than the original scalar-indexed node. This + // source dimension will have to be a passthrough dimension of the + // SourceIndexOfReshape indexing function that is folded into the source. And + // such a dimension need not exist so this is a non-trivial precondition. + + std::vector reshape_passthrough_dims = + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*result_shape=*/AsInt64Slice(shape.dimensions())); + + auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { + return IsReshapePassthroughOperandDim(reshape_passthrough_dims, + operand_dim); + }; + + if (!c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { + return nullptr; + } + + // To compute the shape of the source for the new scalar-indexed node we're + // going to create, we first "undo" the scalar-indexed operation. + std::vector new_scalar_indexed_source_shape(shape.dimensions().begin(), + shape.dimensions().end()); + for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) { + int64 output_dim = scalar_indexed->output_dims()[i]; + int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim( + reshape_passthrough_dims, output_dim); + EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape); + } + + // After this, we need to add in the dimension that will be the source + // dimension for the new scalar-indexed node. A scalar-indexed node "removes" + // the source dimensions and "adds" the output dimensions, so to get back to + // the shape for the *source* of the scalar-indexed node we need to remove the + // output dims (which we did above) and then add back the source dim (which we + // are about to do below): + + const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape(); + + int64 source_dim_for_new_scalar_indexed_node = + FindSourcePositionForPassthroughResultDim( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape, + scalar_indexed->source_dim()); + + // We may not be able to find a source dim for the new scalar-indexed node. + // For instance consider: + // + // operand = s32[3,5,2] constant({...}) + // indices = s32[7] parameter(0) + // gather = s32[3,2,7] gather(operand, indices), + // output_window_dims={0,1}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3,1,2} + // reshape = s32[6,7] reshape(gather) + // + // In this case the gather maps to: + // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2]) + // + // and the reshape passes through dimension 2 from its input into dimension 1 + // in its output. However, we can't rewrite the reshape as a scalar-indexed + // node because then we'd have to reshape the [3,5,2] `operand` array to + // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently + // (a.k.a. isn't pass-through) than the [3,5,2] array. + + if (source_dim_for_new_scalar_indexed_node == -1) { + return nullptr; + } + + InsertAt( + &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, + scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + + CHECK(IsReshapePassthroughOperandDim( + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape), + scalar_indexed->source_dim())); + + auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) { + return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims, + result_dim); + }; + + std::vector output_dims_for_new_scalar_indexed_node; + c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); + + TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, + TakeOwnership(scalar_indexed->literal().Reshape( + new_scalar_indexed_source_shape))); + TF_ASSIGN_OR_RETURN( + Array * new_scalar_indexed_source, + ComputeArrayForConstant(*new_scalar_indexed_source_literal)); + + return ConstructScalarIndexedArray( + new_scalar_indexed_source, scalar_indexed->indices(), + source_dim_for_new_scalar_indexed_node, + output_dims_for_new_scalar_indexed_node, shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs) { + // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) + // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) + // + // We can do this if every output dimension from the scalar-indexed node is a + // broadcasted dimension for the broadcast node. Informally, the precondition + // means Broadcast(Const0)[IDX] is solely a function of the components of IDX + // that are not output-dims for the scalar-indexed node. In other words, for + // every assignment to the non-output dims in IDX we have a "constant" LHS to + // the BinaryOp. This transform propagates this "constant" to the source for + // the scalar-indexed node. + + ScalarIndexedConstantArray* lhs_scalar_indexed_const = + dynamic_cast(lhs); + ScalarIndexedConstantArray* rhs_scalar_indexed_const = + dynamic_cast(rhs); + + bool lhs_is_indexed; + + // One of the operands must be scalar-indexed and the other must be a + // broadcast of a constant. + if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) { + lhs_is_indexed = true; + } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) { + lhs_is_indexed = false; + } else { + return nullptr; + } + + ScalarIndexedConstantArray* scalar_indexed_const = + lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const; + UnknownArray* candidate_broadcast_array = + dynamic_cast(lhs_is_indexed ? rhs : lhs); + if (!candidate_broadcast_array || + candidate_broadcast_array->instruction().opcode() != + HloOpcode::kBroadcast) { + return nullptr; + } + + const HloInstruction* broadcast_instr = + &candidate_broadcast_array->instruction(); + const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0); + if (broadcast_const_operand->opcode() != HloOpcode::kConstant) { + return nullptr; + } + + ArraySlice broadcast_dims = broadcast_instr->dimensions(); + auto is_broadcasted_dim = [&](int64 output_dim) { + return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + }; + + // All of the output dims must be "broadcasted" dims for the other operand. + if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + return nullptr; + } + + // To figure out the broadcast dimensions for the (constant) source for the + // scalar-indexed node, we "simulate" the index transformation done by the + // existing broadcsat: + enum class IndexComponent { Broadcasted, NotBroadcasted }; + std::vector simulated_index( + broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); + for (int64 broadcast_dim : broadcast_dims) { + simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; + } + + // The scalar-indexed node "removes" the source dim and "inserts" the output + // dims. We do the opposite here to undo the scalar-indexed operation. + ArraySlice output_dims = scalar_indexed_const->output_dims(); + for (int64 i = output_dims.size() - 1; i >= 0; --i) { + CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); + EraseAt(&simulated_index, output_dims[i]); + } + + InsertAt(&simulated_index, scalar_indexed_const->source_dim(), + IndexComponent::Broadcasted); + + // new_inner_broadcast_dims holds the broadcast dimensions for the inner + // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to + // new_inner_broadcast_dims. + std::vector new_inner_broadcast_dims; + for (int64 i = 0; i < simulated_index.size(); i++) { + if (simulated_index[i] == IndexComponent::NotBroadcasted) { + new_inner_broadcast_dims.push_back(i); + } + } + + // inner_broadcast_result is the Broadcast'(Const0) bit in + // BinaryOp(Broadcast'(Const0), Const1) + TF_ASSIGN_OR_RETURN( + std::unique_ptr inner_broadcast_result, + broadcast_const_operand->literal().Broadcast( + scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); + + // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) + const Literal* literal_for_new_source; + if (lhs_is_indexed) { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + } else { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + } + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand) { + auto* scalar_indexed_const = + dynamic_cast(operand); + if (scalar_indexed_const == nullptr) { + return nullptr; + } + + // Fold UnaryOp(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(UnaryOp(Const), Indices) + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( + opcode, scalar_indexed_const->literal()))); + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { + return "indexed-array-analysis-printer-pass"; +} + +StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { + if (!VLOG_IS_ON(2)) { + return false; + } + + IndexedArrayAnalysis analysis; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instr : computation->instructions()) { + TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); + if (!dynamic_cast(t) && !dynamic_cast(t)) { + VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..ce92fd2919c90fa8a2fb7b796ed6f0fdaf48fe62 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -0,0 +1,326 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the + // input index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + + // `source_dim` is the dimension in the source array that is being indexed + // over using indices from the `indices` array. See the class documentation + // and the overview for more details. + int64 source_dim() const { return source_dim_; } + + // `output_dims` are the dimensions in the output array that are being used + // to compute an index into the `indices` array. See the class + // documentation and the overview for more details. + tensorflow::gtl::ArraySlice output_dims() const { + return output_dims_; + } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64 source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64 source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + StatusOr GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + string ToString(Array* root, bool print_constants = false); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + Status TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + StatusOr ComputeArrayFor(const HloInstruction* instr); + + StatusOr ComputeArrayForConstant(const Literal& literal); + + StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + StatusOr FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape); + + StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); + + StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, Array* rhs); + StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + Literal* TakeOwnership(std::unique_ptr literal) { + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + StatusOr TakeOwnership( + StatusOr> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + std::vector> owned_tensors_; + std::vector> owned_literals_; + tensorflow::gtl::FlatMap cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloPassInterface { + public: + tensorflow::StringPiece name() const override; + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..373556ebeba883f7dc2116bdf0ffc3274182f775 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -0,0 +1,504 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace { +class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + protected: + void AssertArrayForRootExpressionIs(const string& hlo_text, + const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/false); + } + + void AssertArrayWithConstantsForRootExpressionIs( + const string& hlo_text, const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/true); + } + + private: + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, + const string& root_expression, + bool print_constants) { + IndexedArrayAnalysis indexed_tensor_analysis; + ParseAndVerifyModule(hlo_text); + + TF_ASSERT_OK_AND_ASSIGN( + IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + module().entry_computation()->root_instruction())); + string string_result = + indexed_tensor_analysis.ToString(array_result, print_constants); + LOG(INFO) << string_result; + ASSERT_EQ(string_result, root_expression); + } +}; + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5] parameter(0) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices_a = s32[5] parameter(0) + indices_b = s32[2] parameter(1) + gather_a = s32[5,3] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT gather_b = s32[2,3] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a " + "%indices_b 0->[0]) 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[2] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 1->[1]) 1->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,6] parameter(0) + indices_a = s32[2] parameter(1) + indices_b = s32[5,7] parameter(2) + gather_a = s32[2,6] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 0->[0,1]) 0->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[4,8] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=2, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " + "1->[0,2]) 1->[0,1,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5] parameter(0) + gather = s32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT reshape = s32[5,2,2] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,7] parameter(0) + gather = s32[5,4,7] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,2,6] constant(s32[3,2,6]{ + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[5,7] parameter(0) + gather = s32[5,2,6,7] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,2,6} + ROOT reshape = s32[5,3,4,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,2,3] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,5,2] constant(s32[3,5,2]{ + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}}) + indices = s32[7] parameter(0) + gather = s32[3,2,7] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1,2} + ROOT reshape = s32[6,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { + string hlo_text = R"( +HloModule UnaryOpOfGather + +ENTRY main { + operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + indices = s32[5] parameter(0) + gather = f32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT tanh = f32[5,4] tanh(gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant f32[3,4] f32[3,4] { + { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, + { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, + { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 6, 7, 8, 9 }, + { 6, 8, 7, 9 }, + { 9, 8, 7, 6 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsLhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { -4, -3, -2, -1 }, + { -4, -2, -3, -1 }, + { -1, -2, -3, -4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsRhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 4, 3, 2, 1 }, + { 4, 2, 3, 1 }, + { 1, 2, 3, 4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[4] constant({10,11,12,13}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 11, 13, 15, 17 }, + { 11, 14, 14, 17 }, + { 14, 14, 14, 14 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[5] constant({10,11,12,13,14}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input = f32[100] parameter(0) + ROOT tanh = f32[100] tanh(input) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%tanh"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input0 = f32[100] parameter(0) + input1 = f32[100] parameter(1) + ROOT add = f32[100] add(input0, input1) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7aa1c7c8358318d02a000d968a2672123400ad6e..d2af261008f40ee83e0676cfc7e67c45f8be1844 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({4, 3, 3, 4}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } // Test that `constant` function is changed to `broadcast`. @@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({3, 1, -1, -3}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f494748e17fc2d0de74dec67f7414d4791f76a07..abedb4063d3763516e66cff36633dbd90c8cafde 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,6 +28,25 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// These nodes can always be duplicated into consumers, even if +// InstructionFusion::may_duplicate_ is false. +// +// In general these should be nodes that get *cheaper* the more they're +// duplicated (and fused into consumers). +// +// TODO(jlebar): Duplicating instructions when we have a variable called "may +// duplicate" that's equal to false is not pretty. +bool IsAlwaysDuplicable(const HloInstruction& instruction) { + // We are always willing to duplicate a widening type-conversion instruction + // if it means we can fuse the convert into a consumer. This allows the + // consumer to read less memory, which is almost always a performance win. + return instruction.opcode() == HloOpcode::kConvert && + ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction.shape()); +} +} // namespace + /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -39,6 +58,7 @@ namespace xla { case HloOpcode::kBroadcast: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kClz: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: @@ -76,6 +96,7 @@ namespace xla { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kGenerateToken: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; @@ -98,13 +119,16 @@ namespace xla { case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kDivide: + case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: case HloOpcode::kHostCompute: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: @@ -127,11 +151,11 @@ namespace xla { return false; } -// An "effectively unary" operation is one that has one "large" +// An "effectively at most unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. -bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { +bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { int64 output_rank = 0; ShapeUtil::ForEachSubshape( hlo->shape(), @@ -155,66 +179,89 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { } bool InstructionFusion::CanFuseOnAllPaths( - const HloReachabilityMap& reachability_map, HloInstruction* producer, - HloInstruction* consumer, DoNotFuseSet* do_not_fuse) { - auto could_fuse_on_all_paths = [&] { - // First check to see if we have already marked this producer as infeasible - // to fuse into consumer. - if (do_not_fuse->count(producer) > 0) { + HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_duplicate) { + if (consumer == producer) { + return true; + } + if (!consumer->IsFusable()) { + return false; + } + for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { + auto* consumer_operand = consumer->mutable_operand(i); + // If the operand is not on a path to the producer, it doesn't matter + // whether it's fusable. + if (!reachability_->IsReachable(producer, consumer_operand)) { + continue; + } + if (do_not_duplicate.count(consumer_operand) > 0 || + !ShouldFuse(consumer, i)) { return false; } - // Make sure it is possible for producer and consumer to exist in a fusion - // node. - if (!producer->IsFusable() || !consumer->IsFusable()) { + // The producer is reachable from consumer_operand which means we need + // to be able to fuse consumer_operand into consumer in order for + // producer to be fusable into consumer on all paths. + // Perform the recursive step: make sure producer can be fused into + // consumer_operand on all paths. + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { return false; } - // We do an upward walk of the graph from consumer towards all paths which - // lead to producer to find any unfusable paths. - for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { - auto* consumer_operand = consumer->mutable_operand(i); - if (consumer_operand == producer) { - // This is the base case: our upward crawl ends but we need to make sure - // that fusion from consumer can happen. - if (!ShouldFuse(consumer, i)) { - return false; - } - } else if (reachability_map.IsReachable(producer, consumer_operand)) { - // The reachability map told us that consumer_operand is a node on the - // path to producer. We need to further investigate from - // consumer_operand. - - // First check if we have already ruled out fusing producer into - // consumer_operand. - if (do_not_fuse->count(consumer_operand) > 0) { - return false; - } - // Make sure it is possible for consumer_operand to exist in a fusion - // node. - if (!consumer_operand->IsFusable()) { - return false; - } - // The producer is reachable from consumer_operand which means we need - // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. - if (!ShouldFuse(consumer, i)) { - return false; - } - // Perform the recursive step: make sure producer can be fused into - // consumer_operand on all paths. - if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand, - do_not_fuse)) { - return false; - } + } + return true; +} + +InstructionFusion::HloInstructionSet +InstructionFusion::ComputeGloballyUnfusable( + tensorflow::gtl::ArraySlice post_order) { + // Forbid fusion of producers that: + // a) Need to be duplicated, unless they can be fused into all consumers + // via all paths. + // b) Are more than unary, that is, fusing them would likely lead to an + // increase in memory bandwidth use. + // + // Note that if we allow fusion by these global rules, we may still forbid + // fusing operations that require duplication later depending on + // is_expensive_(). + HloInstructionSet do_not_duplicate; + for (HloInstruction* consumer : post_order) { + for (HloInstruction* producer : consumer->operands()) { + if (do_not_duplicate.count(producer) > 0) { + continue; } + + // If the producer is effectively not more than unary, duplicating it + // will not increase the number of relevant inputs read, as the fusion + // node will only need to read at most 1 relevant input (the input of + // the producer). In that case, we do not forbid fusion of the operation + // here. + if (EffectivelyAtMostUnary(producer)) { + continue; + } + // Otherwise we will forbid fusing the op unless we can fuse it into + // all of its consumers on all paths. + // + // That means, that for: + // A --> B (fusable) + // \-> C (non-fusable) + // A will be not allowed to be fused into B, as it cannot be fused into C. + // + // Similarly, for: + // A -------------> B + // \-> C -> D -/ + // If: + // - A is fusable into B and C, and D is fusable into B + // - C is *not* fusable into D + // A will be not allowed to be fused into B, as it cannot be fused via + // all paths. + if (producer->IsFusable() && + CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { + continue; + } + do_not_duplicate.insert(producer); } - return true; - }; - if (could_fuse_on_all_paths()) { - return true; } - // We couldn't fuse on all paths, record this result. - do_not_fuse->insert(producer); - return false; + + return do_not_duplicate; } StatusOr InstructionFusion::Run(HloModule* module) { @@ -226,6 +273,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; + reachability_ = computation_->ComputeReachability(); // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make @@ -243,36 +291,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - DoNotFuseSet do_not_fuse; - auto reachability = computation->ComputeReachability(); - - auto cheap_to_duplicate = [this](HloInstruction* producer) { - if (producer->opcode() == HloOpcode::kBroadcast) { - return true; - } - if (producer->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(producer->shape())) { - return true; - } - if (EffectivelyUnary(producer)) { - return true; - } - return false; - }; - - for (HloInstruction* consumer : post_order) { - for (HloInstruction* producer : consumer->operands()) { - if (cheap_to_duplicate(producer)) { - continue; - } - if (CanFuseOnAllPaths(*reachability, producer, consumer, - &do_not_fuse)) { - CHECK_EQ(do_not_fuse.count(producer), 0); - } else { - CHECK_GT(do_not_fuse.count(producer), 0); - } - } - } + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -302,7 +321,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { // Consider each operand of this instruction for fusion into this // instruction. We want to consider the operands in a particular order to - // avoid created duplicate instruction clones in the fusion instruction. + // avoid creating duplicate instruction clones in the fusion instruction. // For example, consider the following expression: // // A = ... @@ -340,9 +359,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { // ensures that B will be considered before A. // // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers(instruction->operands().size()); - std::iota(std::begin(sorted_operand_numbers), - std::end(sorted_operand_numbers), 0); + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the post-order index. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (post_order_index.find(instruction->mutable_operand(i)) == + post_order_index.end()) { + continue; + } + sorted_operand_numbers.push_back(i); + } std::sort( sorted_operand_numbers.begin(), sorted_operand_numbers.end(), [&](int64 i, int64 j) { @@ -359,13 +389,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { if (!operand->IsFusable()) { continue; } - if (!ShouldFuse(instruction, i)) { - continue; - } - if (do_not_fuse.count(operand) > 0) { + + HloInstruction* fusion_instruction; + // Try "regular" fusion if the operand may be duplicated. Otherwise, + // perform multi-output fusion, unless this creates a cycle. + // TODO(tjoerg): Consider making multi-output fusion the default. + if (ShouldFuse(instruction, i) && + do_not_duplicate.count(operand) == 0) { + fusion_instruction = Fuse(operand, instruction); + } else if (ShouldFuseIntoMultiOutput(instruction, i) && + !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_instruction = FuseIntoMultiOutput(operand, instruction); + } else { continue; } - HloInstruction* fusion_instruction = Fuse(operand, instruction); // Fusing an instruction into a fusion instruction can change the // operand set of the fusion instruction. For simplicity just push the @@ -377,7 +414,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting it's + // Operand is now dead. Remove from post order by setting its // location to nullptr. post_order[FindOrDie(post_order_index, operand)] = nullptr; post_order_index.erase(operand); @@ -396,12 +433,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { return changed; } -HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { +HloInstruction* InstructionFusion::AddFusionInstruction( + HloInstruction* producer, HloInstruction* consumer) { HloInstruction* fusion_instruction; - - VLOG(2) << "Fusing " << producer->ToString() << " into " - << consumer->ToString(); auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -413,17 +447,48 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } + return fusion_instruction; +} +HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, + HloInstruction* consumer) { + VLOG(2) << "Fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); fusion_instruction->FuseInstruction(producer); return fusion_instruction; } +HloInstruction* InstructionFusion::FuseIntoMultiOutput( + HloInstruction* producer, HloInstruction* consumer) { + VLOG(2) << "Multi-output fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + return fusion_instruction; +} + +bool InstructionFusion::MultiOutputFusionCreatesCycle( + HloInstruction* producer, HloInstruction* consumer) { + return c_any_of( + consumer->operands(), [&](const HloInstruction* consumer_operand) { + // The fusion algorithm traverses the HLO graph in reverse post order. + // Thus `cosumers` is visited before its operands (including + // `producer`). Therefore, consumer operands cannot have been fused yet. + // It is thus safe to use the pre-computed reachability map. + return consumer_operand != producer && + reachability_->IsReachable(producer, consumer_operand); + }); +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && - (is_expensive_(*producer) || !may_duplicate_)) { + (!may_duplicate_ || is_expensive_(*producer)) && + !IsAlwaysDuplicable(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 152d0886ee9eda19961e092df44cb234ee2bd29d..f73ca9adf768ed26f9ec9f162e01b7b160f50daf 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -61,6 +61,14 @@ class InstructionFusion : public HloPassInterface { // Subtypes can override this with target-specific heuristics. virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index); + // Returns whether multi-output fusion can be applied to fuse `producer` into + // `consumer`. In contrast to "regular" fusion, the `producer` is not + // duplicated by multi-output fusion. + virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + return false; + } + // Chooses a fusion kind for `producer` and `consumer`. // Default method chooses `kLoop`. virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, @@ -70,11 +78,18 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); - // An "effectively unary" operation is one that has one "large" + // Creates a new fusion instruction containing `producer` and `consumer`. A + // tuple is added as the fusion instruction's root, which consumes from both, + // `producer` and `consumer`. This style of fusion is referred to as + // multi-output fusion. + virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer); + + // An "effectively unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. - bool EffectivelyUnary(HloInstruction* hlo); + bool EffectivelyAtMostUnary(HloInstruction* hlo); // Returns true if fusing producer into consumer would cause producer to be // duplicated. This is the case if producer has uses other than consumer. @@ -90,21 +105,34 @@ class InstructionFusion : public HloPassInterface { // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; + // Reachability information for the current computation. + std::unique_ptr reachability_; private: // The set of producers whose consumers we cannot fuse into. - using DoNotFuseSet = std::unordered_set; + using HloInstructionSet = std::unordered_set; - // Whether or not we can fuse consumer into original_producer on all paths + HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer); + + // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map, - HloInstruction* producer, HloInstruction* consumer, - DoNotFuseSet* do_not_fuse); + bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_fuse); + + // Computes the set of nodes that we do not want to fuse into any of their + // consumers based on a global analysis of the HLO graph. + HloInstructionSet ComputeGloballyUnfusable( + tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 0fa2c95fb458f8f2b863388fd77bca5f10372a0a..21db2338995960bde00ec9c4b325e5562fc3a592 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -16,12 +16,100 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { +namespace op = xla::testing::opcode_matchers; + using InstructionFusionTest = HloTestBase; +// Subclass of InstructionFusion exposing the protected methods Fuse and +// FuseIntoMultiOutput for testing. +class InstructionFusionForTesting : public InstructionFusion { + public: + explicit InstructionFusionForTesting(HloModule* module) + : InstructionFusion(InstructionFusion::IsExpensive) { + module_ = module; + computation_ = module->entry_computation(); + } + + HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::Fuse(producer, consumer); + } + + HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::FuseIntoMultiOutput(producer, consumer); + } +}; + +TEST_F(InstructionFusionTest, FuseInstructions) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT sub = f32[4,3]{1,0} subtract(add, p0) + })") + .ValueOrDie(); + HloInstruction* sub = module->entry_computation()->root_instruction(); + HloInstruction* add = sub->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(add, sub); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), + op::Subtract(op::Add(), op::Parameter())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + fused_computation { + p1 = f32[4,3] parameter(0) + add = f32[4,3] add(p1, p1) + } + ENTRY entry_computation { + p0 = f32[4,3] parameter(0) + abs = f32[4,3] abs(p0) + ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(abs, root); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + abs = f32[4,3]{1,0} abs(p0) + tanh = f32[4,3]{1,0} tanh(abs) + ROOT add = f32[4,3]{1,0} add(abs, tanh) + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* tanh = root->mutable_operand(1); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs())) + << module->ToString(); +} + TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( @@ -89,7 +177,172 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); +} + +// Counts the number of HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT root = f32[4,3]{1,0} subtract(add, add) + })") + .ValueOrDie(); + // Expect the add and subtraction to be fused. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); +} + +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { + // Make sure we do not duplicate the add, as we cannot fuse through the rng. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> rng -> abs2 -/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + rng = f32[4,3]{1,0} rng(abs1), distribution=rng_uniform + abs2 = f32[4,3]{1,0} abs(rng) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + // We expect abs2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); + + // Use a log node with a second consumer to break the fusion. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> log -> abs2 -/ + // \-> send + module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + log = f32[4,3]{1,0} log(abs1) + send = f32[4,3]{1,0} send(log), channel_id=0 + abs2 = f32[4,3]{1,0} abs(log) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + + // We expect abs2 to be fused into root and abs1 to be fused into log. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 2) << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); + + // Make sure we still fuse ops where one operand in the chain to the producer + // can't be fused. + // + // p0 ---> add1 -----------> sub + // \ \-> add2 -/ + // \-> log -/ + // \-> send + module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add1 = f32[4,3]{1,0} add(p0, p0) + log = f32[4,3]{1,0} log(p0) + send = f32[4,3]{1,0} send(log), channel_id=0 + add2 = f32[4,3]{1,0} add(log, add1) + ROOT root = f32[4,3]{1,0} subtract(add1, add2) + })") + .ValueOrDie(); + + // Expect the add1 and add2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + + // Make sure we didn't duplicate any adds. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); + + // A variant of the above that allows the algorithm to put add2 into the set + // of unfusable ops to short-circuit the decision whether add1 should be fused + // into sub2. + // + // /---------------\ + // p0 ---> add1 ---> add2 ------> sub2 + // \------> sub1 + // log -/ + // \-> send + module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add1 = f32[4,3]{1,0} add(p0, p0) + add2 = f32[4,3]{1,0} add(add1, add1) + log = f32[4,3]{1,0} log(add2) + send = f32[4,3]{1,0} send(log), channel_id=0 + sub1 = f32[4,3]{1,0} subtract(log, add2) + sub2 = f32[4,3]{1,0} subtract(add2, add1) + ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) + })") + .ValueOrDie(); + + // Expect sub1 and sub2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); + + // Make sure we didn't duplicate any adds. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); } TEST_F(InstructionFusionTest, AllowUnaryDuplication) { @@ -135,4 +388,29 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, + WideningConvertsAreAlwaysDuplicableIntoConsumers) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f16[100] parameter(0) + c = f32[100] convert(p0) + add = f32[100] add(c, c) + ROOT mul = f32[100] multiply(c, c) + })") + .ValueOrDie(); + + // The convert should be fused into the add and mul, even though may_duplicate + // is false, because it's always beneficial to fuse/duplicate widening + // converts into consumers. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion(op::Parameter())); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 0819ab3b90b2360c6b0b2afaa89f322afe566eb3..524d3234eb4eff9c7d000eca1a0d9f5c4fae90af 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -18,7 +18,6 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -63,10 +62,7 @@ cc_library( name = "platform_id", srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], - deps = [ - "@nsync//:nsync_headers", - "//tensorflow/core:stream_executor_headers_lib", - ] + if_static( + deps = ["//tensorflow/core:stream_executor_headers_lib"] + if_static( ["@protobuf_archive//:protobuf"], ["@protobuf_archive//:protobuf_headers"], ), @@ -120,17 +116,5 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/compiler/xla/service/interpreter/README.md b/tensorflow/compiler/xla/service/interpreter/README.md index 4c19a1b916d42149c670f4d3bd1d11cff87cf075..0b21b251c3f663540292d98e5a609b3e27446d38 100644 --- a/tensorflow/compiler/xla/service/interpreter/README.md +++ b/tensorflow/compiler/xla/service/interpreter/README.md @@ -5,7 +5,7 @@ evaluating the result of the HLO graph directly with HloEvaluator, without lowering it further (to LLVM IR for example) before execution as other backends (CPU and GPU for example) do. -Its key componenets are: +Its key components are: * [`InterpreterCompiler`] despite the inherited naming of "compiler", all `InterpreterCompiler` really does is the following: diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9171e859c6f84ceef9664aa1eb90a07c87dfab40..c1666530687f2f8407a9dcb4e271c9d95552a689 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -34,22 +34,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace interpreter { -namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; - Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - + hlo_module->mutable_device_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } @@ -74,7 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module)); + xla::MakeUnique(std::move(hlo_module), + xla::MakeUnique()); return std::move(executable); } @@ -96,7 +92,7 @@ InterpreterCompiler::CompileAheadOfTime( } se::Platform::Id InterpreterCompiler::PlatformId() const { - return sep::kInterpreterPlatformId; + return se::interpreter::kXlaInterpreterPlatformId; } HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() @@ -104,16 +100,14 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() return InterpreterExecutable::ShapeSizeBytes; } -static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); -} - static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(sep::kInterpreterPlatformId, []() { - return xla::MakeUnique(); - }); - xla::ComputationPlacer::RegisterComputationPlacer(sep::kInterpreterPlatformId, - &CreateComputationPlacer); + xla::Compiler::RegisterCompilerFactory( + se::interpreter::kXlaInterpreterPlatformId, []() { + return xla::MakeUnique(); + }); + xla::ComputationPlacer::RegisterComputationPlacer( + se::interpreter::kXlaInterpreterPlatformId, + []() { return xla::MakeUnique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index c8660c04d86a82e7dfcfd1658310c2a0e4fa0083..e90ae3e818522e6e4fd9d9f5acb846800bc899ca 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -44,19 +44,16 @@ class InterpreterCompiler : public Compiler { ~InterpreterCompiler() override {} StatusOr> RunHloPasses( - std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( - std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( std::vector> hlo_modules, - std::vector> - stream_exec, + std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> @@ -65,7 +62,7 @@ class InterpreterCompiler : public Compiler { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; private: Status RunHloOptimization(HloModule* hlo_module); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 0cb9b5d8107cd8bf468b07d5fe2a22930d9e8b8c..029e71058a7373b9310c6d9ffdb65f72ca28e5af 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -32,22 +31,21 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace interpreter { -namespace se = ::perftools::gputools; - InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module, + std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + /*hlo_profile_index_map=*/nullptr), + evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} -StatusOr> InterpreterExecutable::ExecuteOnStream( +StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -84,18 +82,21 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( } // Execute the graph using the HloEvaluator. - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, - evaluator.Evaluate>(*computation, arg_literals)); + std::unique_ptr result_literal; + { + tensorflow::mutex_lock lock(evaluator_lock_); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate>( + *computation, arg_literals)); + } // Transform the result literal back into a ShapedBuffer. - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - transfer_manager->AllocateShapedBuffer( + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + transfer_manager->AllocateScopedShapedBuffer( result_literal->shape(), run_options->allocator(), - run_options->device_ordinal())); + executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - executor, *result_literal, *result)); + executor, *result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -108,8 +109,7 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( return std::move(result); } -StatusOr> -InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 410110a1adf04c83001c38ed03f5d60dd203dc7e..91d8148d26dc8eddbafdaf4870d9efbb73a12816 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -40,20 +42,27 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module, + std::unique_ptr evaluator); ~InterpreterExecutable() override; - StatusOr> ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override + LOCKS_EXCLUDED(evaluator_lock_); - StatusOr> ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; static int64 ShapeSizeBytes(const Shape& shape); + protected: + // The interpreter interprets executables with an HloEvaluator. + std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + mutable tensorflow::mutex evaluator_lock_; + private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 68371910d76f42c0b6d4b1adad9d6a83bdb858e6..97e9fa2c8e8ecd918ffe3df2fd4e731f3b91e6db 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { host::HostStream *AsExecutorStream(Stream *stream) { @@ -28,84 +27,85 @@ host::HostStream *AsExecutorStream(Stream *stream) { return dynamic_cast(stream->implementation()); } -InterpreterExecutor::InterpreterExecutor(const PluginConfig &plugin_config) +XlaInterpreterExecutor::XlaInterpreterExecutor( + const PluginConfig &plugin_config) : plugin_config_(plugin_config) {} -InterpreterExecutor::~InterpreterExecutor() {} +XlaInterpreterExecutor::~XlaInterpreterExecutor() {} -void *InterpreterExecutor::Allocate(uint64 size) { return new char[size]; } +void *XlaInterpreterExecutor::Allocate(uint64 size) { return new char[size]; } -void *InterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, - uint64 offset_bytes, - uint64 /*size_bytes*/) { +void *XlaInterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, + uint64 offset_bytes, + uint64 /*size_bytes*/) { return parent + offset_bytes; } -void InterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { +void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { if (!mem->is_sub_buffer()) { delete[] static_cast(mem->opaque()); } } -bool InterpreterExecutor::Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, uint64 size) { +bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, + uint64 size) { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); return true; } -bool InterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, - const void *host_src, uint64 size) { +bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, + const void *host_src, uint64 size) { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); return true; } -port::Status InterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, - const void *host_src, - uint64 size) { +port::Status XlaInterpreterExecutor::SynchronousMemcpy( + DeviceMemoryBase *dev_dst, const void *host_src, uint64 size) { memcpy(dev_dst->opaque(), host_src, size); return port::Status::OK(); } -port::Status InterpreterExecutor::SynchronousMemcpy( +port::Status XlaInterpreterExecutor::SynchronousMemcpy( void *host_dst, const DeviceMemoryBase &dev_src, uint64 size) { memcpy(host_dst, dev_src.opaque(), size); return port::Status::OK(); } -bool InterpreterExecutor::HostCallback(Stream *stream, - std::function callback) { +bool XlaInterpreterExecutor::HostCallback(Stream *stream, + std::function callback) { AsExecutorStream(stream)->EnqueueTask(callback); return true; } -bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, - Stream *other) { +bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, + Stream *other) { AsExecutorStream(dependent)->EnqueueTask( [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } -bool InterpreterExecutor::StartTimer(Stream *stream, Timer *timer) { +bool XlaInterpreterExecutor::StartTimer(Stream *stream, Timer *timer) { dynamic_cast(timer->implementation())->Start(stream); return true; } -bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { +bool XlaInterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { dynamic_cast(timer->implementation())->Stop(stream); return true; } -port::Status InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); return port::Status::OK(); } -DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { +DeviceDescription *XlaInterpreterExecutor::PopulateDeviceDescription() const { internal::DeviceDescriptionBuilder builder; builder.set_device_address_bits(64); @@ -118,5 +118,4 @@ DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { } } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c5d07e906dafb033905c50c604069e80e1ce80cd..9b109022fbfc698f7dadc678ef837da270a5e74a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Declares the InterpreterExecutor class, which is a CPU-only implementation of -// the StreamExecutor interface. For now, this is used for testing and to +// Declares the XlaInterpreterExecutor class, which is a CPU-only implementation +// of the StreamExecutor interface. For now, this is used for testing and to // examine the performance of host-based StreamExecutor code. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ @@ -44,16 +44,15 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/timer.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { using Args = tensorflow::gtl::ArraySlice; -class InterpreterExecutor : public internal::StreamExecutorInterface { +class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: - explicit InterpreterExecutor(const PluginConfig &plugin_config); - ~InterpreterExecutor() override; + explicit XlaInterpreterExecutor(const PluginConfig &plugin_config); + ~XlaInterpreterExecutor() override; port::Status Init(int device_ordinal, DeviceOptions device_options) override { return port::Status::OK(); @@ -213,7 +212,6 @@ class InterpreterExecutor : public internal::StreamExecutorInterface { }; } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index cf98ecd7749d61261bf072cdb1882c7603f39556..d27cd7502f10a1f615fc5b0d610acafdf55e3e43 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -21,12 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" -namespace sei = ::perftools::gputools::interpreter; - namespace xla { InterpreterTransferManager::InterpreterTransferManager() - : GenericTransferManager(sei::kInterpreterPlatformId, + : GenericTransferManager(se::interpreter::kXlaInterpreterPlatformId, /*pointer_size=*/sizeof(void*)) {} } // namespace xla @@ -38,7 +36,8 @@ CreateInterpreterTransferManager() { static bool InitModule() { xla::TransferManager::RegisterTransferManager( - sei::kInterpreterPlatformId, &CreateInterpreterTransferManager); + stream_executor::interpreter::kXlaInterpreterPlatformId, + &CreateInterpreterTransferManager); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index a60e7fc59f7c5f0b1b24e026b34e195ca0fe5ebb..42c2c28997d5f3b02f1fe4effca164c893e4071d 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -28,24 +27,22 @@ limitations under the License. #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" -namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; - -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { -InterpreterPlatform::InterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, + const Platform::Id& id) + : name_(name), id_(id) {} -InterpreterPlatform::~InterpreterPlatform() {} +XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id InterpreterPlatform::id() const { return kInterpreterPlatformId; } +Platform::Id XlaInterpreterPlatform::id() const { return id_; } -int InterpreterPlatform::VisibleDeviceCount() const { return 1; } +int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } -const string& InterpreterPlatform::Name() const { return name_; } +const string& XlaInterpreterPlatform::Name() const { return name_; } -port::StatusOr InterpreterPlatform::ExecutorForDevice( +port::StatusOr XlaInterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; @@ -55,7 +52,7 @@ port::StatusOr InterpreterPlatform::ExecutorForDevice( } port::StatusOr -InterpreterPlatform::ExecutorForDeviceWithPluginConfig( +XlaInterpreterPlatform::ExecutorForDeviceWithPluginConfig( int device_ordinal, const PluginConfig& plugin_config) { StreamExecutorConfig config; config.ordinal = device_ordinal; @@ -64,16 +61,17 @@ InterpreterPlatform::ExecutorForDeviceWithPluginConfig( return GetExecutor(config); } -port::StatusOr InterpreterPlatform::GetExecutor( +port::StatusOr XlaInterpreterPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } port::StatusOr> -InterpreterPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = port::MakeUnique( - this, port::MakeUnique(config.plugin_config)); +XlaInterpreterPlatform::GetUncachedExecutor( + const StreamExecutorConfig& config) { + auto executor = MakeUnique( + this, MakeUnique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ @@ -86,28 +84,26 @@ InterpreterPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { return std::move(executor); } -void InterpreterPlatform::RegisterTraceListener( +void XlaInterpreterPlatform::RegisterTraceListener( std::unique_ptr listener) { LOG(FATAL) << "not yet implemented: register executor trace listener"; } -void InterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { +void XlaInterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { LOG(FATAL) << "not yet implemented: unregister executor trace listener"; } -static void InitializeInterpreterPlatform() { - std::unique_ptr platform(new sep::InterpreterPlatform); - SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); +static void InitializeXlaInterpreterPlatform() { + std::unique_ptr platform(new XlaInterpreterPlatform); + SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); } } // namespace interpreter -} // namespace gputools -} // namespace perftools - -REGISTER_MODULE_INITIALIZER(interpreter_platform, - sep::InitializeInterpreterPlatform()); +} // namespace stream_executor -DECLARE_MODULE_INITIALIZER(multi_platform_manager); +REGISTER_MODULE_INITIALIZER( + interpreter_platform, + stream_executor::interpreter::InitializeXlaInterpreterPlatform()); // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index c66ddb907d1c5a8e99d3178a202a77a72a646ce5..0187f6d473b19f50136e214708e56f833627d9d1 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -18,19 +18,20 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/trace_listener.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { -class InterpreterPlatform : public Platform { +class XlaInterpreterPlatform : public Platform { public: - InterpreterPlatform(); - ~InterpreterPlatform() override; + XlaInterpreterPlatform(const string& name = "Interpreter", + const Platform::Id& id = kXlaInterpreterPlatformId); + ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -56,15 +57,16 @@ class InterpreterPlatform : public Platform { private: // This platform's name. string name_; + // This platform's id. + Platform::Id id_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; - SE_DISALLOW_COPY_AND_ASSIGN(InterpreterPlatform); + SE_DISALLOW_COPY_AND_ASSIGN(XlaInterpreterPlatform); }; } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.cc b/tensorflow/compiler/xla/service/interpreter/platform_id.cc index 1a0373cf86e26b564e0e732e8de1a0a5d868bfa6..3272396ce5045129a7689a160ec859d11fbbe9fa 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform_id.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.cc @@ -14,12 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { -PLATFORM_DEFINE_ID(kInterpreterPlatformId); +PLATFORM_DEFINE_ID(kXlaInterpreterPlatformId); } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.h b/tensorflow/compiler/xla/service/interpreter/platform_id.h index 905efef1690d3bd32461353fe32dd394eb6bca9e..a6cc10bcc1eb756a3146d4a834efa4cd3ceb2d27 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform_id.h +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.h @@ -18,14 +18,12 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { -extern const Platform::Id kInterpreterPlatformId; +extern const Platform::Id kXlaInterpreterPlatformId; } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_ diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 0668f66051ce96292c3c85bac7e649d89914106c..7067b6f86a0fb24fb946ad236bca9bbd48d53722 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -57,76 +59,6 @@ namespace xla { // anonymous namespace, instead of three or four spread all over this file. namespace { -// Creates and returns a copy of the given instruction with a different -// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple -// instruction producing the copy is returned. -StatusOr CreateCopyWithNewLayout( - const Shape& shape_with_layout, HloInstruction* instruction) { - TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); - DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) - << ShapeUtil::HumanString(shape_with_layout) << " " - << ShapeUtil::HumanString(instruction->shape()) - << " instruction: " << instruction->ToString(); - - if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. - std::vector element_copies; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); - ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); - } - // Gather element copies into a tuple with a new Tuple instruction. - HloInstruction* tuple_copy = instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); - LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, tuple_copy->mutable_shape())); - return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { - HloInstruction* copy = - instruction->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, copy->mutable_shape())); - - return copy; - } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); - } -} - -// Creates a copy of the given operand if the operand's layout does not match -// the given layout. This copy replaces the use in the given instruction. Tuple -// operands will be deep-copied. -Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no) { - HloInstruction* operand = instruction->mutable_operand(operand_no); - TF_RET_CHECK(operand_layout.LayoutIsSet()); - TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { - // Operand layout already matches our constraint. Nothing to do. - return Status::OK(); - } - - TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, - CreateCopyWithNewLayout(operand_layout.shape(), operand)); - - return instruction->ReplaceOperandWith(operand_no, operand_copy); -} } // namespace @@ -192,17 +124,34 @@ LayoutConstraints::LayoutConstraints( } } +PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( + const HloInstruction* instruction) const { + auto it = buffer_sets_cache_.find(instruction); + if (it != buffer_sets_cache_.end()) { + return it->second.get(); + } + auto& buffer_set = + buffer_sets_cache_ + .emplace(instruction, MakeUnique()) + .first->second; + const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); + points_to_set.ForEachElement( + [&buffer_set](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + buffer_set->insert(buffers.begin(), buffers.end()); + }); + return buffer_set.get(); +} + bool LayoutConstraints::OperandBufferForwarded( const HloInstruction* instruction, int64 operand_no) const { // The operand is potentially forwarded if the intersection of points-to sets // of the operand and the instruction is non-empty. - auto output_buffers = - points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); - auto operand_buffers = - points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) - .CreateFlattenedSet(); - for (const LogicalBuffer* output_buffer : output_buffers) { - if (operand_buffers.count(output_buffer) > 0) { + PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); + PointsToSet::BufferSet* operand_buffers = + GetBufferSet(instruction->operand(operand_no)); + for (const LogicalBuffer* output_buffer : *output_buffers) { + if (operand_buffers->count(output_buffer) > 0) { return true; } } @@ -453,9 +402,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints) { + const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, HloComputation* computation, + LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -477,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { - // Parameter layouts must match the respective layout in - // ComputationLayout. - shape_with_layout = - &computation_layout.parameter_layout(instruction->parameter_number()) - .shape(); + if (computation_layout != nullptr) { + const ShapeLayout& parameter_layout = + computation_layout->parameter_layout( + instruction->parameter_number()); + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + shape_with_layout = ¶meter_layout.shape(); + } + } } if (shape_with_layout != nullptr) { TF_RETURN_IF_ERROR( @@ -546,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints( HloComputation* body = instruction->while_body(); HloComputation* condition = instruction->while_condition(); const HloInstruction* init = instruction->operand(0); - const ComputationLayout& body_layout = - FindOrDie(computation_layouts_, body); - const ComputationLayout& condition_layout = + ComputationLayout& body_layout = FindOrDie(computation_layouts_, body); + ComputationLayout& condition_layout = FindOrDie(computation_layouts_, condition); // Check a few invariants irrespective of layout. @@ -561,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints( condition_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); - // Return error if earlier layout assignment of the embedded computations - // has produced conflicting layouts. - if (!ShapeUtil::Equal(body_layout.result_shape(), - body_layout.parameter_shape(0))) { - return InternalError( - "Parameter and result of body computation %s of while instruction " - "%s have different layouts: %s vs %s", - body->name().c_str(), instruction->name().c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str(), - ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + if (body_layout.result_layout() != body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while body parameter layout: body=" << body->name() + << " while=" << instruction->name() + << " shape=" << body_layout.result_layout().ToString(); + *body_layout.mutable_parameter_layout(0) = body_layout.result_layout(); } - if (!ShapeUtil::Equal(body->root_instruction()->shape(), - condition->parameter_instruction(0)->shape())) { - return InternalError( - "Parameter of condition computation %s of while instruction " - "%s does not match body computation %s result: %s vs %s", - condition->name().c_str(), instruction->name().c_str(), - body->name().c_str(), - ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + if (condition_layout.parameter_layout(0) != + body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while condition parameter layout: cond=" + << condition->name() << " while=" << instruction->name() + << " shape=" << body_layout.parameter_layout(0).ToString(); + *condition_layout.mutable_parameter_layout(0) = + body_layout.parameter_layout(0); } // Constrain the output and the operand of the while instruction to match @@ -610,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints( true_computation_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible( false_operand->shape(), false_computation_layout.parameter_shape(0))); - + if (true_computation_layout.result_layout() != + false_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the true and false computations + // might have negotiated a different layout. But for the conditional + // instruction POV the layout must match, so we run again on the false + // computation, this time with proper computation layout. + VLOG(2) << "Reset %conditional false computation result layout: " + "false_computation=" + << false_computation->name() + << " conditional=" << instruction->name() << " shape=" + << true_computation_layout.result_layout().ToString(); + *false_computation_layout.mutable_result_layout() = + true_computation_layout.result_layout(); + } TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( true_computation_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( @@ -646,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints( } } } - - // Finally set the result layout to match ComputationLayout. - return constraints->SetResultLayout( - computation_layout.result_layout().shape()); + // Finally set the result layout to match ComputationLayout, if there is one. + if (computation_layout != nullptr) { + const ShapeLayout& result_layout = computation_layout->result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); + } + } + return Status::OK(); } namespace { @@ -776,6 +739,106 @@ Status CheckConstantLayout(HloInstruction* constant) { } // namespace +StatusOr LayoutAssignment::CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction) { + TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); + DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) + << ShapeUtil::HumanString(shape_with_layout) << " " + << ShapeUtil::HumanString(instruction->shape()) + << " instruction: " << instruction->ToString(); + + if (ShapeUtil::IsTuple(instruction->shape())) { + // Deep-copy tuples. + std::vector element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + SetupCopiedInstruction(*instruction, gte, {i}); + // Recurse to copy each elements. + TF_ASSIGN_OR_RETURN( + HloInstruction * element_copy, + CreateCopyWithNewLayout( + ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); + element_copies.push_back(element_copy); + } + // Gather element copies into a tuple with a new Tuple instruction. + HloInstruction* tuple_copy = instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); + SetupCopiedInstruction(*instruction, tuple_copy, {}); + LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, tuple_copy->mutable_shape())); + return tuple_copy; + } else if (ShapeUtil::IsArray(instruction->shape())) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*instruction, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, copy->mutable_shape())); + + return copy; + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +// Creates a copy of the given operand if the operand's layout does not match +// the given layout. This copy replaces the use in the given instruction. Tuple +// operands will be deep-copied. +Status LayoutAssignment::CopyOperandIfLayoutsDiffer( + const ShapeLayout& operand_layout, HloInstruction* instruction, + int64 operand_no) { + HloInstruction* operand = instruction->mutable_operand(operand_no); + TF_RET_CHECK(operand_layout.LayoutIsSet()); + TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); + + if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + VLOG(5) << "Operand " << operand->ToString() << " layout matches in " + << instruction->ToString(); + // Operand layout already matches our constraint. Nothing to do. + return Status::OK(); + } + VLOG(4) << "Operand " << operand->ToString() << " layout does not match " + << operand_layout.ToString() << " in " << instruction->ToString(); + + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); + + VLOG(4) << "New copy of " << operand->ToString() << " is " + << operand_copy->ToString(); + return instruction->ReplaceOperandWith(operand_no, operand_copy); +} + +void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, + HloInstruction* copy, + const ShapeIndex& index) { + if (instruction.has_sharding()) { + // If the index is empty, we want to copy the whole sharding, in case the + // sharding is a tuple sharding. + HloSharding sharding = + !index.empty() && instruction.sharding().IsTuple() + ? instruction.sharding().GetSubSharding(instruction.shape(), index) + : instruction.sharding(); + // We propagate the sharding to the copied instruction only if it is a + // special sharding, like tiled ones, or special devices like the + // HostCompute module. + // Otherwise it is preferable to leave the new instruction without device, + // and let the automatic device placer to choose the best location. + if (!sharding.HasUniqueDevice() || + HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { + copy->set_sharding(sharding); + } + } + copy->set_metadata(instruction.metadata()); +} + Status LayoutAssignment::CheckLayouts(HloModule* module) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -856,15 +919,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } } } - - // Finally verify the result layout matches the layout of the entry + // Finally verify the result layout, if set, matches the layout of the entry // computation root. - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), + const ShapeLayout& result_layout = FindOrDie(computation_layouts_, module->entry_computation()) - .result_layout() - .shape())); - + .result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + result_layout.shape())); + } return Status::OK(); } @@ -873,18 +937,13 @@ LayoutAssignment::LayoutAssignment( ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { - VLOG(1) << "entry computation layout given to layout assignment: " + VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. for (const ShapeLayout& parameter_layout : entry_computation_layout_->parameter_layouts()) { CHECK(parameter_layout.LayoutIsSet()); } - // If the result layout is not set, then choose the default. - // TODO(b/29118294): Choose a better layout in this case. - if (!entry_computation_layout_->result_layout().LayoutIsSet()) { - entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout(); - } } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1444,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); } +Status LayoutAssignment::CalculateComputationLayout( + HloComputation* computation) { + ComputationLayout computation_layout(computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << " Calculated ComputationLayout = " + << computation_layout.ToString(); + return Status::OK(); +} + +Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + return Status::OK(); +} + Status LayoutAssignment::RunOnComputation( - const ComputationLayout& computation_layout, + ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { - DCHECK(computation_layout.LayoutIsSet()); - InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; - VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); + if (computation_layout != nullptr) { + auto it = computation_layouts_.find(computation); + if (it == computation_layouts_.end()) { + VLOG(2) << " New ComputationLayout = " << computation_layout->ToString(); + computation_layouts_.emplace(computation, *computation_layout); + } else { + TF_RET_CHECK(computation_layout == &it->second || + computation_layout == entry_computation_layout_); + VLOG(2) << " Existing ComputationLayout = " + << computation_layout->ToString(); + } + } else { + VLOG(2) << " No ComputationLayout specified (will be calculated)"; + } // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1496,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation( CHECK_LT(constraints.unconstrained_buffer_ids().size(), unconstrained_count); } - // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + // If the computation layout wasn't specified, now it is the time to compute + // it according to the parameters and root instruction layouts. + // This allows the first pass through this API to record the best flowing + // layout to parameters and root instruction. + if (computation_layout == nullptr) { + TF_RETURN_IF_ERROR(CalculateComputationLayout(computation)); + } + // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. for (HloInstruction* instruction : computation->instructions()) { @@ -1516,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation( return Status::OK(); } +Status LayoutAssignment::PropagateComputationLayouts( + HloComputation* computation, ComputationLayout* computation_layout) { + ComputationLayout computed_computation_layout( + computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) { + ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); + if (!param_layout->LayoutIsSet()) { + VLOG(4) << "Assigning layout to parameter " << i << " of computation " + << computation->name() << ": " + << computed_computation_layout.parameter_layout(i).ToString(); + *param_layout = computed_computation_layout.parameter_layout(i); + } else { + TF_RET_CHECK(computed_computation_layout.parameter_layout(i) == + *param_layout); + } + } + ShapeLayout* result_layout = computation_layout->mutable_result_layout(); + if (!result_layout->LayoutIsSet()) { + VLOG(4) << "Assigning result layout of computation " << computation->name() + << ": " << computed_computation_layout.result_layout().ToString(); + *result_layout = computed_computation_layout.result_layout(); + } else { + TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + } + return Status::OK(); +} + StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -1524,45 +1662,45 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "before layout assignment", module->config().debug_options()); } - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Assign layouts to computations in an order such that a callee computation - // is handled before its caller computation. This ensures that the layout of - // all callers of a computation will agree. - std::list computation_post_order = - module->MakeComputationPostOrder(); - for (auto* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kInfeed) { - LayoutUtil::ClearLayout(instruction->mutable_shape()); + TF_RETURN_IF_ERROR(Init()); + + // We do two passes. The first one we pass a nullptr ComputationLayout to + // the RunOnComputation() calls (for non entry computations), and we register + // the ComputationLayout which are naturally flowing in DFS fashion to the + // parameters and root instruction. + // Walking in DFS mode though, means that we can end up with incorrect layouts + // when seen from an outer instruction, which has across-computation + // constraints to impose. + // For example, the kWhile instruction needs to enforce the same layouts for + // the parameters and root of the bosy, as well as the condition parameters. + // Similarly, the kConditional instruction needs to enforce the same layouts + // for the root of the true and false computations. + // So in the first pass, while allowing the layouts to flow to parameters and + // root, we also fix up the eventually inconsistent ComputationLayout, which + // will be then made mandatory by the second pass. + for (int64 i = 0; i < 2; ++i) { + TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; + } + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); + } else { + ComputationLayout* computation_layout = + (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, + *points_to_analysis, computation, + channel_layout_constraints_)); } - } - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - *entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); - } else { - ComputationLayout computation_layout(computation->ComputeProgramShape()); - // Setting all embedded computations to the default layout is potentially - // suboptimal. - computation_layout.SetToDefaultLayout(); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, - channel_layout_constraints_)); } } - + TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; @@ -1572,9 +1710,54 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "after layout assignment", module->config().debug_options()); } - // All layouts are reset then reassigned by this pass. return true; } +Status LayoutAssignment::Init() { + computation_layouts_.clear(); + return Status::OK(); +} + +Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + // Clear all the copies which have been added, and all the related + // instructions (like GTE and tuples). + int64 removed_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCopy && + added_copies_.count(instruction) > 0) { + VLOG(5) << "Removing added copy: " << instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + ++removed_copies; + } + } + } + added_copies_.clear(); + if (removed_copies > 0) { + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, + int64 operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + SetupCopiedInstruction(*operand, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + } + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 29018584487cabfd740d7914625c2a50f552d6ff..c287cca0c54ba1bb514bd8d243c137eca99b258f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -199,6 +201,11 @@ class LayoutConstraints { string ToString() const; private: + // Find a bufferset in the bufferset cache. This is useful since we can + // currently create the flattened buffer set for the same instruction many + // times, which is often slow. + PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; + // The set of BufferLayoutConstraints applied to the computation. std::unordered_map buffer_constraints_; @@ -221,6 +228,10 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; + mutable tensorflow::gtl::FlatMap> + buffer_sets_cache_; + HloComputation* computation_; }; @@ -271,8 +282,8 @@ class LayoutAssignment : public HloPassInterface { // the case that no particular layout is requested. // // channel_constraints is both an input and output. Any sends or recvs that - // are present in channel_constraints will be layed out as constrained. Any - // unconstrained sends or recvs will be layed out as locally optimal and their + // are present in channel_constraints will be laid out as constrained. Any + // unconstrained sends or recvs will be laid out as locally optimal and their // layout will be added as a constraint to channel_constraints. // // If channel_constraints is nullptr, no kSend or kRecvs must be contained @@ -352,12 +363,15 @@ class LayoutAssignment : public HloPassInterface { int64 operand_no); private: + // Initializes the layout assignment object for a new Run() call. + Status Init(); + // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints); + Status AddMandatoryConstraints(const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, + LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -368,10 +382,12 @@ class LayoutAssignment : public HloPassInterface { } // Construct contraints and assign layouts to all instructions in the - // computation satisfying the given ComputationLayout. Layouts constraints are - // added, then propagated until all LogicalBuffers in the computation are - // constrained. - Status RunOnComputation(const ComputationLayout& computation_layout, + // computation satisfying the given ComputationLayout, if not nullptr. + // Otherwise the ComputationLayout will be calculated by propagating the + // computation instruction contraints. + // Layouts constraints are added, then propagated until all LogicalBuffers in + // the computation are constrained. + Status RunOnComputation(ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -392,15 +408,73 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); + // Computes the ComputationLayout of the given computation based of the + // layouts assigned to parameters and root instruction, and inserts it to the + // computation_layouts_ map. + Status CalculateComputationLayout(HloComputation* computation); + + // Clears all the layouts which can be cleared within a computation. + Status ClearComputationLayouts(HloComputation* computation); + + // Clears the side effects of a previous pass, like added copy instructions. + Status ClearPreviousPassSideEffects(HloModule* module); + + // Propagates the layouts computed by the layout assignment pass on the given + // computation, to the computation layout passed in to this API. + // This API propagates missing layout, and also checks that the caller + // specified have been respected, by comparing those with the parameters and + // root computation instruction. + Status PropagateComputationLayouts(HloComputation* computation, + ComputationLayout* computation_layout); + ComputationLayout* entry_computation_layout_; - ChannelLayoutConstraints* channel_layout_constraints_; protected: + // Sets up the copy instruction according to the characteristic (sharding, + // metadata, ...) of the reference instruction. The index argument is used + // when the instruction is a tuple, and in such case the index represents + // the location from where the copy instruction was created from. + // If the index is empty, the whole sharding will be propagated, even in case + // the intruction has a tuple sharding. + static void SetupCopiedInstruction(const HloInstruction& instruction, + HloInstruction* copy, + const ShapeIndex& index); + + // Creates and returns a copy of the given instruction with a different + // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple + // instruction producing the copy is returned. + StatusOr CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction); + + // Creates a copy of the given operand if the operand's layout does not match + // the given layout. This copy replaces the use in the given instruction. + // Tuple operands will be deep-copied. + Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + + // Registers a copy instruction added by the layout assignment pass. + void RegisterAddedCopy(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + added_copies_.insert(copy); + } + + // Adds a copy for the operand of an instruction, unless such operand is + // already a copy, and has a single user (which is forcibly the instruction + // itself). + Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); + // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + + // Every copy added to the module by the layout assignment pass is registered + // here. + tensorflow::gtl::FlatSet added_copies_; + + ChannelLayoutConstraints* channel_layout_constraints_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index dd0fba2758f0d77e72bc55138df229b24c026677..bf0448a67674f24591d866b646b98aea09ebb12c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -590,6 +590,84 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { transpose->shape(), {2, 3, 0, 1})); } +// TransposeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = builder.AddInstruction( + HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1})); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(), + hlo->shape(), hlo->dimensions()), + "LayoutUtil::HasLayout"); +} + +// ReshapeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = + builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param)); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH( + ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()), + "LayoutUtil::HasLayout"); +} + +// Check that the computation below doesn't crash the compiler. +// +// Within a fusion computation, only the parameters and result get assigned a +// layout. When we run the algebraic simplifier on this computation post layout +// assignment, it should not call TransposeIsBitcast on the `transpose` node +// inside the fusion computation as TransposeIsBitcast checks both input_shape +// and output_shape have layouts. +TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { + const char* module_str = R"( + HloModule test_module + + fused_computation { + param_1 = f32[2,2,2]{2,1,0} parameter(1) + transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1} + reduce_1 = f32[] parameter(0) + broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={} + ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1) + } + + ENTRY entry_computation { + fusion.1 = f32[2,2,2]{2,1,0} parameter(1) + reduce.1 = f32[] parameter(0) + fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation + ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + EXPECT_EQ(Status::OK(), backend() + .compiler() + ->RunBackend(std::move(module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); +} + // A GTE inside of a fusion node inherits the layout of its operand (which // should, if we keep following operands, eventually be a parameter). TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { @@ -613,7 +691,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( @@ -629,33 +707,29 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - HloComputation* fused_computation = *std::find_if( - module->computations().begin(), module->computations().end(), - [](const HloComputation* c) { return c->name() == "fused_computation"; }); - - auto fused_instr = [&](const string& name) { - auto it = std::find_if( - fused_computation->instructions().begin(), - fused_computation->instructions().end(), - [&](const HloInstruction* i) { return i->name() == name; }); - CHECK(it != fused_computation->instructions().end()); - return *it; + auto layout_of = [&](tensorflow::StringPiece name) { + return FindInstruction(module.get(), name) + ->shape() + .layout() + .minor_to_major(); }; - EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(), - ElementsAre(0, 1, 2)); - EXPECT_THAT( - fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(), - ElementsAre(1, 2, 0)); - EXPECT_THAT( - fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(), - ElementsAre(2, 0, 1)); - EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(), + EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(module.get(), "gte1") + ->shape() + .tuple_shapes(0) + .layout() + .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(), + EXPECT_THAT(FindInstruction(module.get(), "gte1") + ->shape() + .tuple_shapes(1) + .layout() + .minor_to_major(), ElementsAre(2, 0, 1)); - EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(), - ElementsAre(2, 1, 0)); } TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { @@ -721,5 +795,26 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } +TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant0->shape(), HloOpcode::kBitcast, constant0)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + LayoutAssignment layout_assignment(&computation_layout); + Status error_status = layout_assignment.Run(module.get()).status(); + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT( + error_status.error_message(), + ::testing::HasSubstr( + "Unexpected bitcast operation seen during layout assignment")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc deleted file mode 100644 index 68c99256a246edcf43a8358f667fc4458b9b4fea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ /dev/null @@ -1,379 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return true; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return false if any uses are detected at 'index', returns true otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return false; - } - } - // Return true: found no uses of 'operand' at 'index' in 'user'. - return true; - } - return false; -} - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : - dataflow.GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; - } - } - } - } - - return true; -} - -namespace { - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'instruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector> uses; - const PointsToSet::BufferList& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if there is exactly one use of 'operand' at 'operand_index' -// in 'fusion.fused_instructions', where the singleton use is the fused -// root at operand index 'use_operand_index'. Returns false otherwise. -// -// REQUIRES: 'fusion' opcode is a kFusion instruction. -bool HasUniqueFusedUseOfOperandAt( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* fusion, const int64 use_operand_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Check that 'operand' is unique in the operand list of 'fusion'. - if (fusion->OperandIndices(operand).size() > 1) { - return false; - } - // Find fusion parameter associated with 'operand'. - const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { - return fusion->operand(fused_param->parameter_number()) == operand; - }); - if (fused_param_it == fused_params.end()) { - return false; - } - auto* fused_param = *fused_param_it; - // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root (at index in 'use_operand_indices'). - return fused_param_uses.size() == 1 && - fused_param_uses[0].first == fusion->fused_expression_root() && - fused_param_uses[0].second == use_operand_index; -} - -} // namespace - -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following qualifications: -// -// (1) Is element-wise. Or... -// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index -// 0. -// -// (2) and (3) can only be determined if points-to analysis is available. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, - other_add_operand_index, - points_to_analysis); - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // TODO(b/62548313): Remove when buffer assignment is module scoped and - // does not assign buffers to calls. - // Find called computation parameter associated with 'operand'. - const std::vector operand_indices = user->OperandIndices(operand); - if (operand_indices.size() > 1) { - return false; - } - CHECK_EQ(1, operand_indices.size()); - auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); - // Get all uses of 'operand' at 'index' in called computation. - auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, - points_to_analysis); - - // Return true iff: - // *) There exists exactly one use of 'operand' in called computation. - // *) The unique use is by the root instruction of called computation. - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - auto* callee_root = user->to_apply()->root_instruction(); - return param_uses.size() == 1 && param_uses[0].first == callee_root && - callee_root->IsElementwiseOnOperand(param_uses[0].second); - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& value = - dataflow.GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - return false; - } - const HloUse& use = value.uses()[0]; - - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // Get all uses of value defined by 'operand' at 'operand_index'. - const auto& uses = - dataflow.GetValueDefinedAt(operand, operand_index).uses(); - // Return true iff: - // *) There exists two uses of 'operand'. - // *) One use is by 'user' (caller). - // *) One use is by root instruction of called computation (callee root). - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { - return use.instruction == user; - }) != uses.end(); - auto* callee_root = user->to_apply()->root_instruction(); - const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); - return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h deleted file mode 100644 index 28ef991880039de73cc158a67ef2a5f78fc90e6d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// A collection of utilities on the HLO graph. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ - -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// Returns true if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis); -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow); - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc deleted file mode 100644 index f8b309488eeb5391b1cad5db760934ec1f7e3521..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ /dev/null @@ -1,463 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -class PointsToAnalysisTestBase : public HloTestBase { - protected: - void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); - computation_ = module_->AddEntryComputation(std::move(computation)); - } - - void RunAnalysis() { - CHECK_NOTNULL(module_.get()); - points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); - } - - void BuildModuleAndRunAnalysis(std::unique_ptr computation) { - BuildModule(std::move(computation)); - RunAnalysis(); - } - - std::unique_ptr module_; - HloComputation* computation_ = nullptr; - std::unique_ptr points_to_analysis_; - std::unique_ptr dataflow_analysis_; -}; - -class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; - -TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { - auto builder = HloComputation::Builder(TestName()); - - Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); -} - -TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction never uses tuple element 0, but does use element 1. - EXPECT_TRUE( - DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); -} - -class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto log = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape in_shape = ShapeUtil::MakeShape(F32, {8}); - Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, in_shape, "param0")); - auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *dataflow_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction can share with tuple element 1. - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *dataflow_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto update = builder.AddInstruction( - HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); - auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // The DynamicUpdateSlice instruction can share with the data operand, but not - // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, dot}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused dot add should be able to share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - auto b_t = builder.AddInstruction( - HloInstruction::CreateTranspose(data_shape, b, {1, 0})); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - - auto nested_fusion = computation_->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); - - auto fusion = computation_->CreateFusionInstruction( - {add, nested_fusion}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused transpose-dot-add should be share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto reverse = builder.AddInstruction( - HloInstruction::CreateReverse(data_shape, operand, {0, 1})); - - auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, two, reverse}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused operand->reverse->add cannot alias operand buffer 'operand'. - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - - auto make_cond = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Cond"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); - return builder.Build(); - }; - - auto make_body = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Body"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); - return builder.Build(); - }; - - module_ = CreateNewModule(); - HloComputation* cond_computation = - module_->AddEmbeddedComputation(make_cond()); - HloComputation* body_computation = - module_->AddEmbeddedComputation(make_body()); - - auto builder = HloComputation::Builder(TestName()); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto whil = builder.AddInstruction(HloInstruction::CreateWhile( - data_shape, cond_computation, body_computation, data)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); -} - -// Tests that Call can alias operand buffer if the only use of the operand -// in the called computation is an elementwise instruction. -TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { - Shape shape = ShapeUtil::MakeShape(F32, {8}); - // Build sub-computation with fusion root. - auto sub_builder = HloComputation::Builder(TestName() + "_sub"); - auto sub_param = sub_builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "sub_param")); - auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto ones = sub_builder.AddInstruction( - HloInstruction::CreateBroadcast(shape, one, {1})); - auto add = sub_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - - module_ = CreateNewModule(); - auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); - sub_computation->CreateFusionInstruction({add, ones}, - HloInstruction::FusionKind::kLoop); - - // Build entry-computation with kCall which calls 'sub_computation'. - auto builder = HloComputation::Builder(TestName()); - - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto reverse = - builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); - auto call = builder.AddInstruction( - HloInstruction::CreateCall(shape, {reverse}, sub_computation)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *dataflow_analysis_)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index 911b243fe28a5baf8a4b8ed752b892265f5388ac..b17c9d504501a907e27d5152e0082799e87443c7 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -23,7 +23,7 @@ limitations under the License. namespace xla { StatusOr>> LLVMCompiler::Compile( std::vector> modules, - std::vector> stream_execs, + std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: // diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index d74e81bb7f622ac5e89203a3d02ca5ad839da07e..f1c623508c5307f2b1c036d3ec6823b75c7eda13 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -60,19 +60,18 @@ class LLVMCompiler : public Compiler { // Bring in // StatusOr> RunBackend( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec, + // se::StreamExecutor* stream_exec, // DeviceMemoryAllocator* device_allocator) // StatusOr> RunHloPasses( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec, + // se::StreamExecutor* stream_exec, // DeviceMemoryAllocator* device_allocator) using Compiler::RunBackend; using Compiler::RunHloPasses; StatusOr>> Compile( std::vector> modules, - std::vector> - stream_execs, + std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) override; protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 37261ed1e665ebed9685751161a412ad114a9e96..f1e7fc29532ce7e6841010a5258f4000a7c70383 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,17 +169,3 @@ cc_library( "@llvm//:core", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index bc683a1880b010d57e83aa6e9ffa95fda299e1a0..d909845a3a21fc55e44b0037371fca30e577980f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; @@ -151,7 +153,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; - return tensorflow::Status::OK(); + return Status::OK(); } FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 6384c7f46f5ebbedaeda232b40095611a5d738a4..7323abeb2077154f82828bcda3e90eb45a67138a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,18 +29,13 @@ limitations under the License. namespace xla { namespace llvm_ir { -IrArray::Index::Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) - : multidim_(ShapeUtil::Rank(shape)), - linear_(linear), - layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - CHECK(LayoutUtil::HasLayout(shape)) - << "Shape " << ShapeUtil::HumanStringWithLayout(shape) - << " should have a layout."; +static void Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) { int64 divisor = 1; - for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) { - int64 dimension = layout_.minor_to_major(i); + const Layout& layout = shape.layout(); + for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { + int64 dimension = layout.minor_to_major(i); int64 size_of_current_dimension = shape.dimensions(dimension); // If i is not the last dimension, compute @@ -54,16 +49,28 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); - if (i < layout_.minor_to_major_size() - 1) { - multidim_[dimension] = ir_builder->CreateURem( + if (i < layout.minor_to_major_size() - 1) { + (*multidim)[dimension] = ir_builder->CreateURem( quot, ir_builder->getInt64(size_of_current_dimension)); } else { - multidim_[dimension] = quot; + (*multidim)[dimension] = quot; } divisor *= size_of_current_dimension; } } +IrArray::Index::Index(llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) + : multidim_(ShapeUtil::Rank(shape)), + linear_(linear), + layout_(shape.layout()), + dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK(LayoutUtil::HasLayout(shape)) + << "Shape " << ShapeUtil::HumanStringWithLayout(shape) + << " should have a layout."; + Delinearize(&multidim_, linear, shape, ir_builder); +} + IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), @@ -83,7 +90,6 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, dims_(shape.dimensions().begin(), shape.dimensions().end()) { CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); - linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder); } IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) @@ -106,16 +112,13 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) } } -// Returns whether given linear index valid on given shape. +// Returns whether the given linear index is valid on the given shape. bool IrArray::Index::LinearValidOnShape(const Shape& a) const { - auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_); + auto b = ShapeUtil::MakeShape(a.element_type(), dims_); *b.mutable_layout() = layout_; return linear_ != nullptr && - ContainersEqual( - ShapeUtil::StripDegenerateDimensions(a).dimensions(), - ShapeUtil::StripDegenerateDimensions(b).dimensions()) && - LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(), - ShapeUtil::StripDegenerateDimensions(b).layout()); + ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) && + ShapeUtil::ReshapeIsBitcast(a, b); } IrArray::Index IrArray::Index::SourceIndexOfReshape( @@ -160,7 +163,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } } - if (linear() != nullptr && + if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape) && ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } @@ -195,13 +199,111 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose( llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); - if (linear() != nullptr && + + if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && + LayoutUtil::HasLayout(shape) && ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { return Index(operand_multidim_index, linear(), operand_shape); } + return Index(operand_multidim_index); } +IrArray::Index IrArray::Index::SourceIndexOfBitcast( + const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const { + CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); + // In case the bitcast is just a reshape, we can use SourceIndexOfReshape() + // instead. This will reuse linear() if possible, so we don't have to build a + // new 'linear_index'. + if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) { + return SourceIndexOfReshape(shape, operand_shape, builder); + } + + // First linearize the index coming from the output of the bitcast. We want + // the physical index of the element in the buffer. This is like Linearize, + // but takes the layout into account. + int64 scale = 1; + llvm::Value* linear_index = builder->getInt64(0); + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { + linear_index = builder->CreateAdd( + linear_index, + builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + /*HasNUW=*/true, /*HasNSW=*/true), + "", /*HasNUW=*/true, /*HasNSW=*/true); + scale *= shape.dimensions(dimension); + } + + // Now delinearize it for the input of the bitcast. + std::vector multi_index(operand_shape.dimensions_size()); + Delinearize(&multi_index, linear_index, operand_shape, builder); + + return Index(multi_index, linear_index, operand_shape); +} + +IrArray::Index IrArray::Index::SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const { + int64 rank = ShapeUtil::Rank(operand_shape); + std::vector source_index(rank); + for (int64 i = 0; i < rank; ++i) { + source_index[i] = multidim_[dimension_mapping[i]]; + } + if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || + !LayoutUtil::HasLayout(shape)) { + return Index(source_index); + } + // High-level idea: we can reuse the linear index if the broadcasted + // dimensions are contiguous, and this part of the operation is a bitcast. + // The other dimensions can be masked out with a div and a mod operation. + std::vector logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(shape.layout()); + int64 output_rank = ShapeUtil::Rank(shape); + // The minimum physical dimension that is broadcasted. + int64 min_broadcasted_dimension = output_rank; + // The maximum physical dimension that is broadcasted. + int64 max_broadcasted_dimension = -1; + for (int64 i = 0; i < rank; ++i) { + int64 physical_dim = logical_to_physical[dimension_mapping[i]]; + min_broadcasted_dimension = + std::min(min_broadcasted_dimension, physical_dim); + max_broadcasted_dimension = + std::max(max_broadcasted_dimension, physical_dim); + } + bool contiguous_broadcast_dimensions = + max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; + if (!contiguous_broadcast_dimensions) { + return Index(source_index); + } + // Check if the mapped dimensions are a bitcast. + std::vector operand_logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(operand_shape.layout()); + for (int64 i = 0; i < rank; ++i) { + if (operand_logical_to_physical[i] != + logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { + return Index(source_index); + } + } + llvm::Value* linear = linear_; + int64 divisor = 1; + for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) { + divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + if (divisor > 1) { + linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + } + if (min_broadcasted_dimension > 0) { + int64 mod = 1; + for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension; + ++i) { + mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + linear = builder->CreateURem(linear, builder->getInt64(mod)); + } + return Index(source_index, linear, operand_shape); +} + llvm::Value* IrArray::Index::Linearize( tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const { @@ -231,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress( } CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); - std::vector actual_index; - bool is_implicit_broadcast = false; - // We perform broadcasting when the operand shape has dimension(s) of size - // 1. In this case we fix the index value for that dimension to zero. This - // effectively broadcasts along this dimension. - for (int64 i = 0; i < index.size(); ++i) { - auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); - is_implicit_broadcast |= dim == 1; - } - - if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( @@ -252,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress( {index.linear()}, llvm_ir::AsStringRef(name)); } + std::vector actual_index; + for (int64 i = 0; i < index.size(); ++i) { + // When dimension i is of size 1, LLVM optimization is able to replace + // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to + // produce better code in some cases. + auto dim = shape_->dimensions(i); + actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + } + // "base_ptr_" has the type of "*" // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element // should be computed by diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 387d4629125cbb791840e943013188d14159908a..4c3195c29c859c9eef08e3f6531b059edbebfc47 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -76,8 +76,7 @@ class IrArray { llvm::IRBuilder<>* ir_builder); // Constructs an index from the given multi-dimensional index and the shape - // that it indexes into. Also, computes the linear index according to - // "shape". + // that it indexes into. // // Precondition: "shape" has a layout. Index(tensorflow::gtl::ArraySlice multidim, @@ -98,6 +97,10 @@ class IrArray { llvm::Value*& operator[](size_t i) { return multidim()[i]; } void push_back(llvm::Value* value) { multidim().push_back(value); } + void InsertAt(int64 index, llvm::Value* value) { + CHECK_LE(index, size()); + multidim().insert(multidim().begin() + index, value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; @@ -134,6 +137,18 @@ class IrArray { tensorflow::gtl::ArraySlice dimension_mapping, llvm::IRBuilder<>* builder) const; + // Given that "this" is the target index of a bitcast from `operand_shape` + // to `shape`, returns the source index. + Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const; + + // Given that "this" is the target index of a broadcast from `operand_shape` + // to `shape` with the given dimension mapping, returns the source index. + Index SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const; + // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6..1f6e3c829f890d68aa251b101f0402c120a19d61 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -15,53 +15,57 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, - const std::function& for_body_generator) { - If(ir_builder_->CreateICmpSLT(start, end), [&]() { - for_body_generator(start, /*is_first_iteration=*/true); - For(name, ir_builder_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { for_body_generator(iv, false); }); + const std::function& for_body_generator) { + return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + return For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, - const std::function& for_body_generator) { + const std::function& + for_body_generator) { if (peel_first_iteration) { - For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) { - for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); - }); + return For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator( + indvar, ir_builder_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, ir_builder_, - /*prevent_unrolling=*/prevent_unrolling_, + /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( - loop->GetIndVarValue(), start)); + TF_RETURN_IF_ERROR( + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start))); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + return Status::OK(); } } -void KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, - const std::function& false_block_generator) { +Status KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, "", ir_builder_); ir_builder_->SetInsertPoint(&if_data.true_block->back()); - true_block_generator(); + TF_RETURN_IF_ERROR(true_block_generator()); ir_builder_->SetInsertPoint(&if_data.false_block->back()); - false_block_generator(); + TF_RETURN_IF_ERROR(false_block_generator()); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 1c00b2aabd182da72e78d2c9c01cbe70cfd8e33c..e17c649e5272a9e7c0d5126083ab76542abfdf48 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,13 +31,14 @@ namespace xla { class KernelSupportLibrary { public: // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. - // If `prevent_unrolling` is true then unrolling is explicitly disabled on - // every loop generated by this instance of KernelSupportLibrary. - explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = true, - bool prevent_vectorization = true) + // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop + // generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary( + llvm::IRBuilder<>* ir_builder, + llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, + bool prevent_vectorization = true) : ir_builder_(ir_builder), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} // Generates the following control flow structure: @@ -46,19 +48,41 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator); + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& - for_body_generator); + for_body_generator) { + CHECK_EQ(Status::OK(), + For(name, start, end, step, + [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -75,37 +99,101 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, llvm::Value* step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(For( + name, start, end, step, peel_first_iteration, + [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + return For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, for_body_generator); + } - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + ForReturnVoid(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + return For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); + } + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + return For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + ForReturnVoid(name, start, end, ir_builder_->getInt64(step), + for_body_generator); + } + + Status For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -114,9 +202,25 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - void If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() {}); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }); + + void IfReturnVoid(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); + } using ArgumentVector = tensorflow::gtl::ArraySlice; @@ -174,7 +278,7 @@ class KernelSupportLibrary { private: llvm::IRBuilder<>* ir_builder_; - bool prevent_unrolling_; + llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 7b227ce294176cfbbf7308bbf65afe21814f3dea..9f867014fb015845448c4fcf9c165750f8a61935 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,23 +34,23 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, + llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) - : prefix_(prefix.ToString()), - suffix_(suffix.ToString()), + : prefix_(std::string(prefix)), + suffix_(std::string(suffix)), start_index_(start_index), end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { + UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, + end_index, step, unroll_mode, prevent_vectorization)); loop->Emit(ir_builder); return loop; @@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { std::vector ForLoop::GetLoopMetadata( llvm::IRBuilder<>* ir_builder) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; llvm::LLVMContext* ctx = &start_index_->getContext(); std::vector result; - if (prevent_unrolling_) { + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); } @@ -162,6 +163,10 @@ std::vector ForLoop::GetLoopMetadata( llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); } + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)})); + } return result; } @@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,23 +220,23 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, + ir_builder_->getInt64(end_index), unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, + ir_builder_->getInt64(stride), unroll_mode, prevent_vectorization); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 20069ce5a28184a5a9216d1a3751d1cee547727d..4e403cd994874c27453574283c5c573c876628db 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -34,6 +34,12 @@ limitations under the License. namespace xla { namespace llvm_ir { +enum class UnrollMode { + kDefaultUnroll, + kFullyUnroll, + kNoUnroll, +}; + // A class for constructing a for-loop in LLVM IR. class ForLoop { public: @@ -69,12 +75,13 @@ class ForLoop { // LLVM IR. If non-empty, it is prepended to the name of the induction // variable value and each basic block created for the loop. // - // If `prevent_unrolling` is true then emit metadata that directs LLVM to not - // unroll the generated loop. + // `unroll_mode` specifies the desired LLVM unrolling behavior for generated + // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false, bool prevent_vectorization = false); + UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -128,7 +135,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling, bool prevent_vectorization); + UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -161,7 +168,7 @@ class ForLoop { llvm::BasicBlock* body_bb_; llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; - bool prevent_unrolling_; + UnrollMode unroll_mode_; bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); @@ -174,7 +181,7 @@ class ForLoopNest { : ForLoopNest(/*name=*/"", ir_builder) {} ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) - : name_(name.ToString()), + : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), @@ -182,34 +189,34 @@ class ForLoopNest { // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have - // been added then emit loop inside the body of the last added loop. If - // prevent_unrolling is true, then metadata is emitting directing LLVM to not - // unroll this loop. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + // been added then emit loop inside the body of the last added loop. + // unroll_mode is used to emit metadata that controls LLVM unrolling. + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, int64 stride, + tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 5c1866311d1ae1e0c33ab061ee326d86d647a908..ff64da87e9c9acf8a9d7ff87d3b1be7a9e9106bb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -86,18 +87,10 @@ llvm::Value* EmitCallToIntrinsic( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, llvm::IRBuilder<>* ir_builder) { - std::vector types; - for (auto type : overloaded_types) { - types.push_back(type); - } llvm::Module* module = ModuleFromIRBuilder(ir_builder); - llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); - std::vector operands_vec; - for (auto operand : operands) { - operands_vec.push_back(operand); - } - return ir_builder->CreateCall(intrinsic, operands_vec); + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, intrinsic_id, AsArrayRef(overloaded_types)); + return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -106,8 +99,10 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value); return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); } else { - return EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder); + auto cmp_ge = ir_builder->CreateFCmpOGE(lhs_value, rhs_value); + auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = ir_builder->CreateOr(cmp_ge, lhs_is_nan); + return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); } } @@ -117,8 +112,10 @@ llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value); return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); } else { - return EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder); + auto cmp_le = ir_builder->CreateFCmpOLE(lhs_value, rhs_value); + auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = ir_builder->CreateOr(cmp_le, lhs_is_nan); + return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); } } @@ -363,15 +360,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray(literal, module); + case U32: + return GetConstantDataArray(literal, module); + case U8: + return GetConstantDataArray(literal, module); + case S64: + return GetConstantDataArray(literal, module); + case S32: + return GetConstantDataArray(literal, module); + case F64: + return GetConstantDataArray(literal, module); + case F32: + return GetConstantDataArray(literal, module); + case BF16: + case F16: + return GetConstantDataArray(literal, module); + case PRED: + return GetConstantDataArray(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, @@ -758,7 +792,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { + if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index b6b918ec78a27b90325f72eea14b97f9aee43c54..dc2934a34c23f8229947210cacc9863d47c2ea55 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); target_array.EmitWriteArrayElement(array_index, target_element, ir_builder); - return tensorflow::Status::OK(); + return Status::OK(); }), shape_(target_array.GetShape()), ir_builder_(ir_builder) {} @@ -84,16 +83,18 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { - CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())) + << ": '" << shape_.ShortDebugString() << "' does not match '" + << array.GetShape().ShortDebugString() << "'"; } } -IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( +std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) { if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; - return IrArray::Index(); + return {IrArray::Index()}; } // Create loop nest with one for-loop for each dimension of the target shape. @@ -121,19 +122,21 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); CHECK_NOTNULL(exit_bb_); - return array_index; + return {array_index}; } -tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { - IrArray::Index array_index = EmitIndexAndSetExitBasicBlock(loop_name); - TF_RETURN_IF_ERROR(body_emitter_(array_index)); +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { + for (const IrArray::Index& array_index : + EmitIndexAndSetExitBasicBlock(loop_name)) { + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + } // Set the insertion point of ir_builder_ to the loop exit, so that // code emitted for later instructions will be correctly placed. if (exit_bb_ != nullptr) { ir_builder_->SetInsertPoint(exit_bb_); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 0fc528439a0d5bf8382dfcf2d8b3051f8900bf1d..b70d28ecd3033eb26629718e50ce48f39b162273 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -38,8 +38,7 @@ using ElementGenerator = // Emits a loop for every element in the given shape. class LoopEmitter { public: - using BodyEmitter = - std::function; + using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* ir_builder); @@ -63,15 +62,16 @@ class LoopEmitter { // Emits a loop nest (with a yet-to-be-filled loop body) that iterates through // every element in the given shape. Returns the multi-dimensional index that - // specifies the element. - IrArray::Index EmitIndexAndSetExitBasicBlock() { + // specifies the element, will return multiple indices if the loop is + // unrolled. + std::vector EmitIndexAndSetExitBasicBlock() { return EmitIndexAndSetExitBasicBlock(/*loop_name=*/""); } - virtual IrArray::Index EmitIndexAndSetExitBasicBlock( + virtual std::vector EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. - tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = ""); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 34899b7400464e4f4f97d301f35ed3b7b083bca1..dacc54742c0897bbd92315f1e33a484aae56bb7f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), update_shape.dimensions(i)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + llvm::Value* max_bound = + ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, + start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // output_index[dim] = start_index[dim] + update_index[dim] // IrArray::Index output_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - update_index[i]->getType(), output_shape.dimensions(i)); - llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateURem( - ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 3a21eda35757aa706565ee4a5286eee1acea117b..5fc08aab916e377b245b6221108956c06da70767 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -24,15 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -47,30 +46,27 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; llvm::Value* on_true_element_address = ir_builder->CreateInBoundsGEP(on_true, element_index); llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = ir_builder->CreateInBoundsGEP(on_false, element_index); llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), output_element_address); } } -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index dbf9a140068b60505f6798360438f709bfd3feba..352d34ebf839c6c2465abade7c3d3eb3b7a34506 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -59,13 +59,13 @@ namespace llvm_ir { // of the address from the corresponding element in either // tuple_on_true or tuple_on_false: // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module); +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 07f989d4faea199e812e54d2ae74d3ff9e7fa19a..296d04d4362b12fdc39798a016ca9e8795e02586 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -24,15 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -43,13 +40,11 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { /* static */ StatusOr> LocalService::NewService( const ServiceOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } @@ -69,31 +64,100 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} +namespace { + +// Retrieves the parameter metadata for the given computation and parameter +// number. +// +// If the parameter number is invalid for this computation, nullopt is +// returned. When the return value has_value(), nullptr will never be +// the held value. +tensorflow::gtl::optional ParameterMetadata( + const XlaComputation& computation, int parameter_number) { + for (const HloComputationProto& comp : computation.proto().computations()) { + if (comp.id() == computation.proto().entry_computation_id()) { + for (const HloInstructionProto& instr : comp.instructions()) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == parameter_number) { + if (!instr.has_metadata()) { + return tensorflow::gtl::nullopt; + } + return &instr.metadata(); + } + } + } + } + return tensorflow::gtl::nullopt; +} + +ExecutionOptions CreateExecutionOptions( + const ExecutableBuildOptions& build_options, + const ProgramShape* program_shape) { + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (build_options.hlo_profile().has_value()) { + execution_options.mutable_debug_options()->set_xla_hlo_profile( + *build_options.hlo_profile()); + } + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.dump_optimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_optimized_hlo_proto_to( + build_options.dump_optimized_hlo_proto_to().value()); + } + if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_unoptimized_hlo_proto_to( + build_options.dump_unoptimized_hlo_proto_to().value()); + } + if (build_options.dump_per_pass_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_per_pass_hlo_proto_to( + build_options.dump_per_pass_hlo_proto_to().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); + } else { + *execution_options.mutable_shape_with_output_layout() = + program_shape->result(); + LayoutUtil::SetToDefaultLayout( + execution_options.mutable_shape_with_output_layout()); + } + + for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { + execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( + disabled_pass); + } + + return execution_options; +} + +} // namespace + StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, + const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); + const HloModuleProto& proto = computation.proto(); + TF_RET_CHECK(proto.has_program_shape()); + const ProgramShape& program_shape = proto.program_shape(); // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { + if (argument_layouts.size() != program_shape.parameters_size()) { return InvalidArgument( "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); + program_shape.parameters_size(), argument_layouts.size()); } + for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { + if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); + ParameterMetadata(computation, /*parameter_number=*/i); auto metadata_string = [&metadata]() -> string { if (!metadata.has_value()) { return ""; @@ -109,39 +173,27 @@ StatusOr> LocalService::CompileExecutable( return InvalidArgument( "Invalid argument shape for argument %d%s, expected %s, got %s.", i, metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), + ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(argument_shape).c_str()); } } if (build_options.result_layout() != nullptr) { TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); + *build_options.result_layout(), program_shape.result())); } - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); - } - if (build_options.result_layout() != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); - } else { - *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); - } + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, &program_shape); + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - *user_computation)); + CreateModuleConfig(program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); - return BuildExecutable(versioned_handle, std::move(module_config), + return BuildExecutable(proto, std::move(module_config), execute_backend_.get(), executor, build_options.device_allocator()); } @@ -152,4 +204,15 @@ StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { /*computation_count=*/1); } +StatusOr LocalService::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + if (replica_number >= buffers.size()) { + return InvalidArgument( + "replica_number %d out of range; must be less than num_replicas = %zu.", + replica_number, buffers.size()); + } + return buffers[replica_number]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 15e120685e1be9190d49fdaf5ed6706bdf991a6c..39d6734c3fc06df6832cf67edddbc7c14c815cd1 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -40,15 +41,15 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. + // Builds an Executable with the given XlaComputation, argument layouts and + // options. If result_layout is non-null, then the executable is compiled to + // produce a result of the given layout. If device_allocator is non-null, + // then the compiler may use it to allocate temp space on the device. The + // compiler is responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( - const ComputationHandle& computation, + const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); + const ExecutableBuildOptions& build_options); // Returns the device ordinal that corresponds to the given replica number. // @@ -57,6 +58,11 @@ class LocalService : public Service { // the "easy" case where a single replica is a single device. StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index 68553bed121917850aaae41c6154f7895ed1add9..c742d35a7bcafa66692195a513992c9cfbb39335 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include -#include - #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -28,43 +25,20 @@ namespace xla { LogicalBuffer::LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) - : instruction_(instruction), id_(id), color_(kInvalidColor), index_(index) { - const auto& s = shape(); - is_array_ = ShapeUtil::IsArray(s); - is_tuple_ = ShapeUtil::IsTuple(s); -} + : BufferValue(instruction, index, id), + instruction_(instruction), + index_(index) {} + +LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { + string color_string; + if (has_color()) { + color_string = tensorflow::strings::StrCat(" @", color().value()); + } return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), - "](#", id_, " @", color_.value(), ")"); -} - -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { - out << buffer.ToString(); - return out; -} - -/*static*/ LogicalBufferProto::Location LogicalBuffer::ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index) { - LogicalBufferProto::Location proto; - proto.set_computation_name(instruction.parent()->name()); - proto.set_instruction_name(instruction.name()); - for (const int64 index_entry : index) { - proto.add_shape_index(index_entry); - } - return proto; -} - -LogicalBufferProto LogicalBuffer::ToProto(const SizeFunction& size_fn) const { - LogicalBufferProto proto; - proto.set_id(id_); - proto.set_size(size_fn(*this)); - LogicalBufferProto::Location proto_location = - ToLocationProto(*instruction_, index_); - proto.mutable_defined_at()->Swap(&proto_location); - proto.set_color(color_.value()); - return proto; + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 67b205e289e626f4db16c39a0a9ddf8618678c3a..f9ba5a554740c9d4cc2643fe59d18ba76c30d03b 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -16,11 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ -#include -#include #include -#include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,133 +31,30 @@ limitations under the License. namespace xla { -// Class describing a contiguous sequence of elements (ie, C array) which form -// the components of Shaped values in XLA. XLA arrays are trivially a -// single LogicalBuffer. Tuple values are made up of more than one -// LogicalBuffer: a LogicalBuffer for the pointers to elements, and a -// LogicalBuffer for each child element. -// -// Every buffer is defined by a particular instruction and most instructions -// define only a single buffer. Instructions which define a single buffer -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single buffer -// which is a vector of pointers to the buffers containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple buffers only the top-level buffer (the vector of pointers) is -// defined by the Tuple instruction. The buffers containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one buffer. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. The tuple-shaped instructions do -// not assemble a tuple from existing buffers like the Tuple instruction does, -// but rather define the entire tuple. -// -// Some instructions, such as Bitcast, define no buffers. These instructions -// simply forward buffers from their operands. -// -// The LogicalBuffer object describes which HLO instruction defines a buffer and -// where within that instruction's output shape the buffer is defined. The -// location within the output shape is indicated by LogicalBuffer::index() which -// is defined identically to the index used in -// ShapeUtil::GetSubshape(). Examples: -// -// %add = Add(%foo, %bar) -// %tuple_constant = Constant({1, {42, 43}}) -// -// %add defines a single array-shaped buffer LogicalBuffer(%add, {}) which holds -// the array result of the add operation. The nested-tuple-shaped -// %tuple_constant defines 5 buffers described by the following LogicalBuffer -// objects: -// -// LogicalBuffer(%tuple_constant, {}) // "Top-level" buffer: vector of -// // pointers to LogicalBuffers at -// // indices {0} and {1} -// LogicalBuffer(%tuple_constant, {0}) // Holds value "1" -// LogicalBuffer(%tuple_constant, {1}) // Holds nested tuple: vector of -// // pointers to LogicalBuffers at -// // indices {1, 0} and {1, 1} -// LogicalBuffer(%tuple_constant, {1, 0}) // Holds value "42" -// LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43" -class LogicalBuffer { +// TuplePointsToAnalysis uses this subclass of BufferValue. +class LogicalBuffer : public BufferValue { public: - TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); - - // Id is a unique identifier for the LogicalBuffer to facilitate efficient - // collections of LogicalBuffers with stable iteration order. - // LogicalBuffers are typically created and accessed through - // TuplePointsToAnalysis, and points-to analysis assigns each LogicalBuffer a - // unique value. - using Id = int64; - - // Functions which return the size and alignment of a logical buffer in bytes. - using SizeFunction = std::function; - using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id); - - Id id() const { return id_; } + ~LogicalBuffer() override; // Return the instruction that defines the buffer. - HloInstruction* instruction() const { return instruction_; } + HloInstruction* instruction() const override { return instruction_; } // Return the index within the output of the instruction where the buffer is // defined. Index used defined as in ShapeUtil::GetSubshape() - const ShapeIndex& index() const { return index_; } - - // Return the color of the logical buffer. Differently colored buffers can - // not be parts of the same allocation. - Color color() const { - CHECK_NE(color_, kInvalidColor) - << "Should not query the color of a buffer that was never colored"; - return color_; - } - - void set_color(Color color) { - CHECK_NE(color, kInvalidColor) - << "Should not set the color of a buffer to the invalid color"; - color_ = color; - } - - bool has_color() const { return color_ != kInvalidColor; } + const ShapeIndex& index() const override { return index_; } // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will // contain the layout of instruction, if any. - const Shape& shape() const { + const Shape& shape() const override { return ShapeUtil::GetSubshape(instruction_->shape(), index_); } - // Returns true if this buffer is the top-level output buffer of the defining - // HLO instruction. This is equivalent to index == {}. - bool IsTopLevel() const { return index_.empty(); } - - // Whether this buffer contains a tuple. - bool IsTuple() const { return is_tuple_; } - - // Whether this buffer contains an array. - bool IsArray() const { return is_array_; } - - // operator< is required for std::set. - bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; } - - string ToString() const; - LogicalBufferProto ToProto(const SizeFunction& size_fn) const; - - // Returns the LogicalBufferProto::Location that serializes the given - // instruction and index. - static LogicalBufferProto::Location ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index); - - const Color kInvalidColor = Color(-1); + string ToString() const override; private: HloInstruction* instruction_; - Id id_ : 62; - bool is_array_ : 1; - bool is_tuple_ : 1; - Color color_; ShapeIndex index_; // Similar to HLO constructs (HloInstruction, etc), pointers are used for @@ -167,8 +62,6 @@ class LogicalBuffer { TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer); }; -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 6aca6ba38572c5311797fbb91acbbcd6610a3410..f410921b4b5337192bdeae5924631d9c06b7d5a5 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -125,6 +125,12 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { // RecvDone doesn't create a new buffer but rather aliases its input (Recv) // tuple element at {0} to its output. diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index f4c63dd86b4d8a6f598d46047012e4e5bc7b3d7e..b5ef3967875a58b35631d5f69c210f5cbcd91250 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -59,6 +59,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..29f787b86b9cbb6f80d048b46b78bdad8074f488 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -0,0 +1,342 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr MultiOutputFusion::Run(HloModule* module) { + bool changed = false; + + for (auto* computation : module->MakeNonfusionComputations()) { + computation_ = computation; + reachability_ = computation_->ComputeReachability(); + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + + int64 index = 0; + for (auto it : computation_->MakeInstructionPostOrder()) { + candidates_.emplace_back(it); + InsertOrDie(&candidates_index_, it, index++); + } + + // Create the initial candidate list for each Node. + for (auto& node : candidates_) { + HloInstruction* instruction = node.hlo; + int64 instruction_id = get_candidate_id(instruction); + FusionCandidate& instr_node = candidates_[instruction_id]; + if (!IsFusible(instruction)) { + continue; + } + all_fusion_candidates_.push_back(instruction); + + std::vector candidates; + tensorflow::gtl::FlatSet candidates_set; + VLOG(10) << "Looking at instruction: " << instruction->name(); + for (auto operand : instruction->operands()) { + // Filter out the non-interesting instructions -- they + // will not generate the savings. + if (!IsProfitableOperand(operand)) { + VLOG(10) << "Operand not profitable: " << operand->name(); + continue; + } + VLOG(10) << "Operand profitable: " << operand->name(); + for (auto user : operand->users()) { + VLOG(10) << "User: " << user->name(); + if (user == instruction || !IsFusible(user)) { + VLOG(10) << "User is not fusible, or is the instruction itself: " + << user->name(); + continue; + } + int64 user_id = get_candidate_id(user); + if (is_connected(instruction, user)) { + VLOG(10) << "User is connected: " << user->name(); + continue; + } + if (instruction_id < user_id && + user->opcode() == HloOpcode::kFusion) { + VLOG(10) << "User ID for user: " << user->name() << " is " + << user_id << " which is higher than " << instruction_id; + continue; + } + if (!LegalToFuse(instruction, user)) { + VLOG(10) << "User not legal to fuse: " << user->name(); + continue; + } + if (candidates_set.insert(user).second) { + VLOG(10) << "User added to candidate list: " << user->name(); + candidates.push_back(user); + } + } + } + + // Iterate over candidates rather than candidates_set to avoid + // nondeterminism. + for (auto candidate : candidates) { + int64 profit = GetProfit(instruction, candidate); + if (profit > 0) { + FusionCandidate& candidate_node = + candidates_[get_candidate_id(candidate)]; + instr_node.fusibles.emplace_back(candidate, profit); + candidate_node.fusibles.emplace_back(instruction, profit); + worklist_.emplace(instruction, candidate, profit); + } + } + } + if (Perform()) { + changed = true; + } + } + return changed; +} + +HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, + HloInstruction* instr2) { + HloInstruction* remaining = instr1; + HloInstruction* fused = instr2; + // Make sure that if only one of the instructions is a fusion, or if only one + // of the instructions is a multi-output fusion, it's what will be fused into. + // + // An invariant is that no bitcast nodes will show up in the middle of a + // fusion node. This invariant must hold in order for us to lower it. Given + // that, we require that during multi-output fusion, a fusion node ending with + // bitcast to preserve its structure as a nested fusion instead being + // merged and flattened. + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + std::swap(remaining, fused); + } + if (fused->IsMultiOutputFusion()) { + std::swap(remaining, fused); + } + + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + remaining->MergeFusionInstructionIntoMultiOutput(fused); + } else { + if (remaining->opcode() == HloOpcode::kFusion && + remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { + auto parent_computation = remaining->parent(); + // Create a nested fusion node. + auto remaining_nested_fused = + parent_computation->AddInstruction(HloInstruction::CreateFusion( + remaining->shape(), HloInstruction::FusionKind::kLoop, + remaining)); + TF_CHECK_OK(parent_computation->ReplaceInstruction( + remaining, remaining_nested_fused)); + remaining = remaining_nested_fused; + } + remaining->FuseInstructionIntoMultiOutput(fused); + } + + return remaining; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; + FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + // Update the fusible list for fusion. Variable new_fusibles keeps + // track of the new or changed entries. + std::vector> new_fusibles; + tensorflow::gtl::FlatSet in_list; + auto it = fusion_node.fusibles.begin(); + while (it != fusion_node.fusibles.end()) { + HloInstruction* instr = it->first; + if (is_fused(instr) || is_connected(fusion, instr)) { + it = fusion_node.fusibles.erase(it); + continue; + } + in_list.insert(instr); + int64 profit = GetProfit(instr, fusion); + if (profit > it->second) { + it->second = profit; + new_fusibles.emplace_back(instr, profit); + } + ++it; + } + + // Fused_node has been fused into fusion_node. Take the fusion candidates + // (fusibles) from fused_nodes and add them to the fusion_node's. Filter + // out those fusibles that no longer valid (or already in the list). + for (const auto& it : fused_node.fusibles) { + HloInstruction* instr = it.first; + if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { + continue; + } + if (in_list.count(instr) > 0) { + continue; + } + int64 profit = GetProfit(instr, fusion); + fusion_node.fusibles.emplace_back(instr, profit); + new_fusibles.emplace_back(instr, profit); + } + fused_node.fusibles.clear(); + + // Update the worklist_. + for (auto it : new_fusibles) { + worklist_.emplace(fusion, it.first, it.second); + } +} + +bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { + return false; + } + if (instr1->opcode() != HloOpcode::kFusion) { + return false; + } + + // Fusing nodes with 0 user makes no sense and the rest of the implementation + // doesn't support it either. + if (instr1->user_count() == 0 || instr2->user_count() == 0) { + return false; + } + + // Check if the users of multioutput fusion is not a get-tuple-element. + // If this is the case, we bail out because the transformation assumes + // the users are get-tuple-element. + auto multioutput_user_is_not_gte = [](HloInstruction* instr) { + if (!instr->IsMultiOutputFusion()) { + return false; + } + for (auto user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return true; + } + } + return false; + }; + if (multioutput_user_is_not_gte(instr1) || + multioutput_user_is_not_gte(instr2)) { + return false; + } + + if (is_connected(instr1, instr2)) { + return false; + } + if (!ShapesCompatibleForFusion(instr1, instr2)) { + return false; + } + + return true; +} + +void MultiOutputFusion::UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip) { + for (auto instr : instrs_to_update) { + if (skip != nullptr && skip(instr)) { + continue; + } + if (reachability_->IsReachable(instr2, instr) && + reachability_->IsReachable(instr1, instr)) { + // If a candidate was already reachable by both, no update needed. + continue; + } + if (reachability_->IsReachable(instr2, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); + } + if (reachability_->IsReachable(instr1, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); + } + } +} + +bool MultiOutputFusion::Perform() { + int changed = false; + // Pick the top candidate from queue and try to merge. + while (!worklist_.empty()) { + if (fuel_ <= 0) { + VLOG(2) << "No fusing: run out of fuel."; + break; + } + ToBeFused candidate = worklist_.top(); + worklist_.pop(); + + HloInstruction* instr1 = candidate.instr1; + HloInstruction* instr2 = candidate.instr2; + + if (is_fused(instr1) || is_fused(instr2)) { + continue; + } + + VLOG(1) << "Considering candidate profit_score=" << candidate.score + << "\n\t\tinstr1 = " << instr1->ToString() + << "\n\t\tinstr2 = " << instr2->ToString(); + + if (LegalToFuse(instr1, instr2)) { + VLOG(1) << "Fuse!"; + VLOG(2) << "Before multi_output_fusion:"; + VLOG(2) << "instr1: " << instr1->ToString(); + VLOG(2) << "\n" + << instr1->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + VLOG(2) << "instr2: " << instr2->ToString(); + if (instr2->opcode() == HloOpcode::kFusion) { + VLOG(2) << "\n" + << instr2->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + } + HloInstruction* ret = Fuse(instr1, instr2); + set_is_fused(ret == instr1 ? instr2 : instr1); + Update(instr1, instr2); + changed = true; + VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" + << ret->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + auto users = ret->users(); + --fuel_; + } + } + if (DoProducerConsumerMultiOutputFusion(computation_)) { + changed = true; + } + return changed; +} + +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion( + HloComputation* /*computation*/) { + return false; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..cfdf83cfe856a7c3b05f51129446cd4e1055a8d6 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// This class implements the fusing of sibling fusion instructions that sharing +// common operands. +// It constructs the following associated data structures. +// (1) candidates_: stores the instruction and the set of instructions it can +// fuse to. +// (2) candidates_index_: maps instruction to id. +// (3) reachability_: reachability map in this computation. +// (4) all_fusion_candidates_: the vector of candidate instructions. +// (5) worklist_: a priority queue that contains pairs of instructions to be +// fused and their fusion profit scores. +// +// Function Perform() applies the optimization. It picks up the most profitable +// pair in the worklist_, check if it's legal to fuse and fuse the pair. +// After fusion, it updates the associated structure such as reachability_, +// candidates_ and worklist_. +// Note that the reachability map is updated based on the original computation. +// This works because the reachability is monotonically increasing with +// instruction fusion. +class MultiOutputFusion : public HloPassInterface { + public: + MultiOutputFusion(int64 fuel) : fuel_(fuel) {} + + tensorflow::StringPiece name() const override { + return "multi_output_fusion"; + } + + // Run multi-output fusion on the given module. Returns whether the module + // was changed. + StatusOr Run(HloModule* module) override; + + protected: + // Main entry for the optimization. Returns true if the optimization happens. + bool Perform(); + + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + virtual bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) = 0; + + // Whether the instruction is a candidate for fusion. + virtual bool IsFusible(HloInstruction* instr) = 0; + + // This function estimates the savings by merging instr1 and instr2 into one + // multi-output fusion instruction. + virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; + + // Whether fusing the instruction can reduce cost. + virtual bool IsProfitableOperand(HloInstruction* instr) = 0; + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the reachability map after fusing instr1 and instr2. + void UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip = nullptr); + + // Hook for multi-output fusion along producer-consumer edges. + // Returns whether any instructions were fused. + // + // TODO(b/80420762): Perform producer-consumer multi-output fusion in + // InstructionFusion instead. + virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation); + + private: + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + + // Optimization fuel is a compiler debugging technique that makes an + // optimization pass stop what it is doing after having made N changes to the + // program, where N is the fuel. By varying N, this can be used to find the + // first single change that makes a test fail. + int64 fuel_; + + // Computation for the pass. + HloComputation* computation_; + + // An internal data structure for each instruction in current computation. + // When an instruction is removed, member 'hlo' is set to nullptr. + struct FusionCandidate { + HloInstruction* hlo; + std::list> fusibles; + explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} + }; + std::vector candidates_; + + // A map that maps an instruction to the index_. + tensorflow::gtl::FlatMap candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // The pair of candidates to be fused and the profit score. + struct ToBeFused { + HloInstruction* instr1; + HloInstruction* instr2; + int64 score; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) + : instr1(instr1), instr2(instr2), score(score) {} + bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + }; + std::priority_queue worklist_; + + int64 get_candidate_id(HloInstruction* instr) { + return FindOrDie(candidates_index_, instr); + } + + bool is_fused(HloInstruction* instr) { + return candidates_[get_candidate_id(instr)].hlo == nullptr; + } + + void set_is_fused(HloInstruction* instr) { + candidates_[get_candidate_id(instr)].hlo = nullptr; + } + + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 7d8c05fffa4ab11d7dbf9956d2cb7ebd5bcdd3c4..3a6a7c25f4b727c7112dbcbcb4f3d892679a0011 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -53,17 +53,18 @@ NameUniquer::NameUniquer(const string& separator) { } string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = prefix.empty() ? "name" : prefix.ToString(); - root = GetSanitizedName(root); + string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. + bool has_numeric_suffix = false; + int64 numeric_suffix = 0; size_t separator_index = root.rfind(separator_); if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - int64 numeric_suffix; if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); // Update count to at least the numeric suffix value to avoid future @@ -71,11 +72,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { generated_names_[root] = std::max(generated_names_[root], numeric_suffix); } } - int64* count = &(generated_names_[root]); if (*count == 0) { *count = 1; - return root; + return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) + : root; } else { tensorflow::strings::StrAppend(&root, separator_, *count); // Increment lookup under old 'root' name. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 4258cf16876ab46dce6df062ab701b1b1a4a7580..2ec255558c4ed3695ec6c824458cbedac44dc297 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -57,11 +57,18 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); + EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); } +TEST_F(NameUniquerTest, PrefixHasSuffix) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo.11.0", uniquer.GetUniqueName("foo.11.0")); + EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11")); +} + TEST_F(NameUniquerTest, Sanitize) { NameUniquer uniquer("_"); @@ -73,7 +80,7 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); // Invalid characters will be replaced with '_'. - EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000")); EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..c115bc097f3b1dd810654745b835a977955718c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" + +namespace xla { + +void OwningDeviceMemory::Free() { + CHECK(allocator_ != nullptr) + << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + auto status = allocator_->Deallocate(device_ordinal_, mem_); + if (!status.ok()) { + LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed."; + } + + allocator_ = nullptr; + mem_ = se::DeviceMemoryBase(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..9cf071f0d9d09dfbf74b15e73caaf542714ec8d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.h @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Break circular dependency between this file and device_memory_allocator.h. +class DeviceMemoryAllocator; + +// Owning pointer for memory on a device. +// +// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can +// point to memory that resides on a "device" (e.g. a GPU). When an +// OwningDeviceMemory goes out of scope, it frees the memory it owns. +// +// We say that an instance of OwningDeviceMemory is "active" if it currently +// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing, +// Free()'ing, and other actions can deactive an active object. +// +// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of +// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a +// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator. +class OwningDeviceMemory { + public: + OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {} + + explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal, + DeviceMemoryAllocator* allocator) + : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) { + CHECK(allocator != nullptr) << "allocator cannot be null."; + } + + OwningDeviceMemory(OwningDeviceMemory&& other) + : mem_(other.mem_), + device_ordinal_(other.device_ordinal_), + allocator_(other.allocator_) { + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + } + + OwningDeviceMemory& operator=(OwningDeviceMemory&& other) { + if (allocator_ != nullptr) { + Free(); + } + mem_ = other.mem_; + device_ordinal_ = other.device_ordinal_; + allocator_ = other.allocator_; + + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + return *this; + } + + // Deactivates this instance if it's active. Nop if it's not active. + OwningDeviceMemory& operator=(std::nullptr_t) { + if (allocator_ != nullptr) { + Free(); + } + return *this; + } + + ~OwningDeviceMemory() { + if (allocator_ != nullptr) { + Free(); + } + } + + // The returned allocator is nonnull iff this object is active. + DeviceMemoryAllocator* allocator() const { return allocator_; } + + int device_ordinal() const { return device_ordinal_; } + + // Gets the device memory pointer. + const void* opaque() const { return mem_.opaque(); } + void* opaque() { return mem_.opaque(); } + + uint64 size() const { return mem_.size(); } + + // Determines whether this wraps a null pointer. + // + // !is_null() is sufficient but not necessary to imply `this` is active. + bool is_null() const { return mem_.is_null(); } + + se::DeviceMemoryBase AsDeviceMemoryBase() { + return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); + } + + // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates + // this object. Precondition: `this` is active. + TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() { + CHECK(allocator_ != nullptr) + << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + allocator_ = nullptr; + se::DeviceMemoryBase mem(mem_); + mem_ = se::DeviceMemoryBase(); + return mem; + } + + // Frees the wrapped DeviceMemoryBase and deactivates this object. + // Precondition: `this` is active. + void Free(); + + private: + se::DeviceMemoryBase mem_; + int device_ordinal_; + DeviceMemoryAllocator* allocator_; // Null if this object is inactive. +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h new file mode 100644 index 0000000000000000000000000000000000000000..2515222cf2db3d9699c85c13f4fe72b3488fa217 --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -0,0 +1,1047 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// A pattern matcher for HloInstructions, Shapes, and Layouts. +// +// The Match function's first argument must be HloInstruction*, Shape*, or +// Layout*. The second argument is a pattern that will be matched against the +// first argument, as described below. +// +// Patterns are constructed using the match::Op, match::Shape, or match::Layout +// functions. By default, the returned patterns will match any HloInstruction, +// Shape, or Layout, respectively. However the match can be made more specific +// by using the pattern's modifier methods, for example: +// +// match::Op().WithOpcode(HloOpcode::kAdd).WithOperand( +// 0, match::Op().WithOpcode(HloOpcode::kConstant)) +// +// This pattern will match Add instructions whose first operand is a constant. +// +// Each pattern type has the following modifiers: +// +// Op(): +// - WithName: match operations with the given name +// - WithOpcode: match operations with the given opcode +// - WithShape: match operations whose shape matches the given pattern +// - WithOperand: match operations whose operand matches the given pattern +// +// Shape(): +// - EqualTo: matches shapes that are equal to the argument +// - CompatibleTo: matches shapes that are compatible to the argument +// - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes +// - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format +// - WithLayout: match shapes whose layout matches the given pattern +// - WithLayoutEqualTo: matches shapes whose layouts equal the argument +// - WithSubshape: matches tuple shapes whose subshape matches the given +// pattern +// - WithSubshapeEqualTo: matches shapes with a subshape equal the argument +// - WithElementType: matches array/scalar shapes with the given element +// type +// - WithRank: matches array/scalar types with the given rank +// +// Layout(): +// - EqualTo: matches layouts that are equal to the argument +// - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse +// format +// +// Op(), Shape(), and Layout() may be passed an argument of type +// HloInstruction**, Shape**, or Layout**, respectively, or const versions of +// these pointers. If the pattern is matched, the address of the matched value +// will be "captured" and stored at this location. +// +// For example: +// HloInstruction* foo = ...; +// HloInstruction* matched_operand; +// CHECK(Match(foo, +// match::Op().WithOperand(0, match::Op(&matched_operand)))); +// +// Helpers are provided for common nullary, unary, binary, and ternary +// instructions. These helpers can be called with no arguments, in which case +// they will match any instruction matching the opcode. They may also be called +// with matches for the operands and with an optional capture. (The capture must +// be the first argument.) Some examples of these helpers and their equivalents +// are provided below. +// +// Example nullary instruction: +// Recv() == Op().WithOpcode(HloOpcode::kRecv) +// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv) +// +// Example unary instruction: +// Abs() == Op().WithOpcode(HloOpcode::kAbs) +// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&a))) +// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&b)) +// +// Example binary instruction: +// Add() == Op().WithOpcode(HloOpcode::kAdd) +// Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) +// .WithOperand(0, Op(&a)) +// .WithOperand(1, Op(&b)) +// Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) +// .WithOperand(0, Op(&b)) +// .WithOperand(1, Op(&c)) +// +// Example ternary instruction: +// Clamp() == Op().WithOpcode(HloOpcode::kClamp) +// Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) +// .WithOperand(0, Op(&a)) +// .WithOperand(1, Op(&b)) +// .WithOperand(2, Op(&c)) +// Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) +// .WithOperand(0, Op(&b)) +// .WithOperand(1, Op(&c)) +// .WithOperand(2, Op(&d)) +// +template +bool Match(Value* value, const Pattern& pattern) { + return pattern.Match(value); +} + +namespace match { + +namespace detail { + +template +class LayoutPattern; + +// The base LayoutPattern implementation. Matches only if the layout is not +// nullptr. +class LayoutPatternBaseImpl { + public: + bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } +}; + +// A LayoutPattern implementation that matches only if the layout equals a +// Layout proto. +template +class LayoutPatternEqualImpl { + public: + explicit constexpr LayoutPatternEqualImpl(const Previous& previous, + const ::xla::Layout* layout) + : previous_(previous), layout_(layout) {} + + bool Match(const ::xla::Layout* layout) const { + return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + } + + private: + Previous previous_; + const ::xla::Layout* layout_; +}; + +// A LayoutPattern implementation that matches only if the layout has a given +// format. +template +class LayoutPatternFormatImpl { + public: + explicit constexpr LayoutPatternFormatImpl(const Previous& previous, + Format format) + : previous_(previous), format_(format) {} + + bool Match(const ::xla::Layout* layout) const { + return previous_.Match(layout) && layout->format() == format_; + } + + private: + Previous previous_; + Format format_; +}; + +// A pattern that matches Layouts. +template +class LayoutPattern { + public: + explicit constexpr LayoutPattern(const Impl& impl, + LayoutType** matched_layout) + : impl_(impl), matched_layout_(matched_layout) {} + + // Returns true and captures the layout iff it matches the pattern. + bool Match(const ::xla::Layout* layout) const { + if (impl_.Match(layout)) { + if (matched_layout_) { + *matched_layout_ = layout; + } + return true; + } + return false; + } + + // Returns true and captures the layout iff it matches the pattern. + bool Match(::xla::Layout* layout) const { + if (impl_.Match(layout)) { + if (matched_layout_) { + *matched_layout_ = layout; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the layout equals the given proto. + // The layout must outlive the returned pattern. + constexpr LayoutPattern> EqualTo( + const ::xla::Layout* layout) const { + return LayoutPattern>( + LayoutPatternEqualImpl(impl_, layout), matched_layout_); + } + + // Modifies the pattern to match only if the layout has a dense format. + constexpr LayoutPattern> + WithDenseFormat() const { + return LayoutPattern>( + LayoutPatternFormatImpl(impl_, DENSE), matched_layout_); + } + + // Modifies the pattern to match only if the layout has a sparse format. + constexpr LayoutPattern> + WithSparseFormat() const { + return LayoutPattern>( + LayoutPatternFormatImpl(impl_, SPARSE), matched_layout_); + } + + private: + Impl impl_; + LayoutType** matched_layout_; +}; + +} // namespace detail + +// Creates a layout pattern that will capture the matched layout in the +// argument. +inline constexpr detail::LayoutPattern +Layout(const ::xla::Layout** matched_layout = nullptr) { + return detail::LayoutPattern( + detail::LayoutPatternBaseImpl(), matched_layout); +} + +// Creates a layout pattern that will capture the matched layout in the +// argument. +inline constexpr detail::LayoutPattern<::xla::Layout, + detail::LayoutPatternBaseImpl> +Layout(::xla::Layout** matched_layout) { + return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( + detail::LayoutPatternBaseImpl(), matched_layout); +} + +namespace detail { + +template +class ShapePattern; + +// The base ShapePattern implementation. Matches only if the shape is not +// nullptr. +class ShapePatternBaseImpl { + public: + bool Match(const ::xla::Shape* shape) const { return shape != nullptr; } +}; + +// A ShapePattern implementation that matches only if the shape equals a Shape +// proto. +template +class ShapePatternEqualImpl { + public: + explicit constexpr ShapePatternEqualImpl(const Previous& previous, + const ::xla::Shape* shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + } + + private: + Previous previous_; + const ::xla::Shape* shape_; +}; + +// A ShapePattern implementation that matches only if the shape is compatible to +// a Shape proto. +template +class ShapePatternCompatibleImpl { + public: + explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, + const ::xla::Shape* shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + } + + private: + Previous previous_; + const ::xla::Shape* shape_; +}; + +// A ShapePattern implementation that matches only if the shape has a given +// element type. +template +class ShapePatternElementTypeImpl { + public: + explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, + PrimitiveType element_type) + : previous_(previous), element_type_(element_type) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && shape->element_type() == element_type_; + } + + private: + Previous previous_; + PrimitiveType element_type_; +}; + +// A ShapePattern implementation that matches only if the shape is scalar. +template +class ShapePatternIsScalarImpl { + public: + explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape is an array +template +class ShapePatternIsArrayImpl { + public: + explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape is a tuple. +template +class ShapePatternIsTupleImpl { + public: + explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape has a given +// rank. +template +class ShapePatternRankImpl { + public: + explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) + : previous_(previous), rank_(rank) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + } + + private: + Previous previous_; + int64 rank_; +}; + +// A ShapePattern implementation that matches only if the shape has a layout +// that matches a given pattern. +template +class ShapePatternLayoutImpl { + public: + explicit constexpr ShapePatternLayoutImpl( + const Previous& previous, + const LayoutPattern& layout) + : previous_(previous), layout_(layout) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout()); + } + + bool Match(Shape* shape) const { + return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout()); + } + + private: + Previous previous_; + LayoutPattern layout_; +}; + +// A ShapePattern implementation that matches only if the shape has a subshape +// that matches a given pattern. +template +class ShapePatternSubshapeImpl { + public: + explicit ShapePatternSubshapeImpl( + const Previous& previous, ShapeIndexView index, + const ShapePattern& subshape) + : previous_(previous), index_(index), subshape_(subshape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + } + + bool Match(::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + } + + private: + Previous previous_; + ShapeIndexView index_; + ShapePattern subshape_; +}; + +// A pattern that matches Shapes. +template +class ShapePattern { + public: + explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) + : impl_(impl), matched_shape_(matched_shape) {} + + // Returns true and captures the shape iff it matches the pattern. + bool Match(const ::xla::Shape* shape) const { + if (impl_.Match(shape)) { + if (matched_shape_) { + *matched_shape_ = shape; + } + return true; + } + return false; + } + + // Returns true and captures the shape iff it matches the pattern. + bool Match(::xla::Shape* shape) const { + if (impl_.Match(shape)) { + if (matched_shape_) { + *matched_shape_ = shape; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the shape equals the given proto. + // The layout must outlive the returned pattern. + constexpr ShapePattern> EqualTo( + const ::xla::Shape* shape) const { + return ShapePattern>( + ShapePatternEqualImpl(impl_, shape), matched_shape_); + } + + // Modifies the pattern to match only if the shape is compatible to the given + // proto. The layout must outlive the returned pattern. + constexpr ShapePattern> + CompatibleTo(const ::xla::Shape* shape) const { + return ShapePattern>( + ShapePatternCompatibleImpl(impl_, shape), matched_shape_); + } + + // Modifies the pattern to match only if the shape has the given element type. + constexpr ShapePattern> + WithElementType(PrimitiveType element_type) const { + return ShapePattern>( + ShapePatternElementTypeImpl(impl_, element_type), matched_shape_); + } + + // Modifies the pattern to match only if the shape is scalar. + constexpr ShapePattern> IsScalar() + const { + return ShapePattern>( + ShapePatternIsScalarImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape is an array. + constexpr ShapePattern> IsArray() + const { + return ShapePattern>( + ShapePatternIsArrayImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape is a tuple. + constexpr ShapePattern> IsTuple() + const { + return ShapePattern>( + ShapePatternIsTupleImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape has the given rank. + constexpr ShapePattern> WithRank( + int64 rank) const { + return ShapePattern>( + ShapePatternRankImpl(impl_, rank), matched_shape_); + } + + // Modifies the pattern to match only if the shape has a layout that matches + // the given pattern. + template + constexpr ShapePattern> + WithLayout(const LayoutPattern& layout) const { + return ShapePattern>( + ShapePatternLayoutImpl(impl_, layout), + matched_shape_); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + WithLayoutEqualTo(const ::xla::Layout* layout) const { + return WithLayout(Layout().EqualTo(layout)); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + IsDenseArray() const { + return WithLayout(Layout().WithDenseFormat()); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + IsSparseArray() const { + return WithLayout(Layout().WithSparseFormat()); + } + + // Modifies the pattern to match only if the shape has a subshape that matches + // the given pattern. + template + ShapePattern> + WithSubshape(ShapeIndexView index, + const ShapePattern& subshape) const { + return ShapePattern< + ShapeType, ShapePatternSubshapeImpl>( + ShapePatternSubshapeImpl(impl_, index, + subshape), + matched_shape_); + } + + ShapePattern>> + WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { + return WithSubshape(index, + ShapePattern( + ShapePatternBaseImpl(), nullptr) + .EqualTo(shape)); + } + + ShapePattern>> + WithSubshapeCompatibleTo(ShapeIndexView index, + const ::xla::Shape* shape) const { + return WithSubshape(index, + ShapePattern( + ShapePatternBaseImpl(), nullptr) + .CompatibleTo(shape)); + } + + private: + Impl impl_; + ShapeType** matched_shape_; +}; + +} // namespace detail + +// Creates a shape pattern that will capture the matched layout in the argument. +inline constexpr detail::ShapePattern +Shape(const ::xla::Shape** matched_shape = nullptr) { + return detail::ShapePattern( + detail::ShapePatternBaseImpl(), matched_shape); +} + +// Creates a shape pattern that will capture the matched layout in the argument. +inline constexpr detail::ShapePattern<::xla::Shape, + detail::ShapePatternBaseImpl> +Shape(::xla::Shape** matched_shape) { + return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( + detail::ShapePatternBaseImpl(), matched_shape); +} + +namespace detail { + +template +class HloInstructionPattern; + +// The base HloInstructionPattern implementation. Matches only if the +// instruction is not nullptr. +class HloInstructionPatternBaseImpl { + public: + bool Match(const ::xla::HloInstruction* inst) const { + return inst != nullptr; + } +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a given name. +template +class HloInstructionPatternNameImpl { + public: + explicit HloInstructionPatternNameImpl(const Previous& previous, + tensorflow::StringPiece name) + : previous_(previous), name_(name) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->name() == name_; + } + + private: + Previous previous_; + tensorflow::StringPiece name_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a given opcode. +template +class HloInstructionPatternOpcodeImpl { + public: + explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, + HloOpcode opcode, + bool invert) + : previous_(previous), opcode_(opcode), invert_(invert) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + } + + private: + Previous previous_; + HloOpcode opcode_; + bool invert_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a shape that matches a given pattern. +template +class HloInstructionPatternShapeImpl { + public: + explicit constexpr HloInstructionPatternShapeImpl( + const Previous& previous, const ShapePattern& shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && shape_.Match(&inst->shape()); + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + } + + private: + Previous previous_; + ShapePattern shape_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has an operand that matches a given pattern. +template +class HloInstructionPatternOperandImpl { + public: + explicit constexpr HloInstructionPatternOperandImpl( + const Previous& previous, int64 operand_index, + const HloInstructionPattern& operand) + : previous_(previous), operand_index_(operand_index), operand_(operand) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_)); + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_)); + } + + private: + Previous previous_; + int64 operand_index_; + HloInstructionPattern operand_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// is a fusion node with a particular kind. +template +class HloInstructionPatternFusionKindImpl { + public: + explicit constexpr HloInstructionPatternFusionKindImpl( + const Previous& previous, ::xla::HloInstruction::FusionKind kind) + : previous_(previous), kind_(kind) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + private: + Previous previous_; + ::xla::HloInstruction::FusionKind kind_; +}; + +// A pattern that matches HloInstructions. +template +class HloInstructionPattern { + public: + explicit constexpr HloInstructionPattern(const Impl& impl, + HloInstructionType** matched_inst) + : impl_(impl), matched_inst_(matched_inst) {} + + // Returns true and captures the instruction iff it matches the pattern. + bool Match(const ::xla::HloInstruction* inst) const { + if (impl_.Match(inst)) { + if (matched_inst_) { + *matched_inst_ = inst; + } + return true; + } + return false; + } + + // Returns true and captures the instruction iff it matches the pattern. + bool Match(::xla::HloInstruction* inst) const { + if (impl_.Match(inst)) { + if (matched_inst_) { + *matched_inst_ = inst; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the instruction has the given name. + HloInstructionPattern> + WithName(tensorflow::StringPiece name) const { + return HloInstructionPattern>( + HloInstructionPatternNameImpl(impl_, name), matched_inst_); + } + + // Modifies the pattern to match only if the instruction has the given opcode. + constexpr HloInstructionPattern> + WithOpcode(HloOpcode opcode) const { + return HloInstructionPattern>( + HloInstructionPatternOpcodeImpl(impl_, opcode, false), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction does not have the + // given opcode. + constexpr HloInstructionPattern> + WithoutOpcode(HloOpcode opcode) const { + return HloInstructionPattern>( + HloInstructionPatternOpcodeImpl(impl_, opcode, true), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction is a constant. + constexpr HloInstructionPattern> + IsConstant() const { + return WithOpcode(HloOpcode::kConstant); + } + + // Modifies the pattern to match only if the instruction is not a constant. + constexpr HloInstructionPattern> + IsNonConstant() const { + return WithoutOpcode(HloOpcode::kConstant); + } + + // Modifies the pattern to match only if the instruction has a shape that + // matches the given pattern. + template + constexpr HloInstructionPattern< + HloInstructionType, + HloInstructionPatternShapeImpl> + WithShape(const ShapePattern& shape) const { + return HloInstructionPattern< + HloInstructionType, + HloInstructionPatternShapeImpl>( + HloInstructionPatternShapeImpl(impl_, + shape), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction has an operand that + // matches the given pattern. + template + constexpr HloInstructionPattern< + HloInstructionType, + HloInstructionPatternOperandImpl> + WithOperand( + int64 operand_index, + const HloInstructionPattern& operand) const { + return HloInstructionPattern< + HloInstructionType, + HloInstructionPatternOperandImpl>( + HloInstructionPatternOperandImpl( + impl_, operand_index, operand), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction is a fusion node with + // the given kind. + constexpr HloInstructionPattern> + WithFusionKind(HloInstruction::FusionKind kind) const { + return HloInstructionPattern>( + HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + } + + private: + Impl impl_; + HloInstructionType** matched_inst_; +}; + +} // namespace detail + +// Creates an instruction pattern that will capture the matched instruction in +// the argument. +inline constexpr detail::HloInstructionPattern< + const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> +Op(const ::xla::HloInstruction** matched_inst = nullptr) { + return detail::HloInstructionPattern( + detail::HloInstructionPatternBaseImpl(), matched_inst); +} + +// Creates an instruction pattern that will capture the matched instruction in +// the argument. +inline constexpr detail::HloInstructionPattern< + ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> +Op(::xla::HloInstruction** matched_inst) { + return detail::HloInstructionPattern<::xla::HloInstruction, + detail::HloInstructionPatternBaseImpl>( + detail::HloInstructionPatternBaseImpl(), matched_inst); +} + +// Helpers for nullary instructions. +#define XLA_NULLOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst) \ + ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ + return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ + } +XLA_NULLOP_PATTERN(Constant) +XLA_NULLOP_PATTERN(Infeed) +XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Recv) +#undef XLA_NULLOP_PATTERN + +// Helpers for unary instructions. +#define XLA_UNOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Arg&& arg)->decltype( \ + Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } +XLA_UNOP_PATTERN(Abs) +XLA_UNOP_PATTERN(RoundNearestAfz) +XLA_UNOP_PATTERN(Bitcast) +XLA_UNOP_PATTERN(Broadcast) +XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Copy) +XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(Exp) +XLA_UNOP_PATTERN(Fft) +XLA_UNOP_PATTERN(Floor) +XLA_UNOP_PATTERN(Imag) +XLA_UNOP_PATTERN(IsFinite) +XLA_UNOP_PATTERN(Log) +XLA_UNOP_PATTERN(Not) +XLA_UNOP_PATTERN(Negate) +XLA_UNOP_PATTERN(Outfeed) +XLA_UNOP_PATTERN(Real) +XLA_UNOP_PATTERN(Reduce) +XLA_UNOP_PATTERN(ReducePrecision) +XLA_UNOP_PATTERN(Reshape) +XLA_UNOP_PATTERN(Reverse) +XLA_UNOP_PATTERN(Send) +XLA_UNOP_PATTERN(Sign) +XLA_UNOP_PATTERN(Sin) +XLA_UNOP_PATTERN(Sort) +XLA_UNOP_PATTERN(Tanh) +XLA_UNOP_PATTERN(Transpose) +#undef XLA_UNOP_PATTERN + +// Helpers for binary instructions. +#define XLA_BINOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } +XLA_BINOP_PATTERN(Add) +XLA_BINOP_PATTERN(Atan2) +XLA_BINOP_PATTERN(Divide) +XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Dot) +XLA_BINOP_PATTERN(Eq) +XLA_BINOP_PATTERN(Gather) +XLA_BINOP_PATTERN(Ge) +XLA_BINOP_PATTERN(Gt) +XLA_BINOP_PATTERN(Le) +XLA_BINOP_PATTERN(Lt) +XLA_BINOP_PATTERN(Maximum) +XLA_BINOP_PATTERN(Minimum) +XLA_BINOP_PATTERN(Multiply) +XLA_BINOP_PATTERN(Ne) +XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(Remainder) +XLA_BINOP_PATTERN(Subtract) +XLA_BINOP_PATTERN(And) +XLA_BINOP_PATTERN(Or) +XLA_BINOP_PATTERN(ShiftLeft) +XLA_BINOP_PATTERN(ShiftRightArithmetic) +XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_BINOP_PATTERN + +// Helpers for ternary instructions. +#define XLA_TERNOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ + ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ + Arg1&& arg1, Arg2&& arg2) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2)); \ + } +XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Select); +#undef XLA_TERNOP_PATTERN + +// Helpers for matching non-constant instructions. +inline auto NonConstant() -> decltype(Op().IsNonConstant()) { + return Op().IsNonConstant(); +} + +template +inline auto NonConstant(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsNonConstant()) { + return Op(matched_inst).IsNonConstant(); +} + +} // namespace match + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fef3c132b0f3467a01b02f2be88b419459179277 --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(PatternMatcherTest, AddOp) { + constexpr char kModuleStr[] = R"(HloModule two_plus_two_module + ENTRY %two_plus_two_computation () -> f32[] { + %two = f32[] constant(2) + ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + const HloInstruction* matched_inst; + HloInstruction* matched_operand; + Shape* matched_shape; + Layout* matched_layout; + + ASSERT_TRUE(Match( + hlo_module->entry_computation()->root_instruction(), + match::Op(&matched_inst) + .WithName("two_plus_two") + .WithOpcode(HloOpcode::kAdd) + .WithShape( + match::Shape(&matched_shape) + .WithLayout(match::Layout(&matched_layout).WithDenseFormat())) + .WithOperand( + 0, + match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant)))); + ASSERT_NE(matched_inst, nullptr); + EXPECT_EQ(matched_inst->name(), "two_plus_two"); + EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd); + + EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(), + match::Add(match::Constant(), match::Constant()))); + + EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(), + match::Op().WithName("bad_name"))); + matched_inst = nullptr; + EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(), + match::Multiply(&matched_inst, match::Op(), match::Op()))); +} + +TEST(PatternMatcherTest, ScalarShape) { + auto scalar_shape = ShapeUtil::MakeShape(F32, {}); + Shape* matched_shape; + EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar())); + EXPECT_EQ(matched_shape, &scalar_shape); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray())); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray())); + EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0))); + EXPECT_FALSE(Match( + &scalar_shape, + match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32))); +} + +TEST(PatternMatcherTest, DenseArrayShape) { + auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4}); + Shape* matched_shape; + EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); + EXPECT_EQ(matched_shape, &array_shape); + EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3))); + EXPECT_FALSE( + Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); + Layout* matched_layout; + EXPECT_FALSE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithSparseFormat()))); + EXPECT_TRUE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithDenseFormat()))); + EXPECT_EQ(matched_layout, &array_shape.layout()); +} + +TEST(PatternMatcherTest, SparseArrayShape) { + auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10); + Shape* matched_shape; + EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); + EXPECT_EQ(matched_shape, &array_shape); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray())); + EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3))); + EXPECT_FALSE( + Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); + Layout* matched_layout; + EXPECT_FALSE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithDenseFormat()))); + EXPECT_TRUE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithSparseFormat()))); + EXPECT_EQ(matched_layout, &array_shape.layout()); +} + +TEST(PatternMatcherTest, TupleShape) { + auto tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1, 2, 3}), + ShapeUtil::MakeShape(S32, {4, 5}), + }); + EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple())); + EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray())); + EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar())); + + Shape* subshape; + ASSERT_TRUE(Match( + &tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3)))); + ASSERT_NE(subshape, nullptr); + EXPECT_TRUE( + ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0}))); + EXPECT_TRUE(Match(&tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {0}))))); + EXPECT_FALSE(Match(&tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {1}))))); + + ASSERT_TRUE(Match( + &tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2)))); + ASSERT_NE(subshape, nullptr); + EXPECT_TRUE( + ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1}))); + EXPECT_TRUE(Match(&tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {1}))))); + EXPECT_FALSE(Match(&tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {0}))))); + + EXPECT_FALSE( + Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape()))); + EXPECT_FALSE( + Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); +} + +TEST(PatternMatcherTest, FusionKind) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + fused_computation { + ROOT fp0 = f32[] parameter(0) + } + + ENTRY while.v11 { + p0 = f32[] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop))); + EXPECT_FALSE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput))); + EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind( + HloInstruction::FusionKind::kLoop))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index aa974ee61a27de9c19e97d8a6eb48f9261ce4bd9..7c63c0acc7764d558b2151190f0fa79fac355cbf 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -29,8 +29,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { using tensorflow::str_util::Lowercase; diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index 69188820a70707d9c9be10b20fb7de92ad4d9873..571451ba43a81d19b70e4954e45d3447f15dcedc 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -34,29 +34,27 @@ class PlatformUtil { // // Note that, even if a platform is present with zero devices, if we *do* have // compilation support for it, it will be returned in this sequence. - static StatusOr> - GetSupportedPlatforms(); + static StatusOr> GetSupportedPlatforms(); // Convenience function which returns the default supported platform for // tests. If exactly one supported platform is present, then this platform is // the default platform. If exactly two platforms are present and one of them // is the interpreter platform, then the other platform is the default // platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetDefaultPlatform(); // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetSolePlatform(); + static StatusOr GetSolePlatform(); // Returns the platform according to the given name. Returns error if there is // no such platform. - static StatusOr GetPlatform( - const string& platform_name); + static StatusOr GetPlatform(const string& platform_name); // Returns exactly one platform that does not have given name. Returns error // if there is no such platform, or there are multiple such platforms. - static StatusOr GetPlatformExceptFor( + static StatusOr GetPlatformExceptFor( const string& platform_name); // Returns a vector of StreamExecutors for the given platform. The vector is @@ -64,8 +62,8 @@ class PlatformUtil { // element is nullptr, then the device is present by not supported by XLA. // // If the platform has no visible devices, a not-found error is returned. - static StatusOr> - GetStreamExecutors(perftools::gputools::Platform* platform); + static StatusOr> GetStreamExecutors( + se::Platform* platform); private: TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil); diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index e2c07e38271df8b8875b2c9291f18ba41a9e6acd..688cceff0cd10df62a4093f00ad3331ca77652e0 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -75,7 +75,7 @@ StatusOr ReducePrecisionInsertion::insert_after( return false; } - // Check that we haven't already inserted an equivalant reduce-precision + // Check that we haven't already inserted an equivalent reduce-precision // operation after this instruction. (The zero-user case occurs when this is // the root instruction.) if (instruction->user_count() > 0) { diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index e62bafc50b0e1270702621c9ea7b2ee43e001fe0..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,8 +53,8 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } -// Returns true iff `instruction` can change its shape simply by adjusting -// metadata. +// Returns true if `instruction` can change its shape simply by adjusting +// metadata or if `instruction` is a broadcast of a scalar value. bool CanTriviallyChangeShape(const HloInstruction* instruction) { // NOTE: Technically a sequence of reshape(reshape(constant)) is also // trivially reshapable, so we might be tempted to simply recurse if @@ -88,19 +88,31 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { instruction->user_count() == 1) { return true; } + + // A broadcase of scalar can trivially change its shape. + if (instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(instruction->operand(0)->shape())) { + return true; + } + return false; } -// Finds the first non-scalar operand of an instruction that is a non-trivial -// reshape or transpose. Returns the operand if it is found or nullptr if not -// found. +// Returns true iff `instruction` is a reshape/transpose instruction for which +// a shape change is nontrivial. +bool IsNontrivialReshape(const HloInstruction* instruction) { + return !ShapeUtil::IsScalar(instruction->shape()) && + IsReshapeOrTranspose(instruction) && + !CanTriviallyChangeShape(instruction->operand(0)); +} + +// Finds the first operand of an instruction that is a non-trivial reshape or +// transpose. Returns such an operand or nullptr if not found. HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { - if (!ShapeUtil::IsScalar(operand->shape()) && - IsReshapeOrTranspose(operand) && - !CanTriviallyChangeShape(operand->operand(0))) { - VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " + if (IsNontrivialReshape(operand)) { + VLOG(5) << "Found first non-trivial reshape operand of " << hlo->ToString(HloPrintOptions().set_print_metadata(false)) << ":\n\t" << operand->ToString(HloPrintOptions().set_print_metadata(false)); @@ -110,7 +122,7 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( return nullptr; } -// Returns whether `a` and `b` are equivalent for the purposes of this pass. +// Returns whether `a` and `b` are equivalent reshapes/transposes. bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { if (a->opcode() != b->opcode() || !ShapeUtil::SameDimensions(a->shape(), b->shape())) { @@ -127,71 +139,14 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if all operands of `instruction` can easily change shape. -// Operands can easily change shape if they are all reshapes/transposes to and -// from the same shape. Additionally, operands like constant, rng, and any -// scalar change shape with only an adjustment of metadata. -bool AllOperandsHaveEasyShapeChanges( - const HloInstruction* instruction, - const HloInstruction* first_reshape_operand) { - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, it may be - // implicitly broadcast, which can confound the movement's - // correctness. - // - // And one of the following: - // 1. Are reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. Are one of kConstant, kRng, and scalars that can change shape - // trivially, - for (const HloInstruction* operand : instruction->operands()) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToString(print_no_metadata) << "\n\tinstruction: " - << instruction->ToString(print_no_metadata); - return false; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToString(print_no_metadata) - << "\n\toperand: " << operand->ToString(print_no_metadata); - continue; - } - - if (CanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToString(print_no_metadata); - continue; - } - - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " - << operand->ToString(print_no_metadata); - return false; - } - - VLOG(3) << "All operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - return true; -} - // This function is called once we've decided to sink reshape/transpose operands // across an instruction. It returns an updated `operand` with a shape that // plays nicely with `new_operand_shape`; either it has the same shape (of the // correct type), or it is a scalar that may be implicitly broadcast. -HloInstruction* UpdateOperand(HloComputation* computation, - const HloInstruction* first_reshape_operand, +HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, const Shape& new_operand_shape, HloInstruction* operand) { + HloComputation* computation = operand->parent(); const PrimitiveType element_type = operand->shape().element_type(); const Shape new_shape = ShapeUtil::ChangeElementType(new_operand_shape, element_type); @@ -222,36 +177,24 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } + case HloOpcode::kBroadcast: { + CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape())); + HloInstruction* inst = computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + VLOG(5) << "Changing broadcast from " << operand->ToString() << " to " + << inst->ToString(); + return inst; + } + default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } } -// Try to sink any reshape or transpose operands of `instruction` across it. We -// do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. -StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - // Only perform sinks for live elementwise instructions with operands. - const bool is_dead = instruction->user_count() == 0 && - instruction != computation->root_instruction(); - if (!instruction->IsElementwise() || instruction->operands().empty() || - is_dead) { - return false; - } - - // Only perform sinks if there are any nontrivial reshape/transpose operands. - const HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - if (!first_reshape_operand) { - return false; - } - - // Only perform sinks if all operands can easily change shape. - if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { - return false; - } - +// Actually performs the reshape-move transformation -- that is, sinks the +// reshape or transpose operands of `instruction` across it. +StatusOr PerformSinkReshapeOrTranspose( + HloInstruction* instruction, const HloInstruction* first_reshape_operand) { auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); @@ -272,8 +215,8 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, } VLOG(3) << "Updating operand #" << i << ": " << operands[i]->ToString(print_no_metadata); - operands[i] = UpdateOperand(computation, first_reshape_operand, - new_operand_shape, operands[i]); + operands[i] = + UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is @@ -285,6 +228,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, *shape->mutable_layout() = new_operand_shape.layout(); } } + HloComputation* computation = instruction->parent(); HloInstruction* new_elementwise = computation->AddInstruction(instruction->CloneWithNewOperands( // `instruction` may change the element type, e.g., from @@ -319,6 +263,141 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, return true; } +// Returns true if the instruction is a reshape-move candidate. +// +// An instruction is a reshape-move candidate if the instruction is elementwise, +// has at least one nontrivial reshape/transpose operand, and its operands are +// either trivially reshapable or are equivalent nontrivial reshapes/transposes. +bool IsReshapeMoveCandidate(HloInstruction* instruction) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + VLOG(5) << "** Checking instruction: " + << instruction->ToString(print_no_metadata); + + // Only perform reshape-move for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != instruction->parent()->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { + return false; + } + + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, they may be + // implicitly broadcast, which can confound the movement's + // correctness. + // + // And one of the following: + // 1. Are reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars + // that can change shape trivially. + const HloInstruction* first_reshape_operand = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); + return false; + } + + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToString(print_no_metadata); + continue; + } + + if (!IsNontrivialReshape(operand)) { + VLOG(5) << "Operand can't trivially change shape: " + << operand->ToString(print_no_metadata); + return false; + } + + if (first_reshape_operand == nullptr) { + first_reshape_operand = operand; + VLOG(5) << "First reshape operand " + << operand->ToString(print_no_metadata); + } else if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) + << "Operand is an equivalent reshape of the first reshape operand " + << operand->ToString(print_no_metadata); + } else { + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is a reshape but is not equivalent to the first " + "Reshape operand" + << operand->ToString(print_no_metadata); + return false; + } + } + + if (first_reshape_operand) { + VLOG(5) << "All operands have easy shape changes: " + << instruction->ToString(print_no_metadata); + } + + return first_reshape_operand != nullptr; +} + +// Reshape-moves all qualifying instructions in reshape_candidates. Returns +// true if it makes changes. +// +// `reshape_candidates` is a set of HloInstructions with nontrivial reshape +// operands, and a instruction in the set can be reshape-moved iff all the users +// of its nontrivial reshape operands can also be reshaped-moved. +// +// The algorithm here iteratively finds the nontrivial operands with users that +// are outside the set of `reshape_candidates`, and removes their users from +// `reshape_candidates`, until either `reshape_candidates` becomes empty or none +// of the remaining nontrivial operands have users outside `reshape_candidates`. +// In the later case, all the remaining instructions in `reshape_candidates` +// are reshape-moved and the routine returns true. +StatusOr TryReshapeMoveOnCandidates( + HloInstructionSet* reshape_candidates) { + bool removed = true; + while (!reshape_candidates->empty() && removed) { + if (VLOG_IS_ON(5)) { + for (const HloInstruction* instruction : *reshape_candidates) { + VLOG(5) << "candidate " << instruction->ToString(); + } + } + ConstHloInstructionSet nontrivial_operands; + for (const HloInstruction* instruction : *reshape_candidates) { + for (const auto* operand : instruction->operands()) { + if (IsNontrivialReshape(operand)) { + nontrivial_operands.insert(operand); + } + } + } + + removed = false; + for (auto operand : nontrivial_operands) { + if (c_any_of(operand->users(), [&](HloInstruction* user) { + return !reshape_candidates->count(user); + })) { + for (auto* user : operand->users()) { + removed |= reshape_candidates->erase(user) > 0; + } + } + } + } + + if (reshape_candidates->empty()) { + return false; + } + for (HloInstruction* instruction : *reshape_candidates) { + const HloInstruction* first_reshape_operand = + FirstNonScalarAndNonTrivialReshapeOperand(instruction); + TF_ASSIGN_OR_RETURN( + bool did_change, + PerformSinkReshapeOrTranspose(instruction, first_reshape_operand)); + CHECK(did_change); + } + return true; +} + } // namespace StatusOr ReshapeMover::Run(HloModule* module) { @@ -326,11 +405,15 @@ StatusOr ReshapeMover::Run(HloModule* module) { VLOG(2) << "Pre ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); for (auto* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool did_change, - TrySinkReshapeOrTranspose(comp, instruction)); - changed |= did_change; + HloInstructionSet reshape_candidates; + for (HloInstruction* instruction : comp->instructions()) { + if (IsReshapeMoveCandidate(instruction)) { + reshape_candidates.insert(instruction); + } } + TF_ASSIGN_OR_RETURN(bool did_change, + TryReshapeMoveOnCandidates(&reshape_candidates)); + changed |= did_change; } VLOG(2) << "Post ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index aac8638a54f744f0c230ec6c5ca071c1daf45ab2..13e2d3258e3b92f52320201c382594962c0e3b2b 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -458,57 +458,6 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { EXPECT_EQ(select, computation->root_instruction()); } -// Tree looks like: -// -// param0 [1,128,1] -// | -// reshape [128,1] constant [128,1024] -// \ / -// multiply w/implicit broadcast [128,1024] -// -// The reshape mover would like to sink the reshape below the multiply. -// -// Previously we would attempt to insert a reshape of the constant to [1,128,1] -// (which is unsound, because it has a different number of elements) as -// preparation for sinking the reshape. -// -// To eliminate the unsoundness, we outlaw reshape sinking when one of the -// operands is implicitly broadcast in the elementwise consumer. -// -// TODO(b/37799338) However, it would be possible in this case to do a more -// in-depth analysis to get reshape movement to occur: -// -// 1. Note that the broadcast dimension (logical dimension 1) in the operands -// would map back to logical dimension 2 in the param0 node. -// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 -// dimension). -// 3. Reshape to [128,1024] at the root. -// -// But this is not currently done. -TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { - HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); - auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {128, 1}), param0)); - Array2D a(128, 1024); - auto literal = Literal::CreateR2FromArray2D(a); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(std::move(literal))); - auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( - constant->shape(), HloOpcode::kMultiply, constant, reshape)); - - auto computation = module().AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Constant(), op::Reshape(param0))); - - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Constant(), op::Reshape(param0))); - EXPECT_EQ(multiply, computation->root_instruction()); -} - // Tree looks like this: // // add1 @@ -560,5 +509,95 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); } +TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { + const string hlo_string = R"( + HloModule TransposeMulInversedTransposeModule + ENTRY TransposeMulInversedTranspose { + src0 = f32[20,8]{1,0} parameter(0) + transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0} + src1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={} + ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Transpose(op::Multiply())); +} + +TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { + const string hlo_string = R"( + HloModule ReshapeWithUsersOutsideCandidates + ENTRY ReshapeWithMultipleUsers { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + param4 = f32[8,20]{1,0} parameter(4) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, param4) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates1 + ENTRY ReshapeWithMultipleUsers1 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, reshape2) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates2 + ENTRY ReshapeWithMultipleUsers2 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Reshape(op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index e278eab69088d3031b1d951734b7dcad6f8afc77..d01c35b99231310692f85d0f9fbf4f2c3709d44c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -54,8 +53,6 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; using ::xla::source_map_util::InvalidParameterArgument; @@ -64,12 +61,11 @@ namespace xla { namespace { -// Records the arguments used to invoke a computation in a SessionModule -// proto. -tensorflow::Status RecordArguments( +// Records the arguments used to invoke a computation in an HloSnapshot proto. +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, - SessionModule* module) { + HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( @@ -77,33 +73,28 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } -// Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - SessionModule* module) { +// Records the result of a computation in a HloSnapshot proto. +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace -ServiceOptions& ServiceOptions::set_platform( - perftools::gputools::Platform* platform) { +ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) { platform_ = platform; return *this; } -perftools::gputools::Platform* ServiceOptions::platform() const { - return platform_; -} +se::Platform* ServiceOptions::platform() const { return platform_; } ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { number_of_replicas_ = number_of_replicas; @@ -123,7 +114,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { } /* static */ StatusOr> Service::NewService( - perftools::gputools::Platform* platform) { + se::Platform* platform) { ServiceOptions default_options; default_options.set_platform(platform); return NewService(default_options); @@ -131,7 +122,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { /* static */ StatusOr> Service::NewService( const ServiceOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); std::unique_ptr execute_backend; if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); @@ -176,35 +167,20 @@ Service::Service(const ServiceOptions& options, } } -tensorflow::Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { +Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { +Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { return allocation_tracker_.Unregister(arg->data()); } // Deconstructs a previously-allocated global handle. -tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { +Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { TF_ASSIGN_OR_RETURN( std::vector elements, allocation_tracker_.DeconstructTuple(arg->tuple_handle())); @@ -212,11 +188,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, for (auto& element : elements) { *result->add_element_handles() = element; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const { +Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const { if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -232,10 +208,13 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr>> +Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - int device_ordinal) { - std::vector shaped_buffers; + tensorflow::gtl::ArraySlice stream_executors) { + CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); + std::vector> replicated_arguments; + replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); if (!buffer_status.ok()) { @@ -243,32 +222,36 @@ StatusOr> Service::ResolveAndValidateArguments( StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); - - // Verify allocation is same platform and device as the execution. - if (shaped_buffer->platform() != execute_backend_->platform() || - shaped_buffer->device_ordinal() != device_ordinal) { - return InvalidArgument( - "argument %lu is on device %s:%d but computation will be executed " - "on device %s", - i, shaped_buffer->platform()->Name().c_str(), - shaped_buffer->device_ordinal(), - execute_backend_->device_name(device_ordinal).c_str()); + auto replicated_buffers = buffer_status.ValueOrDie(); + CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); + for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { + const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; + int replica_device_ordinal = stream_executors[replica]->device_ordinal(); + // Verify allocation is same platform and device as the execution. + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != replica_device_ordinal) { + return InvalidArgument( + "argument %lu is on device %s:%d but computation will be executed " + "on device %s", + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(replica_device_ordinal).c_str()); + } + replicated_arguments[replica].push_back(shaped_buffer); } - - shaped_buffers.push_back(shaped_buffer); } - return shaped_buffers; + return replicated_arguments; } StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation& user_computation) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - auto* computation_layout = config->mutable_entry_computation_layout(); - + ComputationLayout* host_computation_layout = + config->mutable_host_entry_computation_layout(); + ComputationLayout* device_computation_layout = + config->mutable_device_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -279,16 +262,16 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - return InvalidParameterArgument( - *user_computation.ParameterMetadata(i).value(), - "Argument does not match shape of computation parameter %d: want %s, " - "got %s", + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - *argument_shapes[i])); + TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { @@ -297,18 +280,26 @@ StatusOr> Service::CreateModuleConfig( TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( + host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + shape_with_output_layout)); + TF_RETURN_IF_ERROR( + device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { - computation_layout->mutable_result_layout()->Clear(); + // If the result layout is not set, then choose the default. + // TODO(b/29118294): Allow the compiler to choose a better layout in this + // case. + // TODO(b/78356948): We are forcing the default layout here. We should fix + // clients which expect a default layout, to be explicit about it, by + // passing the proper ExecutionOptions with shape_with_output_layout set. + host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); if (execution_options != nullptr) { config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); - config->enable_hlo_profiling( - execution_options->debug_options().xla_hlo_profile()); } else { config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); } @@ -324,62 +315,55 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation& user_computation) { + const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options, - user_computation); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( - std::vector versioned_handles, + const std::vector& module_protos, std::vector> module_configs, - Backend* backend, - std::vector> executors, + Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. - std::vector> session_modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { + std::vector> hlo_snapshots; + for (int64 i = 0; i < module_protos.size(); ++i) { const string& directory_path = module_configs[i]->debug_options().xla_dump_computations_to(); - const string& other_directory_path = + const string& execution_directory_path = module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && other_directory_path.empty()) { + if (directory_path.empty() && execution_directory_path.empty()) { continue; } - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); + auto hlo_snapshot = MakeUnique(); + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); + string filename = + Printf("computation_%lld__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name().c_str()); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); + hlo_snapshots.push_back(std::move(hlo_snapshot)); } } - VLOG(1) << "Computation handles:"; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle; + VLOG(1) << "Computations:"; + for (const HloModuleProto* proto : module_protos) { + VLOG(1) << proto->name(); } - CHECK_EQ(versioned_handles.size(), module_configs.size()); + CHECK_EQ(module_protos.size(), module_configs.size()); std::vector> modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const VersionedComputationHandle& versioned_handle = versioned_handles[i]; + for (int64 i = 0; i < module_protos.size(); ++i) { + const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, - computation_tracker_.BuildHloModule( - versioned_handle, config, - /*include_unreachable_instructions=*/true)); + HloModule::CreateFromProto(*proto, config)); modules.push_back(std::move(module)); } @@ -388,116 +372,43 @@ StatusOr>> Service::BuildExecutables( backend->compiler()->Compile(std::move(modules), std::move(executors), device_allocator)); - for (size_t i = 0; i < versioned_handles.size(); ++i) { + for (size_t i = 0; i < module_protos.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { - executables[i]->set_session_module(std::move(session_modules[i])); + executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i])); } } return std::move(executables); } -StatusOr> Service::BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, - versioned_handle.ToString().c_str()); - - // Dump computation proto state if flag is set. - std::unique_ptr session_module; - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !other_directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } +Status Service::ValidateEntryComputationLayout(HloModule* module) { + const ComputationLayout& on_device = + module->device_entry_computation_layout(); + for (int64 i = 0; i < on_device.parameter_count(); ++i) { + TF_RET_CHECK(ShapeUtil::Equal( + on_device.parameter_shape(i), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().parameter_shape(i)))); } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - true)); - - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); - - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); - - if (!other_directory_path.empty()) { - executable->set_session_module(std::move(session_module)); - } - - return std::move(executable); -} - -StatusOr> Service::BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator) { - std::shared_ptr executable = - compilation_cache_.LookUp(versioned_handle, *module_config); - - if (executable != nullptr) { - // Executable found in the computation cache. - if (profile != nullptr) { - profile->set_compilation_cache_hit(true); - } - return executable; - } - - uint64 start_micros = - // Avoid reading the clock if we don't want timing info - (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; - - // Take a copy of the module config, as compilation introduces layouts where - // layouts were optional before. - HloModuleConfig original_module_config = *module_config; - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), backend, - executor, device_allocator)); - - if (profile != nullptr) { - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - uint64 milliseconds = (end_micros - start_micros) / 1000; - profile->set_compilation_cache_hit(false); - profile->set_compile_time_ms(milliseconds); - } - - // Insert executable into the cache. - return compilation_cache_.Insert(std::move(executable_unique_ptr), - original_module_config); + TF_RET_CHECK(ShapeUtil::Equal( + module->device_entry_computation_layout().result_shape(), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().result_shape()))); + return Status::OK(); } StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice> arguments, + tensorflow::gtl::ArraySlice>> + arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector::SmartPtr> streams; - std::vector> timers; + std::vector> timers; // Global data handles for the computation results, one for each computation. std::vector result_handles; @@ -506,21 +417,29 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + std::vector result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.emplace_back( - new perftools::gputools::Timer(streams.back()->parent())); + timers.emplace_back(new se::Timer(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -537,7 +456,6 @@ Service::ExecuteParallelAndRegisterResult( ExecutableRunOptions options; options.set_stream(streams.back().get()); options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); @@ -545,23 +463,20 @@ Service::ExecuteParallelAndRegisterResult( backend->StreamBorrower()); // Asynchronously launch the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr result, - executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + executables[i]->ExecuteAsyncOnStream( + &run_options, arguments[i][replica])); if (replica == 0 && profile != nullptr) { streams.back()->ThenStopTimer(timers.back().get()); } - // All replicas share the same device address for the result allocation, - // so only one of the replicas need to register the result handle. - if (replica == 0) { - TF_ASSIGN_OR_RETURN( - GlobalDataHandle handle, - allocation_tracker_.Register(std::move(result), result_tags[i])); - result_handles.push_back(handle); - } + result_buffers.emplace_back(std::move(result)); } + TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, + allocation_tracker_.RegisterReplicatedBuffers( + std::move(result_buffers), result_tags[i])); + result_handles.push_back(handle); } // Wait for all executions to complete. @@ -627,9 +542,9 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - const string& result_tag, ExecutionProfile* profile) { + const tensorflow::gtl::ArraySlice> + arguments, + Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector::SmartPtr> streams; @@ -654,52 +569,91 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back(options, backend->StreamBorrower(), - backend->inter_op_thread_pool()); + run_options.emplace_back( + options, backend->StreamBorrower(), + /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); } - std::unique_ptr result; if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN(result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); - } else { - // TODO(b/69985541): Support profiling also on this path. - std::vector> - repeated_arguments(options_.number_of_replicas(), arguments); - - TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( - run_options, repeated_arguments)); - TF_RET_CHECK(!results.empty()); - result = std::move(results[0]); + TF_ASSIGN_OR_RETURN( + auto result, executable->ExecuteOnStreamWrapper(&run_options[0], + profile, arguments[0])); + return allocation_tracker_.Register(std::move(result), result_tag); + } + + // TODO(b/69985541): Support profiling also on this path. + + std::vector> + replicated_arguments; + for (const auto& arg : arguments) { + replicated_arguments.emplace_back(arg); } - return allocation_tracker_.Register(std::move(result), result_tag); + + TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( + run_options, replicated_arguments)); + TF_RET_CHECK(!results.empty()); + return allocation_tracker_.RegisterReplicatedBuffers(std::move(results), + result_tag); } -tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); +StatusOr> Service::GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const { + if (execution_options.device_handles().empty()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); + } + if (requests_size > 1 && execution_options.device_handles_size() > 1) { + return InvalidArgument( + "Parallel requests with multiple device handles is not supported. " + "Found %lld parallel requests, with request %lld containing %d device " + "handles.", + requests_size, request_index, execution_options.device_handles_size()); + } + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } + return executors; +} + +StatusOr>> Service::GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments) { + // Resolve the allocations for the arguments of the computation, and create + // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, execution_options.device_handles(0))); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arguments, replicas)); + return replicated_arguments; } -tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); +Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { + VLOG(1) << "running execute-graph-parallel request"; - std::vector> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; + std::vector>> all_arguments; + std::vector> all_executors; + std::vector module_protos; std::vector> module_configs; std::vector computation_names; std::vector device_handles; int num_requested_devices = std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { + [](int a, const ExecuteGraphRequest& r) -> int { return a + r.execution_options().device_handles_size(); }); if (num_requested_devices * options_.number_of_replicas() > @@ -714,75 +668,53 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // is one of the executors to run the replicated computation. const ExecutionOptions& execution_options = arg->requests(i).execution_options(); - if (execution_options.device_handles().empty()) { - return FailedPrecondition( - "device handles must be given to execute parallel computations"); - } - std::vector executors; - for (const auto& device_handle : execution_options.device_handles()) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, device_handle)); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); - executors.push_back(executor); - } + const ExecuteGraphRequest& request = arg->requests(i); + TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; + TF_RET_CHECK(request.computation().has_program_shape()) + << "programe shape may not be empty"; - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Resolve the allocations for the arguments of the computation, and create - // a vector of device memory offsets for the arguments from the allocations. - // In the case of partitioned computations, assume all arguments go on the - // zeroth core. - TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(request.arguments(), - executors[0]->device_ordinal())); + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, - request.execution_options(), *user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + CreateModuleConfig(request.computation().program_shape(), + replicated_arguments.front(), + request.execution_options())); + VLOG(3) + << "ExecuteGraphParallel created HloModuleConfig computation layout: " + << module_config->host_entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {}); - versioned_handles.push_back(versioned_handle); + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); + module_protos.push_back(&request.computation()); module_configs.push_back(std::move(module_config)); computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); + request.computation().name()); all_executors.push_back(executors); device_handles.insert(device_handles.end(), execution_options.device_handles().begin(), execution_options.device_handles().end()); } - // Build the user computations into HloModules and compile to generate the - // executables. + // Build the HloModules and compile to generate the executables. // // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); + // ExecuteGraphParallel, so we have to pass a null device_allocator below. + TF_ASSIGN_OR_RETURN(std::vector> executables, + BuildExecutables(module_protos, std::move(module_configs), + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -804,12 +736,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, *result->add_responses() = response; } - VLOG(1) << "successfully completed 'execute-parallel' request"; - return tensorflow::Status::OK(); + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; + return Status::OK(); } -tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); const int64 replica_count = options_.number_of_replicas(); if (replica_count <= 0) { @@ -829,174 +761,148 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, *result->add_device_handles() = device_handle; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } +Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + ExecuteGraphParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - return Status::OK(); +Status Service::PickParallelResponse( + const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { + // The "result device" selection is a bit hacky, but better than assuming it + // is device 0. We have b/76035356 for restructuring the client API to clean + // up the current asymmetries and support more functionalities. + for (int64 i = 0; i < parallel_result.responses_size(); ++i) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica( + parallel_result.responses(i).output(), 0)); + const Shape& shape = buffer->on_host_shape(); + if (!ShapeUtil::IsEmptyTuple(shape)) { + *result = parallel_result.responses(i); + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return Status::OK(); + } } + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + VLOG(1) << "Defaulting to device 0 result"; + return Status::OK(); +} - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(arg->arguments(), - execute_backend_->default_device_ordinal())); +StatusOr> Service::BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf( + "BuildExecutable on service %p with serialized module proto: %s", this, + module_proto.name().c_str()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options(), - *user_computation)); + // Dump computation proto state if flag is set. + auto hlo_snapshot = MakeUnique(); + const string& directory_path = + module_config->debug_options().xla_dump_computations_to(); + const string& execution_directory_path = + module_config->debug_options().xla_dump_executions_to(); + if (!directory_path.empty() || !execution_directory_path.empty()) { + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; + if (!directory_path.empty()) { + string filename = Printf("computation_%lld__%s", module_proto.id(), + module_proto.entry_computation_name().c_str()); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); + } + } - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(module_proto, *module_config)); - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - arguments, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), - "result of " + user_computation->name(), result->mutable_profile())); + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + // Check that on-host and on-device shapes are consistent. + TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, - allocation_tracker_.Resolve(result->output())); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); - VLOG(1) << "successfully completed 'execute' request"; - return tensorflow::Status::OK(); + return std::move(executable); } -tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); +Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + VLOG(1) << "running execute-graph request"; - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { + if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("programe shape may not be empty"); + } - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(arg->arguments(), - execute_backend_->default_device_ordinal())); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options(), - *user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + return ExecuteOneToN(arg, result); + } TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); - // Set up streams. - std::vector::SmartPtr> streams; + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(arg->computation().program_shape(), + replicated_arguments.front(), + arg->execution_options())); - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + BuildExecutable(arg->computation(), std::move(module_config), + execute_backend_.get(), + execute_backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + + if (executable->dumping_snapshot()) { + executable->hlo_snapshot()->set_execution_platform( + execute_backend_->platform()->Name()); + TF_RETURN_IF_ERROR(RecordArguments( + replicated_arguments.front(), + execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->hlo_snapshot())); } - std::unique_ptr result_buffer; - for (const Pool::SmartPtr& stream : streams) { - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_inter_op_thread_pool(execute_backend_->inter_op_thread_pool()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + "result of " + arg->computation().name(), result->mutable_profile())); + if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN( - std::unique_ptr this_result_buffer, - executable->ExecuteAsyncOnStream(&service_options, arguments)); - - // Take the first result. - if (result_buffer == nullptr) { - result_buffer = std::move(this_result_buffer); - } + const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(result->output(), 0)); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.Register(std::move(result_buffer), - "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return tensorflow::Status::OK(); + VLOG(1) << "successfully completed 'execute-graph' request"; + return Status::OK(); } -tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) { +Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, execution_tracker_.Resolve(arg->execution())); @@ -1007,13 +913,13 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); VLOG(1) << "successfully completed 'wait-for-execution' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.Resolve(arg->data())); + allocation_tracker_.ResolveForReplica(arg->data(), 0)); const Shape* return_shape; if (arg->has_shape_with_layout()) { @@ -1041,7 +947,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, *result->mutable_literal() = result_literal->Relayout(*return_shape)->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } namespace { @@ -1059,8 +965,8 @@ std::unique_ptr CloneShapedBufferOnDevice( } // namespace -tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { +Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(arg->literal())); const Shape& shape = literal->shape(); @@ -1074,43 +980,30 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // All memory allocation is done on the first replica. The allocations in all - // other replicas mirror the firsts'. - int master_device_ordinal = replicas[0]->device_ordinal(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr shaped_buffer, - execute_backend_->transfer_manager()->AllocateShapedBuffer( - shape, execute_backend_->memory_allocator(), master_device_ordinal)); - - // Transfer the data to the replicas. + // Allocate memory in each replica and transfer the data to all replicas. + std::vector replicated_buffers; for (se::StreamExecutor* executor : replicas) { - if (executor->device_ordinal() == master_device_ordinal) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *shaped_buffer)); - } else { - // The replica is not the master. Create an cloned shaped buffer with - // the replica's device ordinal. This is required because - // TransferLiteralToDevice verifies that the device ordinal of the shaped - // buffer matches that of the executor. - std::unique_ptr clone = - CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *clone)); - } + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer shaped_buffer, + execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( + shape, execute_backend_->memory_allocator(), + executor->device_ordinal())); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, shaped_buffer)); + replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN( - *result->mutable_data(), - allocation_tracker_.Register(std::move(shaped_buffer), - StrCat("TransferToServer literal of shape ", - ShapeUtil::HumanString(shape)))); + TF_ASSIGN_OR_RETURN(*result->mutable_data(), + allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1139,9 +1032,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, *literal); } -tensorflow::Status Service::TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1167,169 +1059,76 @@ tensorflow::Status Service::TransferFromOutfeed( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, arg->shape_with_layout(), &literal)); *result->mutable_literal() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { +Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); } -tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { +Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) { + if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); + if (!arg->computation().has_program_shape()) { + return InvalidArgument("program shape may not be empty"); } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } + if (arg->computation().program_shape().parameters_size() != 0) { return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); + "constant computation may not depend on any parameters."); } - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - + ProgramShape program_shape = arg->computation().program_shape(); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); + arg->output_layout(), program_shape.result())); } - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - *user_computation)); + HloModuleConfig config(program_shape); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate>( + *module, /*arg_literals=*/{})); - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. + // Since the result layout is non-effective to the Evaluator results, explicit + // relayout here. + // + // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { result_literal = result_literal->Relayout(arg->output_layout()); } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.Resolve(arg->data())); + allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); +Status Service::GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { + if (!arg->has_computation()) { + return InvalidArgument("Computations may not be empty."); + } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("Program shape may not be empty."); + } - HloModuleConfig config; + HloModuleConfig config(arg->computation().program_shape()); config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); hlo_graph_dumper::MaybeDumpHloModule(*module, "computation statistics subject"); @@ -1344,262 +1143,7 @@ tensorflow::Status Service::GetComputationStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); -} - -template -tensorflow::Status Service::AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - TF_RETURN_IF_ERROR(computation->AddSendInstruction(arg->send_request())); - return tensorflow::Status::OK(); - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceHandle Service::SingleComputationDeviceHandle() const { @@ -1609,9 +1153,9 @@ DeviceHandle Service::SingleComputationDeviceHandle() const { return device_handle; } -StatusOr> Service::Replicas( +StatusOr> Service::Replicas( const Backend& backend, const DeviceHandle& device_handle) const { - std::vector replicas; + std::vector replicas; for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { // From the computation placer, find out the device ids of the replicas for // the given device handle. diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 6ce241971156599aaa25aea1b0caac0e1bd5379c..8748a4c1447eca691abc0f7ca48feda48ceb86e1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -26,17 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -53,8 +48,8 @@ namespace xla { class ServiceOptions { public: // Set the platform backing the service, or nullptr for the default platform. - ServiceOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; + ServiceOptions& set_platform(se::Platform* platform); + se::Platform* platform() const; // Set the number of replicas to use when compiling replicated // programs. @@ -66,7 +61,7 @@ class ServiceOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_ = nullptr; + se::Platform* platform_ = nullptr; int number_of_replicas_ = 1; int intra_op_parallelism_threads_ = -1; }; @@ -79,44 +74,33 @@ class Service : public ServiceInterface { public: // Factory method for creating a new Service. static StatusOr> NewService( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + // immutable arguments. The request contains the whole computation graph. + // Returns global data output and execution timing. + Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; // Requests one or more device handles from the target. // @@ -126,49 +110,33 @@ class Service : public ServiceInterface { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; - - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the // first call. - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; // Requests that global data be transferred to the client in literal form. - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; // Transfers data from a literal provided by the client, into device memory. - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; // Transfers data from the Outfeed othe device to the literal provided by the // client. - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -179,67 +147,25 @@ class Service : public ServiceInterface { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; - - // Tests if an expression is a compile-time constant. - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - // Computes the value of a constant expression. - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; - - // Returns the program shape of the computation associated with the given - // handle. - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; // Retrieves the statistics of a computation. - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; // Creates a unique channel handle that can be used for Send/Recv // instructions. - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; - - // Returns the ComputationTracker of the current service instance. - // Only used in unit tests to access user computations from client. - const ComputationTracker& computation_tracker() { - return computation_tracker_; - } + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } @@ -251,8 +177,24 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions& execution_options); + + // Picks a parallel response and fills the result. + Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, + ExecuteResponse* result); + + // Prepare the executors for executing parallel. + StatusOr> GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const; + + // Prepare the arguments for executing parallel. + StatusOr>> GetArguments( const ExecutionOptions& execution_options, - const UserComputation& user_computation); + tensorflow::gtl::ArraySlice arguments); + + // Assert that host- and device-shapes are in a consistent state. + Status ValidateEntryComputationLayout(HloModule* module); protected: friend class LocalExecutable; @@ -262,22 +204,21 @@ class Service : public ServiceInterface { Service(const ServiceOptions& options, std::unique_ptr execute_backend); - static StatusOr> CreateComputeConstantBackend(); - // Resolves the given argument handles in the allocation tracker and returns - // the corresponding allocations. The function also verifies that each - // allocation matches the execution platform and device ordinal. - StatusOr> ResolveAndValidateArguments( + // the corresponding allocations for every replica. The function also verifies + // that each allocation matches the execution platform and device ordinal of + // the corresponding replica. + StatusOr>> + ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - int device_ordinal); + tensorflow::gtl::ArraySlice stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation& user_computation); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. // @@ -285,67 +226,56 @@ class Service : public ServiceInterface { // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. StatusOr> BuildExecutable( - const VersionedComputationHandle& versioned_handle, + const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. StatusOr>> BuildExecutables( - std::vector versioned_handles, + const std::vector& module_protos, std::vector> module_configs, - Backend* backend, - std::vector> executors, + Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); - // Similar to BuildExecutable, but look in the compilation cache for the - // executable first. If the executable is not in the cache, it is built and - // inserted into the cache. - StatusOr> BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator = nullptr); - // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - const string& result_tag, ExecutionProfile* profile); + const tensorflow::gtl::ArraySlice> + arguments, + Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice> arguments, + tensorflow::gtl::ArraySlice>> + arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile); - // Convenience function for adding a function to a user computation. - template - tensorflow::Status AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder); + // Executes a single computation which has more than one target device. + // The N devices are expected to all return an empty tuple, but one, which + // will be the result of this computation. + Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - tensorflow::Status ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const; + Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that // represents a set of physical devices for the replicas. - StatusOr> Replicas( + StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; Status MaybeDumpHloModule(const HloModule& module) const; @@ -356,9 +286,6 @@ class Service : public ServiceInterface { ServiceOptions options_; - // Tracks computations built via the API. - ComputationTracker computation_tracker_; - // Tracks channels created via the API. ChannelTracker channel_tracker_; @@ -368,12 +295,7 @@ class Service : public ServiceInterface { // Tracks asynchronously launched executions via the API. ExecutionTracker execution_tracker_; - // Cache containing previously built Executables. - CompilationCache compilation_cache_; - // Backend to compile and execute computations on. - // - // TODO(b/28616830): Support multiple backends for execution. std::unique_ptr execute_backend_; TF_DISALLOW_COPY_AND_ASSIGN(Service); diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 6c1f8feac7ed4423051cf2737be57dcfab508671..7f3910cdb0366078b97fb5f6a2dc498b37570926 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -28,7 +28,7 @@ namespace xla { class ServiceExecutableRunOptions { public: using StreamBorrower = - std::function::SmartPtr>(int)>; + std::function::SmartPtr>(int)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -45,14 +45,13 @@ class ServiceExecutableRunOptions { ExecutableRunOptions* mutable_run_options() { return &run_options_; } // Delegate to `ExecutableRunOptions` member. - perftools::gputools::Stream* stream() const { return run_options_.stream(); } + se::Stream* stream() const { return run_options_.stream(); } DeviceMemoryAllocator* allocator() const { return run_options_.allocator(); } int device_ordinal() const { return run_options_.device_ordinal(); } // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr::SmartPtr> BorrowStream( - int device_ordinal) const { + StatusOr::SmartPtr> BorrowStream(int device_ordinal) const { return borrow_stream_ ? borrow_stream_(device_ordinal) : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto deleted file mode 100644 index bb8d1cd2a106ea3e5bb61eee5052bd60c38cd0e2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/session.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This proto file defines messages which store the state of XLA -// computations within the XLA service. A computation is stored as a record -// of the operation requests used to build it. -syntax = "proto3"; - -import "tensorflow/compiler/xla/xla_data.proto"; - -package xla; - -// Describes a single operation request. -message OperationRequest { - ComputationDataHandle output_handle = 1; - Shape output_shape = 2; - - // For operations which call embedded computations such as "Map", these are - // the version(s) that the embedded computation should be called at. A version - // value of a computation is the ComputationDataHandle of the root of the - // computation at the point in time. - // - // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single - // embedded computation so this field will have a single value for those - // operations. - // - // "While" operation takes two; index 0 is the "condition" version and index 1 - // is the "body" version. - repeated int64 embedded_computation_versions = 3; - - // The actual request, which in itself is a tagged union of all possible - // operation request types. - OpRequest request = 4; -} - -// Describes a sequence of operation requests which define an XLA -// computation. -message SessionComputation { - string name = 1; - - // The ComputationHandle used to refer to this computation in the XLA - // service. - ComputationHandle computation_handle = 2; - - // Map from ComputationDataHandle value to operation request. The highest - // ComputationDataHandle value corresponds to the root of the computation. - map requests = 3; -} - -// Describes a group of SessionComputations with an "entry point" computation -// that may refer to the other non-entry (AKA embedded) computations. -// -// This message is used to serialize a computation that has been built via the -// XLA service API, along with its dependencies, for purposes such as -// analysis/replay/file-storage. -message SessionModule { - // The entry computation, which was requested for serialization. This may have - // referred to embedded computations, which are reflected below. - SessionComputation entry = 1; - - // Embedded computations that are transitively referred to by the entry - // computation. - repeated SessionComputation embedded_computations = 2; - - // The arguments passed to the computation. - repeated LiteralProto arguments = 3; - - // The result of the computation. - LiteralProto result = 4; - - // The name of the platform used to run the computation. - string execution_platform = 5; -} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c9692757b27980b10a5ca562223c3d0f6462d820..bd98e86b08b7507b4a7a0d1a7ebac4b654ff2171 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -44,146 +44,29 @@ namespace xla { namespace { -// Return the UnaryOperation proto enum value associated with the given HLO -// opcode. -UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAbs: - return UNOP_ABS; - case HloOpcode::kCeil: - return UNOP_CEIL; - case HloOpcode::kCos: - return UNOP_COS; - case HloOpcode::kExp: - return UNOP_EXP; - case HloOpcode::kFloor: - return UNOP_FLOOR; - case HloOpcode::kImag: - return UNOP_IMAG; - case HloOpcode::kIsFinite: - return UNOP_IS_FINITE; - case HloOpcode::kLog: - return UNOP_LOG; - case HloOpcode::kNot: - return UNOP_NOT; - case HloOpcode::kNegate: - return UNOP_NEGATE; - case HloOpcode::kReal: - return UNOP_REAL; - case HloOpcode::kRoundNearestAfz: - return UNOP_ROUND_NEAREST_AFZ; - case HloOpcode::kSign: - return UNOP_SIGN; - case HloOpcode::kSin: - return UNOP_SIN; - case HloOpcode::kSort: - return UNOP_SORT; - case HloOpcode::kTanh: - return UNOP_TANH; - default: - LOG(FATAL) << "Unhandled opcode for conversion to unary operation: " - << opcode; - } -} - -// Return the BinaryOperation proto enum value associated with the given HLO -// opcode. -BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAtan2: - return BINOP_ATAN2; - case HloOpcode::kComplex: - return BINOP_COMPLEX; - case HloOpcode::kMultiply: - return BINOP_MUL; - case HloOpcode::kAdd: - return BINOP_ADD; - case HloOpcode::kSubtract: - return BINOP_SUB; - case HloOpcode::kDivide: - return BINOP_DIV; - case HloOpcode::kEq: - return BINOP_EQ; - case HloOpcode::kGe: - return BINOP_GE; - case HloOpcode::kGt: - return BINOP_GT; - case HloOpcode::kLe: - return BINOP_LE; - case HloOpcode::kLt: - return BINOP_LT; - case HloOpcode::kNe: - return BINOP_NE; - case HloOpcode::kMaximum: - return BINOP_MAX; - case HloOpcode::kMinimum: - return BINOP_MIN; - case HloOpcode::kPower: - return BINOP_POW; - case HloOpcode::kRemainder: - return BINOP_REM; - case HloOpcode::kOr: - return BINOP_OR; - case HloOpcode::kAnd: - return BINOP_AND; - case HloOpcode::kShiftLeft: - return BINOP_SHIFT_LEFT; - case HloOpcode::kShiftRightArithmetic: - return BINOP_SHIFT_RIGHT_ARITHMETIC; - case HloOpcode::kShiftRightLogical: - return BINOP_SHIFT_RIGHT_LOGICAL; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the TernaryOperation proto enum value associated with the given HLO -// opcode. -TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kClamp: - return TRIOP_CLAMP; - case HloOpcode::kSelect: - return TRIOP_SELECT; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the VariadicOperation proto enum value associated with the given HLO -// opcode. -VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kTuple: - return VAROP_TUPLE; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { +Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s. Got: %s", - op_type.ToString().c_str(), + return InvalidArgument("Expected non-tuple argument for %s, but got %s.", + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s. Got: %s", - op_type.ToString().c_str(), + return InvalidArgument("Expected non-opaque argument for %s, but got %s.", + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { - return tensorflow::Status::OK(); + return Status::OK(); } } -tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { if (reducer_shape.parameters_size() != 2) { return InvalidArgument( "Reduction function must take 2 parameters, but " @@ -193,8 +76,10 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, const Shape& accumulator_shape = reducer_shape.result(); if (ShapeUtil::Rank(accumulator_shape) != 0) { - return Unimplemented( - "Reduction function currently must have rank-0 result."); + return InvalidArgument( + "Reduction function must have rank 0 (rank %lld reduction function " + "given).", + ShapeUtil::Rank(accumulator_shape)); } // Check that the accumulator can be passed in as the first argument. @@ -235,13 +120,13 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, reducer_shape.parameters(1))) { return InvalidArgument( - "Reduction function's second parameter shape currently must " - "match the result shape. Got %s vs %s", + "Reduction function's second parameter shape must " + "match the result shape, but got %s vs %s.", ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), ShapeUtil::HumanString(accumulator_shape).c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr InferWindowOutputShape(const Shape& base_shape, @@ -258,29 +143,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { - return InvalidArgument("Window has a non-positive dimension. Window: %s", + return InvalidArgument("Window %s has a non-positive dimension.", window.DebugString().c_str()); } if (dim.stride() <= 0) { - return InvalidArgument("Window has a non-positive stride. Window: %s", + return InvalidArgument("Window %s has a non-positive stride.", window.DebugString().c_str()); } if (!allow_negative_padding && dim.padding_low() < 0) { - return InvalidArgument("Window has a negative low padding. Window: %s", + return InvalidArgument("Window %s has a negative low padding.", window.DebugString().c_str()); } if (!allow_negative_padding && dim.padding_high() < 0) { - return InvalidArgument("Window has a negative high padding. Window: %s", + return InvalidArgument("Window %s has a negative high padding.", window.DebugString().c_str()); } if (dim.base_dilation() < 1) { return InvalidArgument( - "Window has a non-positive base area dilation factor. Window: %s", + "Window %s has a non-positive base area dilation factor.", window.DebugString().c_str()); } if (dim.window_dilation() < 1) { return InvalidArgument( - "Window has a non-positive window dilation factor. Window: %s", + "Window %s has a non-positive window dilation factor.", window.DebugString().c_str()); } @@ -302,86 +187,92 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const HloInstruction* operand) { + return InferUnaryOpShape(opcode, operand->shape()); +} + +/* static */ StatusOr ShapeInference::InferUnaryOpShape( + HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. - if (opcode == HloOpcode::kCopy) { - return operand->shape(); + // A domain shape is the same as the input one. + if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) { + return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape()); -} + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(shape, "operand of unary operation")); -/* static */ StatusOr ShapeInference::InferUnaryOpShape( - UnaryOperation operation, const Shape& arg) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); - switch (operation) { - case UNOP_FLOOR: - case UNOP_CEIL: - if (!ShapeUtil::ElementIsFloating(arg)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + switch (opcode) { + case HloOpcode::kFloor: + case HloOpcode::kCeil: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "expected element type in shape to be floating for floor/ceil " - "operation; got %s", - PrimitiveType_Name(arg.element_type()).c_str()); + "Expected element type in shape to be floating for floor/ceil " + "operation; got %s.", + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_COS: - case UNOP_SIN: - case UNOP_EXP: - case UNOP_LOG: - case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg) && - !ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kCos: + case HloOpcode::kSin: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kTanh: + if (!ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s", - PrimitiveType_Name(arg.element_type()).c_str()); + "Expected element type in shape to be floating or complex for " + "sin/cos/exp/log/tanh operation; got %s.", + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_REAL: - case UNOP_IMAG: - if (!ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kReal: + case HloOpcode::kImag: + if (!ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "expected element type in shape to be complex for real/imag " - "operation; got %s", - PrimitiveType_Name(arg.element_type()).c_str()); + "Expected element type in shape to be complex for real/imag " + "operation; got %s.", + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, F32); - case UNOP_ABS: - if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType(shape, F32); + case HloOpcode::kAbs: + if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( - arg, primitive_util::ComplexComponentType(arg.element_type())); + shape, primitive_util::ComplexComponentType(shape.element_type())); } - return arg; - case UNOP_NEGATE: - case UNOP_ROUND_NEAREST_AFZ: - case UNOP_SIGN: - case UNOP_SORT: - return arg; - - case UNOP_NOT: - if (arg.element_type() != PRED && - !primitive_util::IsIntegralType(arg.element_type())) { + return shape; + case HloOpcode::kClz: + case HloOpcode::kNegate: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSign: + case HloOpcode::kSort: + return shape; + + case HloOpcode::kNot: + if (shape.element_type() != PRED && + !primitive_util::IsIntegralType(shape.element_type())) { return InvalidArgument( - "expected pred or an integral element type in argument to not " - "operation; got %s", - PrimitiveType_Name(arg.element_type()).c_str()); + "Expected pred or an integral element type in argument to Not " + "operation; got %s.", + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; + return shape; - case UNOP_IS_FINITE: - if (!ShapeUtil::ElementIsFloating(arg)) { + case HloOpcode::kIsFinite: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "expected element type in shape to be floating point for IsFinite " - "operation; got %s", - PrimitiveType_Name(arg.element_type()).c_str()); + "Expected element type in shape to be floating " + "point for IsFinite " + "operation; got %s.", + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, PRED); + return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - UnaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -389,10 +280,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, tensorflow::gtl::ArraySlice arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { - return InvalidArgument("Concatenate expects at least one argument"); + return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("dimension to concatenate along out of bounds: %lld", + return InvalidArgument("Concatenate dimension out of bounds: %lld.", dimension); } const Shape* arg_shape = nullptr; @@ -408,14 +299,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " - "(%s)", + "(%s).", ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( - "cannot concatenate arrays with different element types: %s vs %s", + "Cannot concatenate arrays with different element types: %s vs %s.", PrimitiveType_Name(arg_shape->element_type()).c_str(), PrimitiveType_Name(shape->element_type()).c_str()); } @@ -428,9 +319,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // concatenating. } return InvalidArgument( - "cannot concatenate arrays that differ in dimensions other than " + "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld", + "the same): %s vs %s in dimension %lld.", ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::HumanString(*shape).c_str(), dimension); } @@ -446,13 +337,24 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr ShapeInference::InferTokenShape( + tensorflow::gtl::ArraySlice arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); if (primitive_util::IsComplexType(old_element_type) && !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( - "Unsupported conversion from complex to real type: %s => %s", + "Conversion from complex to real type %s => %s is not implemented.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -461,7 +363,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "cannot convert from or to tuple type; requested conversion: %s => %s", + "Convert does not allow tuples, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -474,24 +376,23 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, auto old_element_type = operand_shape.element_type(); if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { - return Unimplemented( - "Unsupported conversion between real and complex types: %s => %s", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + return InvalidArgument("Conversion from complex to real type %s => %s.", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); } if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "cannot convert from or to tuple type; requested conversion: %s => %s", + "Cannot convert from or to tuple type; requested conversion: %s => %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( - "cannot bitcast types with different bit-widths: %s => %s", + "Cannot bitcast types with different bit-widths: %s => %s.", PrimitiveType_Name(old_element_type).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -504,20 +405,20 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const int mantissa_bits) { if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( - "expected element type in shape to be floating point for " - "ReducePrecision operation; got %s", + "Expected element type in shape to be floating point for " + "ReducePrecision operation; got %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having // no exponent bits doesn't produce a sensible number, so we require at // least one. - return InvalidArgument("expected exponent_bits >= 1; got %d", + return InvalidArgument("Expected exponent_bits >= 1; got %d.", exponent_bits); } if (mantissa_bits < 0) { // A number with no mantissa bits is still meaningful, however. - return InvalidArgument("expected non-negative mantissa_bits; got %d", + return InvalidArgument("Expected non-negative mantissa_bits; got %d.", mantissa_bits); } return operand_shape; @@ -528,23 +429,23 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const PaddingConfig& padding_config) { if (ShapeUtil::IsTuple(operand_shape)) { return InvalidArgument( - "pad operation does not support tuple-shape operands"); + "Pad operation does not support tuple-shape operands."); } if (!ShapeUtil::IsScalar(padding_value_shape)) { return InvalidArgument( - "pad operation does not support non-scalar padding values"); + "Pad operation does not support non-scalar padding values."); } if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " - "%s vs %s", + "%s vs %s.", ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( - "the element types of the operands to pad do not match"); + "The element types of the operands to Pad do not match."); } std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { @@ -605,7 +506,7 @@ Status ValidateDotDimensionNumbers( lhs_batch_dimensions) || !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { - return InvalidArgument("A dimension number is out of range in dot: %s", + return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString().c_str()); } @@ -623,7 +524,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { - return InvalidArgument("A dimension number is not unique in dot: %s", + return InvalidArgument("A dimension number is not unique in Dot: %s.", dimension_numbers.DebugString().c_str()); } @@ -641,8 +542,7 @@ Status ValidateDotDimensionNumbers( rhs_non_contracting_non_batch_dims < 0 || rhs_non_contracting_non_batch_dims > 1) { return InvalidArgument( - "batch and contracting dimension number mismatch " - "with rank "); + "Batch and contracting dimension number mismatch with rank."); } // Check that batch dimension numbers are ordered before all others, and @@ -654,7 +554,7 @@ Status ValidateDotDimensionNumbers( !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), rhs_batch_dimensions.begin())) { return InvalidArgument( - "batch dimension numbers must precede non-batch dimensions and be" + "Batch dimension numbers must precede non-batch dimensions and be" "monotonically increasing."); } @@ -671,22 +571,22 @@ Status ValidateDotDimensionNumbers( auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( - "cannot infer shape for dot operation: %s %s", + "Cannot infer shape for dot operation: %s %s.", ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); if (!addendum.empty()) { - message += ": " + addendum; + message += " " + addendum; } return InvalidArgument("%s", message.c_str()); }; // Check if both element types are the same. if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return fail("element types do not match"); + return fail("Element types do not match."); } if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { - return fail("dot only supports rank 1 or above."); + return fail("Dot only supports rank 1 or above."); } // Validate basic properties of dot dimension numbers. @@ -696,7 +596,7 @@ Status ValidateDotDimensionNumbers( if (dimension_numbers.lhs_contracting_dimensions_size() != dimension_numbers.rhs_contracting_dimensions_size() || dimension_numbers.lhs_contracting_dimensions_size() != 1) { - return fail("must specify one contracting dimension for both lhs and rhs."); + return fail("Must specify one contracting dimension for both lhs and rhs."); } // Check that contracting dimension sizes match. @@ -706,13 +606,13 @@ Status ValidateDotDimensionNumbers( dimension_numbers.rhs_contracting_dimensions(0); if (lhs.dimensions(lhs_contracting_dimension) != rhs.dimensions(rhs_contracting_dimension)) { - return fail("contracting dimension sizes do not match."); + return fail("Contracting dimension sizes do not match."); } // Check that number of batch dimensions match. if (dimension_numbers.lhs_batch_dimensions_size() != dimension_numbers.rhs_batch_dimensions_size()) { - return fail("must the same number of batch dimensions for lhs and rhs."); + return fail("Must the same number of batch dimensions for lhs and rhs."); } // Check that batch dimension numbers and sizes match. @@ -721,7 +621,7 @@ Status ValidateDotDimensionNumbers( dimension_numbers.rhs_batch_dimensions(i) || lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("batch dimension numbers and sizes must match for lhs/rhs."); + return fail("Batch dimension numbers and sizes must match for lhs/rhs."); } } @@ -753,8 +653,9 @@ Status ValidateDotDimensionNumbers( } /* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs) { +ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, + const Shape& lhs, + const Shape& rhs) { TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); // The shapes have to be compatible. That is, if some dimension d has a @@ -770,10 +671,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); } else { - return InvalidArgument("binary op %s with incompatible shapes: %s and %s", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + return InvalidArgument( + "Binary op %s with incompatible shapes: %s and %s.", + HloOpcodeString(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -781,22 +683,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. - return InvalidArgument("automatic shape inference not supported: %s and %s", + return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape).c_str(), ShapeUtil::HumanString(larger_shape).c_str()); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( - "size of broadcast_dimensions has to match lower-rank operand's " + "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu", + "%zu.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -846,13 +747,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "broadcast dimension number (%lld) cannot be negative", + "Broadcast dimension number (%lld) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "broadcast dimension number (%lld) too large; higher-rank " - "operand has rank %d", + "Broadcast dimension number (%lld) too large; higher-rank " + "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } int64 small_dimension_size = smaller_shape.dimensions(i); @@ -863,7 +764,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i, + "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, small_dimension_size, large_dimension_size, ShapeUtil::HumanString(smaller_shape).c_str(), ShapeUtil::HumanString(larger_shape).c_str()); @@ -872,7 +773,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "broadcast dimensions order is wrong: %lld comes after %lld", + "Broadcast dimensions order is wrong: %lld comes after %lld.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -883,7 +784,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); @@ -892,9 +793,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( - "binary op %s with different element types: %s and %s", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), + "Binary op %s with different element types: %s and %s.", + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -904,8 +804,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { return InvalidArgument( - "broadcast dimensions field must either be not set or be the " - "identity on binary operations with operands of the same rank"); + "Broadcast dimensions field must either be not set or be the " + "identity on binary operations with operands of the same rank."); } } @@ -927,10 +827,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN( - Shape indim_broadcast_shape, - InferInDimBroadcastShape(operation, smaller_shape, larger_shape, - broadcast_dimensions)); + TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape( operation, indim_broadcast_shape, larger_shape); @@ -939,110 +838,108 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), - rhs->shape(), /*broadcast_dimensions=*/{}); + return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), + /*broadcast_dimensions=*/{}); } /* static */ StatusOr ShapeInference::InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( lhs, tensorflow::strings::StrCat("lhs of binary operation ", - BinaryOperation_Name(operation)))); + HloOpcodeString(opcode)))); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( rhs, tensorflow::strings::StrCat("rhs of binary operation ", - BinaryOperation_Name(operation)))); - switch (operation) { - case BINOP_MAX: - case BINOP_MIN: - case BINOP_SUB: - case BINOP_ADD: - case BINOP_ATAN2: - case BINOP_POW: - case BINOP_DIV: - case BINOP_REM: - case BINOP_MUL: - case BINOP_SHIFT_LEFT: - case BINOP_SHIFT_RIGHT_ARITHMETIC: - case BINOP_SHIFT_RIGHT_LOGICAL: - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + HloOpcodeString(opcode)))); + switch (opcode) { + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kSubtract: + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kPower: + case HloOpcode::kDivide: + case HloOpcode::kRemainder: + case HloOpcode::kMultiply: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_COMPLEX: { + case HloOpcode::kComplex: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( - "expected element type in shape to be floating for complex compose " - "operation; got %s", + "Expected element type in shape to be floating for complex compose " + "operation; got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { - return Unimplemented("complex component type not supported"); + return Unimplemented("Complex component type is not implemented."); } } - case BINOP_AND: - case BINOP_OR: + case HloOpcode::kAnd: + case HloOpcode::kOr: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( - "expected pred or integral type in argument to and/or operation; " - "got %s", + "Expected pred or integral type in argument to and/or operation; " + "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: - case BINOP_GE: - case BINOP_GT: - case BINOP_LE: - case BINOP_LT: - case BINOP_NE: { + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: { TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: return Unimplemented( - "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s", - BinaryOperation_Name(operation).c_str(), - lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", + HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); } } /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, const HloInstruction* ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(), - rhs->shape(), ehs->shape()); + return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape()); } /* static */ StatusOr ShapeInference::InferTernaryOpShape( - TernaryOperation operation, const Shape& lhs, const Shape& rhs, - const Shape& ehs) { + HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); - switch (operation) { - case TRIOP_CLAMP: + switch (opcode) { + case HloOpcode::kClamp: return InferClampShape(lhs, rhs, ehs); - case TRIOP_SELECT: + case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("unknown operation %s", - TernaryOperation_Name(operation).c_str()); + return InvalidArgument("Unknown operation %s.", + HloOpcodeString(opcode).c_str()); } } @@ -1053,18 +950,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } - return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), - operand_shapes); + return InferVariadicOpShape(opcode, operand_shapes); } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, + HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } - switch (operation) { - case VAROP_TUPLE: { + switch (opcode) { + case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); @@ -1072,8 +968,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return result; } default: - return InvalidArgument("unknown operation %s", - VariadicOperation_Name(operation).c_str()); + return InvalidArgument("Unknown operation %s.", + HloOpcodeString(opcode).c_str()); } } @@ -1082,7 +978,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const ProgramShape& to_apply, tensorflow::gtl::ArraySlice dimensions) { if (arg_shapes.empty()) { - return InvalidArgument("Map expects at least one argument"); + return InvalidArgument("Map expects at least one argument."); } // All arguments must have the same shape. @@ -1113,7 +1009,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } return InvalidArgument( "Map operation requires all operands to have the same shape; got: " - "%s", + "%s.", Join(pieces, ", ").c_str()); } @@ -1122,7 +1018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu", + "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1130,7 +1026,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int i = 0; i < dimensions.size(); ++i) { if (dimensions[i] != i) { return InvalidArgument( - "Map requires monotonically increasing dimension numbers, found: %s ", + "Map requires monotonically increasing dimension numbers; got: %s.", Join(dimensions, ", ").c_str()); } } @@ -1139,7 +1035,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu", + "arity: %d, arguments: %zu.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1147,8 +1043,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& output_shape = to_apply.result(); if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( - "mapped computation's result has to be a scalar; " - "got: %s", + "Mapped computation's result has to be a scalar; got: %s.", ShapeUtil::HumanString(output_shape).c_str()); } @@ -1157,16 +1052,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::IsScalar(parameter_shape)) { return InvalidArgument( - "mapped computation's parameter has to be a scalar; " - "got parameter %d shape: %s", + "Mapped computation's parameter has to be a scalar; " + "got parameter %d shape: %s.", i, ShapeUtil::HumanString(parameter_shape).c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, *arg_shape)) { return InvalidArgument( - "mapped computation's parameter type has to match argument element " - "type; got parameter %d shape: %s, argument shape: %s", + "Mapped computation's parameter type has to match argument element " + "type; got parameter %d shape: %s, argument shape: %s.", i, ShapeUtil::HumanString(parameter_shape).c_str(), ShapeUtil::HumanString(*arg_shape).c_str()); } @@ -1187,31 +1082,31 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld", + "be a non-negative number, got %lld.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld", + "batch-norm-training to be at least 1; got %lld.", ShapeUtil::Rank(operand_shape)); } @@ -1232,7 +1127,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-training must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1241,7 +1136,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(offset_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1251,7 +1146,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1264,7 +1159,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of offset factor should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } @@ -1272,7 +1167,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1293,35 +1188,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld", + "be a non-negative number, got %lld.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld", + "batch-norm-inference to be at least 1; got %lld.", ShapeUtil::Rank(operand_shape)); } @@ -1342,7 +1237,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-inference must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1352,7 +1247,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of offset factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(offset_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1363,7 +1258,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of scale factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1374,7 +1269,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of mean is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1385,7 +1280,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of variance is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(variance_shape.element_type()).c_str()); } @@ -1398,7 +1293,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of offset factor should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } @@ -1406,7 +1301,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1414,7 +1309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of mean should be the same as feature count," "but the size of mean is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } @@ -1422,7 +1317,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of variance should be the same as feature count," "but the size of variance is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1455,7 +1350,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } @@ -1463,7 +1358,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld", + " rank(output_grad_shape) %lld.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } @@ -1491,14 +1386,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-grad must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } @@ -1507,7 +1402,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(output_grad_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1517,7 +1412,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1527,7 +1422,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1537,7 +1432,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1551,7 +1446,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of mean should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } @@ -1559,7 +1454,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1567,7 +1462,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of variance should be the same as feature count," "but the size of variance is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1578,7 +1473,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld", + "and the bound of output_grad_shape is %lld.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1596,7 +1491,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( - "Convolution with different element types: %s and %s", + "Convolution with different element types: %s and %s.", ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -1612,21 +1507,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" - "Window: %s\nDimension numbers: %s", + "Window: %s\nDimension numbers: %s.", window.DebugString().c_str(), dnums.DebugString().c_str()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( - "The LHS argument to a convolution should have rank %d.\n" - "lhs: %s", + "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs).c_str()); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( - "The RHS argument to a convolution should have rank %d.\n" - "lhs: %s", + "The RHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs).c_str()); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); @@ -1663,26 +1556,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( - "A dimension number is out of range in convolution: %s", + "A dimension number is out of range in convolution: %s.", dnums.DebugString().c_str()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } @@ -1706,7 +1599,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected LHS feature dimension (value %lld) to match RHS " "input feature dimension (value %lld); got (%s, %s)\n" - "Dimension numbers: {%s}", + "Dimension numbers: {%s}.", input_features, kernel_input_features, ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); @@ -1720,7 +1613,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "Window dimensions do not match RHS shape:\n\t" "RHS shape: %s\n\t" "Window: {%s}\n\t" - "Dimension numbers: {%s}", + "Dimension numbers: {%s}.", ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), dnums.ShortDebugString().c_str()); } @@ -1748,8 +1641,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const tensorflow::gtl::ArraySlice fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3, but got %lld", - fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); } #define RET_CHECK_RANK(x) \ if (x.dimensions_size() < fft_rank) { \ @@ -1762,7 +1654,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s", + return InvalidArgument("%s requires C64 input type, found %s.", FftType_Name(fft_type).c_str(), PrimitiveType_Name(in.element_type()).c_str()); } @@ -1770,7 +1662,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return in; case RFFT: { if (in.element_type() != F32) { - return InvalidArgument("RFFT requires F32 input type, found %s", + return InvalidArgument("RFFT requires F32 input type, found %s.", PrimitiveType_Name(in.element_type()).c_str()); } RET_CHECK_RANK(in); @@ -1779,7 +1671,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld", + "dimension %lld is %lld and should be %lld.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1792,7 +1684,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } case IRFFT: { if (in.element_type() != C64) { - return InvalidArgument("IRFFT requires C64 input type, found %s", + return InvalidArgument("IRFFT requires C64 input type, found %s.", PrimitiveType_Name(in.element_type()).c_str()); } RET_CHECK_RANK(in); @@ -1802,7 +1694,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld", + "fft_length, but dimension %lld is %lld and should be %lld.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1812,7 +1704,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld", + "dimension %d is %lld and should be %lld.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1850,8 +1742,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( - "attempting to reduce out-of-bounds dimension %lld in shape %s", - dimension, ShapeUtil::HumanString(arg).c_str()); + "Reducing out-of-bounds dimension %lld in shape %s.", dimension, + ShapeUtil::HumanString(arg).c_str()); } } TF_RETURN_IF_ERROR( @@ -1891,30 +1783,30 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { return InvalidArgument( - "select function must take 2 parameters, but " + "Select function must take 2 parameters, but " "takes %d parameter(s).", select_shape.parameters_size()); } const Shape& select_result_shape = select_shape.result(); if (!ShapeUtil::Compatible(select_result_shape, ShapeUtil::MakeShape(PRED, {}))) { - return Unimplemented("select function must have rank-0 PRED result."); + return InvalidArgument("Select function must have rank-0 PRED result."); } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(0))) { return InvalidArgument( - "select function's first parameter shape currently must " - "match the operand element shape. Got %s vs %s", + "Select function's first parameter shape currently must " + "match the operand element shape, but got %s vs %s.", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( - "select function's second parameter shape currently must " - "match the operand element shape. Got %s vs %s", + "Select function's second parameter shape currently must " + "match the operand element shape, but got %s vs %s.", ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } @@ -1931,8 +1823,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, window_result_shape)) { return InvalidArgument( - "source shape does not match the shape of window-reduced operand: " - "source(%s), window-reduced operand(%s)", + "Source shape does not match the shape of window-reduced operand: " + "source(%s), window-reduced operand(%s).", ShapeUtil::HumanString(source_shape).c_str(), ShapeUtil::HumanString(window_result_shape).c_str()); } @@ -1946,7 +1838,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " - "{%s}; strides: {%s}", + "{%s}; strides: {%s}.", message.c_str(), ShapeUtil::HumanString(arg).c_str(), Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); @@ -1969,7 +1861,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "slice index count does not match argument rank: %zu vs %lld", + "Slice index count does not match argument rank: %zu vs %lld.", starts.size(), ShapeUtil::Rank(arg)); } @@ -1979,7 +1871,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("negative start index to slice: %lld", + return InvalidArgument("Negative start index to slice: %lld.", start_index); } if (limit_index > arg.dimensions(dimension)) { @@ -1999,7 +1891,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("stride (%lld) must be positive", stride); + return InvalidArgument("Stride (%lld) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2023,20 +1915,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %lld must be rank1.", ShapeUtil::Rank(start_indices_shape)); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( - "dynamic slice start indices must be of integral type."); + "Dynamic slice start indices must be of integral type."); } const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s)", + "Dynamic slice start number of dimensions %lld (%s) must match rank " + "%lld of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape).c_str()); @@ -2044,7 +1936,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "dynamic slice index count does not match argument rank: %zu vs %lld", + "Dynamic slice index count does not match argument rank: %zu vs %lld.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2052,12 +1944,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("negative size index to dynamic slice: %lld", + return InvalidArgument("Negative size index to dynamic slice: %lld.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "slice dim size %lld greater than dynamic slice dimension: %lld", + "Slice dim size %lld greater than dynamic slice dimension: %lld.", slice_dim_size, input_dim_size); } VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, @@ -2086,20 +1978,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %lld must be rank1.", ShapeUtil::Rank(start_indices_shape)); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( - "dynamic update slice start indices must be of integral type."); + "Dynamic update slice start indices must be of integral type."); } const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s)", + "Dynamic update slice start number of dimensions %lld (%s) must match " + "rank %lld of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape).c_str()); @@ -2107,16 +1999,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "dynamic update slice update rank does not match argument rank: " - "%lld vs %lld", + "Dynamic update slice update rank does not match argument rank: " + "%lld vs %lld.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, update_shape)) { return InvalidArgument( - "dynamic update slice update element type does not match argument. " - "operand.element_type: %s vs update.element_type: %s", + "Dynamic update slice update element type does not match argument. " + "operand.element_type: %s vs update.element_type: %s.", PrimitiveType_Name(operand_shape.element_type()).c_str(), PrimitiveType_Name(update_shape.element_type()).c_str()); } @@ -2126,12 +2018,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "size index %lld to dynamic update slice must be >= 0", + "Size index %lld to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "update dim size %lld greater than dynamic slice dimension: %lld", + "Update dim size %lld greater than dynamic slice dimension: %lld.", update_dim_size, input_dim_size); } VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, @@ -2151,7 +2043,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "one of the reverse dimensions (%lld) is out-of-bounds in shape %s", + "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape).c_str()); } } @@ -2162,14 +2054,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& arg, int64 index) { if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( - "cannot infer shape: attempting to index into non-tuple: %s", + "Cannot infer shape: attempting to index into non-tuple: %s.", ShapeUtil::HumanString(arg).c_str()); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "cannot infer shape: attempt to index out of tuple bounds: %lld " - ">= %d in shape %s", + "Cannot infer shape: attempt to index out of tuple bounds: %lld " + ">= %d in shape %s.", index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); } @@ -2181,17 +2073,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& init) { // Check the number of parameters for given computations. if (condition.parameters_size() != 1) { - return InvalidArgument("condition must take 1 arguments; got %d", + return InvalidArgument("Condition must take 1 arguments; got %d.", condition.parameters_size()); } if (body.parameters_size() != 1) { - return InvalidArgument("body must take 1 arguments; got %d", + return InvalidArgument("Body must take 1 arguments; got %d.", body.parameters_size()); } auto shape_string = [&]() { return tensorflow::strings::Printf( - "condition: %s; body: %s; init: %s", + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition).c_str(), ShapeUtil::HumanString(body).c_str(), ShapeUtil::HumanString(init).c_str()); @@ -2199,15 +2091,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { - return InvalidArgument("condition must return a boolean; got %s", + return InvalidArgument("Condition must return a boolean; got %s.", shape_string().c_str()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || !ShapeUtil::Compatible(body.result(), init)) { return InvalidArgument( - "the parameter of condition and body, the result of the body, and init " - "must all have the same shape; got %s", + "The parameter of condition and body, the result of the body, and init " + "must all have the same shape; got %s.", shape_string().c_str()); } @@ -2219,7 +2111,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { - return InvalidArgument("predicate must be a boolean; got %s.", + return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate).c_str()); } @@ -2302,8 +2194,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s)", + "Reshape operation has mismatched element counts: from=%lld (%s) " + "to=%lld (%s).", ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::ElementsIn(inferred_shape), ShapeUtil::HumanString(inferred_shape).c_str()); @@ -2351,7 +2243,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { - return InvalidArgument("clamp op with different operand types: %s, %s, %s", + return InvalidArgument("Clamp with different operand types: %s, %s, %s.", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); @@ -2372,7 +2264,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } return Unimplemented( - "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); } @@ -2391,25 +2283,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } if (!compatible) { return InvalidArgument( - "operands to select must be the same shape; got %s and %s", + "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), ShapeUtil::HumanString(on_false).c_str()); } if (pred.element_type() != PRED) { return InvalidArgument( - "select's pred operand must have PRED element type; got %s", + "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred).c_str()); } - if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) { + if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + ShapeUtil::Rank(pred) == 0) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. return ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { - return Unimplemented( - "select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s", + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + " different from the other operands: %s.", ShapeUtil::HumanString(pred).c_str()); } } @@ -2427,7 +2320,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Call applied function arity must match number of arguments; got: " "arity: %d, arguments: %zu; computation signature: %s; argument " - "shapes: [%s]", + "shapes: [%s].", to_apply.parameters_size(), arg_shapes.size(), computation_signature.c_str(), argument_shapes.c_str()); } @@ -2439,7 +2332,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::Compatible(arg_shape, param_shape)) { return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " - "argument shape: %s", + "argument shape: %s.", i, ShapeUtil::HumanString(param_shape).c_str(), ShapeUtil::HumanString(arg_shape).c_str()); } @@ -2454,40 +2347,40 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& dim_numbers) { if (!c_is_sorted(dim_numbers.output_window_dims())) { return InvalidArgument( - "Output window dimensions in gather op must be ascending; got: %s", + "Output window dimensions in gather op must be ascending; got: %s.", Join(dim_numbers.output_window_dims(), ", ").c_str()); } if (c_adjacent_find(dim_numbers.output_window_dims()) != dim_numbers.output_window_dims().end()) { return InvalidArgument( - "Output window dimensions in gather op must not repeat; got: %s", + "Output window dimensions in gather op must not repeat; got: %s.", Join(dim_numbers.output_window_dims(), ", ").c_str()); } const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size(); + output_window_dim_count + gather_indices_shape.size() - 1; for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { int64 window_index = dim_numbers.output_window_dims(i); if (window_index < 0 || window_index >= output_shape_rank) { return InvalidArgument( "Window index %d in gather op is out of bounds; got %lld, but should " - "have been in" - "[0,%lld)", + "have been in [0,%lld).", i, window_index, output_shape_rank); } } if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape.back()) { + gather_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "There must be exactly as many elements in gather_dims_to_operand_dims " - "as there are elements in the last dimension of %%gather_indices; got: " - "%d, expected %lld", + "Gather op has %d elements in gather_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of gather_indices is " + "%lld. These two numbers must be equal.", dim_numbers.gather_dims_to_operand_dims_size(), - gather_indices_shape.back()); + dim_numbers.index_vector_dim(), + gather_indices_shape[dim_numbers.index_vector_dim()]); } for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { @@ -2496,7 +2389,7 @@ static Status ValidateGatherDimensionNumbers( gather_dim_to_input_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld", + "got: %d->%lld.", input_shape.dimensions_size(), i, gather_dim_to_input_dim); } } @@ -2511,7 +2404,7 @@ static Status ValidateGatherDimensionNumbers( sorted_gather_dims_to_operand_dims.end()) { return InvalidArgument( "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " - "got: %s", + "got: %s.", Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); } @@ -2519,7 +2412,7 @@ static Status ValidateGatherDimensionNumbers( if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid elided_window_dims set in gather op; valid range is [0, " - "%d), got: %lld", + "%d), got: %lld.", input_shape.dimensions_size(), elided_dim); } } @@ -2534,7 +2427,7 @@ static Status ValidateGatherDimensionNumbers( dim_numbers.elided_window_dims().end()) { return InvalidArgument( "Repeated dimensions not allowed in elided_window_dims in gather op; " - "got: %s", + "got: %s.", Join(dim_numbers.elided_window_dims(), ", ").c_str()); } @@ -2550,24 +2443,33 @@ static Status ValidateGatherDimensionNumbers( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( gather_indices_shape, "gather indices operand of gather op")); - if (gather_indices_shape.dimensions_size() < 1) { + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( - "Gather indices parameter must at least of rank 1; got %s", + "Gather indices parameter must be an integral tensor; got %s.", ShapeUtil::HumanString(gather_indices_shape).c_str()); } - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if + // index_vector_dim is rank(P). The bounds of this expanded shape is + // stored in expanded_gather_indices_shape. + + if (gather_indices_shape.dimensions_size() < + gather_dim_numbers.index_vector_dim() || + gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather indices parameter must be an integral tensor; got %s", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + "Gather index leaf dimension must be within [0, rank(gather_indices) + " + "1). rank(gather_indices) is %d and gather index leaf dimension is " + "%lld.", + gather_indices_shape.dimensions_size(), + gather_dim_numbers.index_vector_dim()); } std::vector expanded_gather_indices_shape; - // We implicitly reshape gather indices of shape P[N] to P[N,1]. expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); c_copy(gather_indices_shape.dimensions(), std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == 1) { + if (expanded_gather_indices_shape.size() == + gather_dim_numbers.index_vector_dim()) { expanded_gather_indices_shape.push_back(1); } @@ -2577,7 +2479,7 @@ static Status ValidateGatherDimensionNumbers( if (window_bounds.size() != input_shape.dimensions_size()) { return InvalidArgument( "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d", + "len(window_bounds)=%lu, input_shape.rank=%d.", window_bounds.size(), input_shape.dimensions_size()); } @@ -2587,7 +2489,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "All components of the window index in a gather op must either be a " "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s", + "output_window_bounds=%s, elided_window_bounds=%s.", window_bounds.size(), Join(gather_dim_numbers.output_window_dims(), ",").c_str(), Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); @@ -2600,7 +2502,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Window bound at index %d in gather op is out of range, must be " "within " - "[0, %lld), got %lld", + "[0, %lld), got %lld.", i, corresponding_input_bound + 1, window_bound); } } @@ -2609,7 +2511,7 @@ static Status ValidateGatherDimensionNumbers( if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { return InvalidArgument( "Gather op can only elide window indices with bound 1, but bound is " - "%lld for index %lld at position %d", + "%lld for index %lld at position %d.", window_bounds[gather_dim_numbers.elided_window_dims(i)], gather_dim_numbers.elided_window_dims(i), i); } @@ -2632,6 +2534,9 @@ static Status ValidateGatherDimensionNumbers( } current_bound = window_bounds[window_dims_seen++]; } else { + if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { + gather_dims_seen++; + } current_bound = expanded_gather_indices_shape[gather_dims_seen++]; } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0d3045213db2230da3e18ffcb1a9923250560b64..f1f7b50902d899c0c629c3098d80fc400fb1388d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -46,15 +46,15 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(UnaryOperation operation, - const Shape& arg); + static StatusOr InferUnaryOpShape(HloOpcode opcode, + const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, const HloInstruction* operand); // Infers the shape produced by applying the given binary operation to the // given input shapes. static StatusOr InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, @@ -62,8 +62,8 @@ class ShapeInference { // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(TernaryOperation operation, - const Shape& lhs, const Shape& rhs, + static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, + const Shape& rhs, const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const HloInstruction* lhs, @@ -73,7 +73,7 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - VariadicOperation operation, + HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, @@ -216,6 +216,13 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + // Infers the shape produced by a kGenerateToken operation. Trivially this + // shape is always a TOKEN shape. However, ShapeInference serves two purposes: + // inferring shapes and checking operand shapes. This method verifies that the + // operand shapes are all TOKENs. + static StatusOr InferTokenShape( + tensorflow::gtl::ArraySlice arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -268,7 +275,7 @@ class ShapeInference { // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); // Helper for inferring the shape of Clamp ops. @@ -284,7 +291,7 @@ class ShapeInference { // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). static StatusOr InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs); + HloOpcode operation, const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations @@ -292,8 +299,7 @@ class ShapeInference { // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7eb120843fd841d841048eeaefd895fde96d133c..6d017dffe2d8f927abad4a62bff7fe41bc871975 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = ShapeInference::InferUnaryOpShape( - UnaryOperation::UNOP_NEGATE, matrix_shape); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } @@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } @@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("operands to select must be the same shape")); + HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), - matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, + matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); @@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, - matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + auto inferred_status = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) - .ok()); - // Dimension mismatch ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_64_, vector_32_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_64_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_32_, vector_64_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); - // Dimension mismatch, where one operand is a scalar + // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, const tensorflow::gtl::ArraySlice& bcast) { - return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, - lhs, rhs, bcast); + return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, + bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); @@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = ShapeInference::InferVariadicOpShape( - VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + StatusOr result = + ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); @@ -340,7 +336,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("source shape does not match")); + HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -351,7 +347,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must take 2 parameters")); + HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -362,7 +358,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must have rank-0 PRED")); + HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -373,7 +369,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's first parameter")); + HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -384,7 +380,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's second parameter")); + HasSubstr("Select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } @@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) { TEST_F(ShapeInferenceTest, InferCompareShapeEq) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) { TEST_F(ShapeInferenceTest, InferCompareShapeGe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) { TEST_F(ShapeInferenceTest, InferCompareShapeGt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) { TEST_F(ShapeInferenceTest, InferCompareShapeLe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) { TEST_F(ShapeInferenceTest, InferCompareShapeLt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) { TEST_F(ShapeInferenceTest, InferCompareShapeNe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -906,7 +902,7 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("Dot only supports rank")); } // 3D 2D: error @@ -918,7 +914,7 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch and contracting dimension number mismatch")); + HasSubstr("Batch and contracting dimension number mismatch")); } // vector vector -> scalar @@ -1024,7 +1020,7 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("must specify one contracting dimension for both " + HasSubstr("Must specify one contracting dimension for both " "lhs and rhs")); } @@ -1044,7 +1040,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension numbers and sizes must match")); } // BatchMatMul with different batch dimension numbers fails. @@ -1063,7 +1059,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers must precede non-batch")); + HasSubstr("Batch dimension numbers must precede non-batch")); } // BatchMatMul with out-of-range dimension numbers fails. @@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {1}); + auto inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {0}); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); - inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {0}); + inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {1}); + inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } @@ -1162,61 +1158,61 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {}); + auto inferred_status_error1 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("automatic")); + HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + auto inferred_status_error2 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + auto inferred_status_error3 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - HasSubstr("size of broadcast_dimensions has to match")); + HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -1242,7 +1238,7 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("condition must take 1 arguments")); + HasSubstr("Condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); @@ -1250,14 +1246,14 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("body must take 1 arguments")); + HasSubstr("Body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("condition must return a boolean")); + HasSubstr("Condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = @@ -1301,13 +1297,13 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: -1")); + HasSubstr("dimension out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: 1")); + HasSubstr("dimension out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( @@ -1315,21 +1311,20 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation.")); + HasSubstr("Expected non-tuple argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT( - inferred_status_error5.status().error_message(), - HasSubstr("cannot concatenate arrays with different element types")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("cannot concatenate arrays that differ in " + HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); } @@ -1467,7 +1462,7 @@ TEST_F(ShapeInferenceTest, Conditional) { ShapeUtil::MakeProgramShape({vector_64_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), - HasSubstr("predicate must be a boolean")); + HasSubstr("Predicate must be a boolean")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, @@ -1530,11 +1525,17 @@ TEST_F(ShapeInferenceTest, BadSlice) { class GatherShapeInferenceTest : public ShapeInferenceTest { protected: + const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); + const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); const Shape s64_4d_tensor_10_9_8_7_1_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); const Shape s64_4d_tensor_10_9_8_7_5_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape s64_4d_tensor_5_10_9_7_6_ = + ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6}); + const Shape s64_4d_tensor_10_9_5_7_6_ = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); const Shape f32_5d_tensor_50_49_48_47_46_ = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -1548,7 +1549,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) @@ -1562,7 +1564,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{1}, /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}), + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) @@ -1576,7 +1579,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4}, /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}), + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) @@ -1591,7 +1595,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1599,12 +1604,85 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { + // This is equivalent to a dynamic slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3, 4}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { + // The gather indices "tensor" is a scalar S here that's used to slice out + // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0), + /*window_bounds=*/{1, 30, 29, 28, 27})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) + << ShapeUtil::HumanString(gather_shape); +} + TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1617,7 +1695,8 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { s64_vector_32_, tuple_shape_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1625,25 +1704,13 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( - s64_vector_32_, s32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), - /*window_bounds=*/{64, 1}); - ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather indices parameter must at least of rank 1")) - << statusor.status(); -} - TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1658,7 +1725,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1674,7 +1742,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1690,7 +1759,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1698,6 +1768,22 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 9}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 4 in gather op is out of bounds")) + << statusor.status(); +} + TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( @@ -1705,7 +1791,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1722,7 +1809,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1738,7 +1826,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1755,15 +1844,15 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "There must be exactly as many elements in " - "gather_dims_to_operand_dims " - "as there are elements in the last dimension of %gather_indices")) + HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of " + "gather_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1774,7 +1863,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1791,7 +1881,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1808,7 +1899,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1822,7 +1914,8 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1838,7 +1931,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1855,7 +1949,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1864,5 +1959,22 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } +TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/32), + /*window_bounds=*/{30, 29, 28, 27, 26}); + + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather index leaf dimension must be within [0, " + "rank(gather_indices) + 1)")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 6e9986165f7eaf71a964b42b734a5ae5db5e45d7..7d7dcac10b65933d1c81b8aca77465932694bfdb 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include #include #include @@ -25,11 +24,10 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { using ::tensorflow::strings::Appendf; @@ -68,6 +66,8 @@ ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { return *this; } +ShapedBuffer::~ShapedBuffer() {} + void ShapedBuffer::clear() { for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. @@ -104,18 +104,6 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } -/* static */ -StatusOr> ScopedShapedBuffer::MakeScoped( - ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { - auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( - shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), - allocator, shaped_buffer->device_ordinal())); - scoped_buffer->buffers_ = shaped_buffer->buffers(); - shaped_buffer->clear(); - - return std::move(scoped_buffer); -} - ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, DeviceMemoryAllocator* allocator, @@ -128,26 +116,46 @@ ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, DeviceMemoryAllocator* allocator) : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} -ScopedShapedBuffer::~ScopedShapedBuffer() { +ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) + : ShapedBuffer(static_cast(s)), allocator_(s.allocator_) { + // Null out s.allocator_ so it doesn't try to free anything in its destructor. + s.allocator_ = nullptr; +} + +ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { + Deallocate(); + + *static_cast(this) = std::move(static_cast(s)); + allocator_ = s.allocator_; + // Null out s.allocator_ so it doesn't try to free anything in its destructor. + s.allocator_ = nullptr; + return *this; +} + +ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); } + +ShapedBuffer ScopedShapedBuffer::release() { + ShapedBuffer shaped_buffer(static_cast(*this)); + buffers_ = ShapeTree(); + return shaped_buffer; +} + +void ScopedShapedBuffer::Deallocate() { + // allocator_ will be null if we were moved-from. + if (allocator_ == nullptr) { + return; + } // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - std::set deallocated_opaques; + tensorflow::gtl::FlatSet deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && - deallocated_opaques.count(memory_base.opaque()) == 0) { - deallocated_opaques.insert(memory_base.opaque()); - TF_CHECK_OK( - this->allocator_->Deallocate(this->device_ordinal(), &memory_base)); + deallocated_ptrs.insert(memory_base.opaque()).second) { + TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base)); } } } -std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = MakeUnique(std::move(*this)); - buffers_ = ShapeTree(); - return shaped_buffer; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index b816df8385ef65b0b69ede1d6e65a1991b4bd7c6..905a7e82e621f2bf4588b71be5dbab20f892cafe 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -30,6 +30,8 @@ limitations under the License. namespace xla { +class ScopedShapedBuffer; + // Class which encapsulates a buffer or set of buffers containing data of a // particular XLA shape. class ShapedBuffer { @@ -41,8 +43,19 @@ class ShapedBuffer { // determines the number of device allocations (DeviceMemoryBase) held by the // ShapedBuffer. ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, - const perftools::gputools::Platform* platform, - int device_ordinal); + const se::Platform* platform, int device_ordinal); + + // Movable, but not copyable. + ShapedBuffer(ShapedBuffer&& s); + ShapedBuffer& operator=(ShapedBuffer&&); + ShapedBuffer(const ShapedBuffer&) = delete; + ShapedBuffer& operator=(const ShapedBuffer&) = delete; + + // Prevent (some forms of) accidental object slicing. + ShapedBuffer(const ScopedShapedBuffer&) = delete; + ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; + + virtual ~ShapedBuffer(); // Returns the shape of the on-host representation of the data held by this // ShapedBuffer. @@ -52,48 +65,36 @@ class ShapedBuffer { // ShapedBuffer. const Shape& on_device_shape() const { return on_device_shape_; } - const perftools::gputools::Platform* platform() const { return platform_; } + const se::Platform* platform() const { return platform_; } int device_ordinal() const { return device_ordinal_; } // Return the root buffer of the shape (shape index {}). - const perftools::gputools::DeviceMemoryBase& root_buffer() const { + const se::DeviceMemoryBase& root_buffer() const { return buffer(/*index=*/{}); } // Returns the buffer at the given shape index where index is defined as in // ShapeUtil::GetSubshape. - const perftools::gputools::DeviceMemoryBase& buffer( - const ShapeIndex& index) const { + const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const { return buffers_.element(index); } // Sets the device memory buffer at the given index. - void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& index) { + void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) { *buffers_.mutable_element(index) = buffer; } // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. - const ShapeTree& buffers() const { - return buffers_; - } - ShapeTree& buffers() { - return buffers_; - } + const ShapeTree& buffers() const { return buffers_; } + ShapeTree& buffers() { return buffers_; } // Set all device memory pointers in the object to null. void clear(); string ToString() const; - ShapedBuffer(ShapedBuffer&& s); - ShapedBuffer& operator=(ShapedBuffer&&); - protected: - ShapedBuffer(const ShapedBuffer&) = delete; - ShapedBuffer& operator=(const ShapedBuffer&) = delete; - // The shape of the data when represented on the host. Shape on_host_shape_; @@ -101,13 +102,13 @@ class ShapedBuffer { Shape on_device_shape_; // The platform the memory is allocated on. - const perftools::gputools::Platform* platform_; + const se::Platform* platform_; // The device the memory is allocated on. int device_ordinal_; // The tree of device buffers. Its shape is on_device_shape(). - ShapeTree buffers_; + ShapeTree buffers_; }; std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); @@ -115,40 +116,60 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); // ShapedBuffer derived class which allocates all internal buffers on // construction and deallocates the memory when the object is // destructed. +// +// TODO(timshen): Remove inheritance between ScopedShapedBuffer and +// ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer +// as a ShapedBuffer, because in that case we should just be able to pass around +// our ShapeTree. Inheritance only adds complexity. See +// discussion in cl/192849370. class ScopedShapedBuffer : public ShapedBuffer { public: - // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the - // deallocation of the device memory held in the shaped buffer. All device - // memory pointers in the given ShapedBuffer are set to null. - static StatusOr> MakeScoped( - ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); - - // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. - ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, - DeviceMemoryAllocator* allocator, int device_ordinal); + // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. + explicit ScopedShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + DeviceMemoryAllocator* allocator, + int device_ordinal); // Create a ScopedShapedBuffer by taking over the memory from the incoming // ShapedBuffer. - ScopedShapedBuffer(ShapedBuffer shaped_buffer, - DeviceMemoryAllocator* allocator); + explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer, + DeviceMemoryAllocator* allocator); + + // Movable, but not copyable. + ScopedShapedBuffer(ScopedShapedBuffer&& s); + ScopedShapedBuffer& operator=(ScopedShapedBuffer&&); + ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; + ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; + + // All buffers in the shape are deallocated on destruction. + ~ScopedShapedBuffer() override; // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Release all device memory owned by this ScopedShapedBuffer and - // return the device memory pointers in the form of a - // ShapedBuffer. The returned ShapedBuffer takes over the memory - // from the ScopedShapedBuffer. The resulting ScopedShapedBuffer can - // only be destroyed. - std::unique_ptr release(); + // Sets the device memory buffer at the given index. + // + // If the given buffer's device memory is non-null, its device_ordinal and + // allocator must match those in `this`. + void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { + if (!buffer.is_null()) { + CHECK_EQ(buffer.device_ordinal(), device_ordinal()); + CHECK_EQ(buffer.allocator(), allocator_); + *buffers_.mutable_element(index) = buffer.Forget(); + } else { + *buffers_.mutable_element(index) = se::DeviceMemoryBase(); + } + } - // All buffers in the shape are deallocated on destruction. - virtual ~ScopedShapedBuffer(); + // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from + // this ScopedShapedBuffer, without freeing any of the associated memory. + // + // It's the caller's job to ensure that the memory contained therein is freed. + TF_MUST_USE_RESULT ShapedBuffer release(); protected: - ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; - void operator=(const ScopedShapedBuffer&) = delete; + void Deallocate(); DeviceMemoryAllocator* allocator_; }; diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fc243667911651c788e3c1e5f1d39d86170f1ad --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace { + +TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { + TF_ASSERT_OK_AND_ASSIGN(auto platforms, + xla::PlatformUtil::GetSupportedPlatforms()); + ASSERT_FALSE(platforms.empty()); + auto* platform = platforms[0]; + TF_ASSERT_OK_AND_ASSIGN(auto executors, + xla::PlatformUtil::GetStreamExecutors(platform)); + xla::StreamExecutorMemoryAllocator allocator(platform, executors); + const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + const int kDeviceOrdinal = 0; + auto scoped_buffer = tensorflow::MakeUnique( + shape, shape, &allocator, kDeviceOrdinal); + std::unique_ptr buffer = std::move(scoped_buffer); + buffer = nullptr; +} + +class TestAllocator : public DeviceMemoryAllocator { + public: + TestAllocator() + : DeviceMemoryAllocator(PlatformUtil::GetDefaultPlatform().ValueOrDie()) { + } + + ~TestAllocator() override { + if (!allocations_.empty()) { + ADD_FAILURE() << "Some allocations not freed!"; + } + } + + // Pull in two-arg overload of Allocate. + using DeviceMemoryAllocator::Allocate; + + StatusOr Allocate(int device_ordinal, uint64 size, + bool /*retry_on_failure*/) override { + // By contract, we must return null if size == 0. + if (size == 0) { + return OwningDeviceMemory(); + } + void* buf = malloc(size); + allocations_.insert({device_ordinal, buf}); + return OwningDeviceMemory(se::DeviceMemoryBase(buf, size), device_ordinal, + this); + } + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { + if (mem.is_null()) { + return Status::OK(); + } + + auto it = allocations_.find({device_ordinal, mem.opaque()}); + if (it == allocations_.end()) { + ADD_FAILURE() << "Allocation not found (double free?)"; + } else { + free(mem.opaque()); + allocations_.erase(it); + } + return Status::OK(); + } + + bool AllowsAsynchronousDeallocation() const override { return false; } + + private: + std::set> allocations_; +}; + +TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { + Shape s = ShapeUtil::MakeShape(F32, {1}); + TestAllocator allocator; + ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + sb1.set_buffer( + allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), + /*index=*/{}); + + ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + sb2.set_buffer( + allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), + /*index=*/{}); + + sb1 = std::move(sb2); + + // TestAllocator's destructor checks that all memory was freed. +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index a776d745f4e56ca4f3d2480740259832bbc85011..18e2651abb1600a7b9ffb79de887b8795717e55e 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { namespace source_map_util { -// Creates an INVALID_ARUGMENT status with the given format string. +// Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable // and append it to the status message for source mapping to user code. diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 2f36e2b16e0f2eed10aef811dd3cceeba6a5b8a9..c4d01562c4e32225ebb984d8fcd93ec3fa86e403 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -25,24 +25,20 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" -namespace se = ::perftools::gputools; - namespace xla { /* static */ tensorflow::mutex TransferManager::platform_transfer_manager_mutex_( tensorflow::LINKER_INITIALIZED); -/* static */ std::map* +/* static */ std::map* TransferManager::GetPlatformTransferManagers() { - static auto* r = - new std::map; + static auto* r = new std::map; return r; } Status TransferManager::TransferArrayToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - const perftools::gputools::DeviceMemoryBase& dest) { + se::StreamExecutor* executor, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) << "On-device representation of " @@ -61,8 +57,8 @@ Status TransferManager::TransferArrayToDevice( } StatusOr> TransferManager::TransferArrayFromDevice( - perftools::gputools::StreamExecutor* executor, const Shape& shape, - const perftools::gputools::DeviceMemoryBase& source) { + se::StreamExecutor* executor, const Shape& shape, + const se::DeviceMemoryBase& source) { TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) << "Shape " << ShapeUtil::HumanString(shape) << " has a differently shaped representation on-device: " @@ -112,8 +108,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( } Status TransferManager::WriteTupleIndexTables( - perftools::gputools::StreamExecutor* executor, - const ShapedBuffer& device_buffer) { + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "Writing tuple index tables for " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -180,7 +175,7 @@ Status TransferManager::TransferBufferToDevice( return Status::OK(); } -StatusOr> TransferManager::AllocateShapedBuffer( +StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { @@ -192,31 +187,23 @@ StatusOr> TransferManager::AllocateShapedBuffer( const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); - auto shaped_buffer = WrapUnique(new ShapedBuffer( - on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); + ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator, + device_ordinal); // Allocate an appropriate sized buffer for each element in the shape // including the tuple pointer arrays. - for (auto& pair : shaped_buffer->buffers()) { + for (auto& pair : shaped_buffer.buffers()) { const ShapeIndex& index = pair.first; se::DeviceMemoryBase& memory_base = pair.second; const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); - TF_ASSIGN_OR_RETURN(memory_base, - allocator->Allocate(shaped_buffer->device_ordinal(), + TF_ASSIGN_OR_RETURN(auto memory, + allocator->Allocate(shaped_buffer.device_ordinal(), GetByteSizeRequirement(subshape))); + // Move the allocated buffer into the ScopedShapedBuffer, which owns it. + memory_base = memory.Forget(); } return std::move(shaped_buffer); } -StatusOr> -TransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, - DeviceMemoryAllocator* allocator, - int device_ordinal) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr unscoped_buffer, - AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); - return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 9f2b5c4aecf0b52f610171e0c2755de577b2bd9e..43a8092b06fba0e2495bce0ee1a309c85a908273 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -42,7 +42,7 @@ class TransferManager { virtual ~TransferManager() {} // Returns the ID of the platform that this transfer manager acts on. - virtual perftools::gputools::Platform::Id PlatformId() const = 0; + virtual se::Platform::Id PlatformId() const = 0; // Returns the shape of the on-device representation for the given shape on // the host. This is intended for use with ShapedBuffer where buffers are @@ -58,48 +58,45 @@ class TransferManager { // DeviceShape(literal_shape) must be compatible, but need not have the same // layout. virtual StatusOr> TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const ShapedBuffer& device_buffer) = 0; + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, // but need not have the same layout - virtual Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - const ShapedBuffer& device_buffer) = 0; + virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - const perftools::gputools::DeviceMemoryBase& dest); + Status TransferArrayToDevice(se::StreamExecutor* executor, + const LiteralSlice& literal, + const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( - perftools::gputools::StreamExecutor* executor, const Shape& shape, - const perftools::gputools::DeviceMemoryBase& source); + se::StreamExecutor* executor, const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. - virtual Status TransferLiteralToInfeed( - perftools::gputools::StreamExecutor* executor, - const Literal& literal) = 0; + virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( - tensorflow::gtl::ArraySlice - executor) = 0; + tensorflow::gtl::ArraySlice executor) = 0; // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. - Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, + Status WriteTupleIndexTables(se::StreamExecutor* executor, const ShapedBuffer& device_buffer); // Determines the byte size requirement for the given shape on the underlying @@ -107,13 +104,10 @@ class TransferManager { // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - // Allocate a ShapedBuffer which can hold data with the given on-host + // Allocates a ScopedShapedBuffer which can hold data with the given on-host // shape. The on-device shape may be different as indicated by // HostShapeToDeviceShape. - StatusOr> AllocateShapedBuffer( - const Shape& on_host_shape, DeviceMemoryAllocator* allocator, - int device_ordinal); - StatusOr> AllocateScopedShapedBuffer( + StatusOr AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal); @@ -127,13 +121,13 @@ class TransferManager { // Precondition: a platform kind must not be registered more than once. typedef std::unique_ptr (*TransferManagerCreationFunction)(); static void RegisterTransferManager( - perftools::gputools::Platform::Id platform_id, + se::Platform::Id platform_id, TransferManagerCreationFunction transfer_manager); // Returns the transfer manager singleton pointer if it is available for the // given platform, or an error status if it is not. static StatusOr GetForPlatform( - const perftools::gputools::Platform* platform); + const se::Platform* platform); protected: // Transfer a memory block of the given size from 'source' buffer to the @@ -143,35 +137,32 @@ class TransferManager { // // source is the source data that must be in the target-dependent layout that // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) = 0; + virtual Status TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, const void* source) = 0; // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, int64 size, - void* destination); + virtual Status TransferBufferFromDevice(se::StreamExecutor* executor, + const se::DeviceMemoryBase& source, + int64 size, void* destination); // Transfer a memory block of the given size from 'source' buffer to the given // destination of the device. // // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source, perftools::gputools::DeviceMemoryBase* destination); + virtual Status TransferBufferToDevice(se::StreamExecutor* executor, + int64 size, const void* source, + se::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + se::StreamExecutor* executor, + tensorflow::gtl::ArraySlice elements, + const Shape& shape, se::DeviceMemoryBase* region) = 0; private: // The mutex that guards the platform-to-transfer manager map. @@ -186,8 +177,7 @@ class TransferManager { }; // Map from platform kind to transfer manager singleton. - static std::map* - GetPlatformTransferManagers(); + static std::map* GetPlatformTransferManagers(); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 83185ac49e9b7c386d10d1cbc4e20dcdfdfd6cae..49e1f873192f800056a2272f7d4f698898b0f8a1 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,7 +35,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode()) { + if (HloOpcode::kDot != dot.opcode() || + dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { return {}; } @@ -44,6 +45,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); + } else if (ShapeUtil::Rank(operand.shape()) != 2) { + return {}; } } @@ -74,23 +77,39 @@ using InstructionOperandsPair = // Folds the operands of `dot` that are foldable transposes. `computation` is // the parent HLO computation of `dot`. -// -// Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair) { - auto* dot = pair.first; - std::vector instructions_to_fuse(1, dot); - for (const int64 operand_index : pair.second) { - instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); - } - - // Early-exit if no operands are foldable. - if (instructions_to_fuse.size() == 1) { - return false; +Status FoldTransposeIntoDot(InstructionOperandsPair pair) { + HloInstruction* dot = pair.first; + + DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* new_lhs = dot->mutable_operand(0); + HloInstruction* new_rhs = dot->mutable_operand(1); + + CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); + + for (int64 operand_index : pair.second) { + // We've checked that there aren't any batch dimensions and that the inputs + // are rank 2, and shape inference guarantees that there is exactly one + // contracting dimension. + if (operand_index == 0) { + CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_lhs_contracting_dimensions( + 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + new_lhs = new_lhs->mutable_operand(0); + } else { + CHECK_EQ(operand_index, 1); + CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + new_rhs = new_rhs->mutable_operand(0); + } } - dot->parent()->CreateFusionInstruction( - instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); - return true; + std::unique_ptr new_dot = HloInstruction::CreateDot( + dot->shape(), new_lhs, new_rhs, new_dim_numbers); + return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } // Folds the operands of `convolution` that are foldable transposes. @@ -195,7 +214,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { std::make_pair(instruction, operand_indices)); } } - return tensorflow::Status::OK(); + return Status::OK(); }; for (auto* comp : module->MakeNonfusionComputations()) { @@ -204,7 +223,8 @@ StatusOr TransposeFolding::Run(HloModule* module) { bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair); + TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); + changed = true; } for (InstructionOperandsPair& pair : foldable_convolutions) { changed |= FoldTransposeIntoConvolution(pair); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index caa1a111ad880b9dee62c1c94e32e8275c196fbf..cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,13 +19,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -34,6 +36,8 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase { }; TEST_F(TransposeFoldingTest, FoldDotTranspose) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); - - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3]{1,0} parameter(0) + y = f32[2,3]{1,0} parameter(1) + transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0} + ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.size()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = *instruction_set.begin(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + FoldTranspose(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); +TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { - auto builder = HloComputation::Builder("entry_computation"); - // 2x1 - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); - // 3x2 - HloInstruction* const1 = - builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); - HloInstruction* transpose0 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1, 3}), - /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); - - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); - - for (auto* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK_EQ(2, instruction->operand_count()); - EXPECT_EQ(const0, instruction->operand(0)); - EXPECT_EQ(const1, instruction->operand(1)); - } - } + string hlo_string = R"( +HloModule FoldDotTransposeConstant + +ENTRY entry_computation { + constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} + constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} + ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + FoldTranspose(module.get()); - // The created fusion instruction should contain two parameters, two - // transposes (one for each parameter) and one dot. - EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instruction_count()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Constant(), op::Constant(), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1)); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -149,11 +172,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - HloModule module("fuse_with_constant_operands"); + auto module = CreateNewModule("fuse_with_constant_operands"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(mul)); - HloInstruction* call = module.OutlineExpressionFromComputation( - {add, sub, mul}, "", entry_computation); + module->AddEntryComputation(builder.Build(mul)); + HloInstruction* call = module->OutlineExpressionFromComputation( + {add, sub, mul}, "entry", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. @@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instruction_count()); } -TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); - - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - - HloInstruction* call = module.OutlineExpressionFromComputation( - {transpose_y, dot}, "outlined", entry_computation); +TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) { + string hlo_string = R"( +HloModule FoldDotTransposeInCall - FoldTranspose(&module); - - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(call)) - << "call is not in entry_computation."; - CHECK(instruction_set.empty()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = - call->called_computations().front()->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); +callee { + name.0 = f32[2,3]{1,0} parameter(0) + name.1 = f32[2,3]{1,0} parameter(1) + transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0} + ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); +ENTRY entry_computation { + y = f32[2,3]{1,0} parameter(1) + x = f32[2,3]{1,0} parameter(0) + ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + FoldTranspose(module.get()); + + const HloComputation* callee = module->GetComputationWithName("callee"); + ASSERT_NE(callee, nullptr); + EXPECT_THAT(callee->root_instruction(), + op::Dot(op::Parameter(1), op::Parameter(0), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } // Test that a two dimension swap of the kernel gets folded into convolution. @@ -222,7 +227,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -240,10 +245,10 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -275,7 +280,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -293,10 +298,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -334,7 +339,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -351,10 +356,10 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -398,7 +403,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -415,10 +420,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 657a8fe09ae9df906d695f7f49df72500d611792..eb6d1ada6b553f998fe06917dfdf0b5092cd79cd 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(domain, domain->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { // A kSlice instruction aliases its operand if the backend lowers it to an // in-place implementation. @@ -588,4 +596,206 @@ void TuplePointsToAnalysis::InstructionToString( }); } +bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'instruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> +TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const { + std::vector> uses; + const PointsToSet::BufferList& points_to = + GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index) const { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = + GetAllUsesOfInstructionAtIndex(fused_param, operand_index); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following qualifications: +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd output fusion instruction where the only use of +// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused +// root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. +bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c3743b150168ebcf1051050dc511e50c43108c4f..c0d82414806d9a6ff57aec59d077f444137fec9a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -248,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; @@ -256,6 +257,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + private: explicit TuplePointsToAnalysis( const HloModule* module, @@ -310,6 +328,13 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { return &per_instruction_[id]; } + std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const; + bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* fusion, + const int64 use_operand_index) const; + // The module this analysis is performed on. const HloModule* module_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dec446d4dac650ba43992f7870764eedc80cb2cf..5734f284071944bc22011405898cf86f33dc48d7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -805,5 +805,373 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { Run(/*add_additional_gte0_user=*/true); } +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE( + points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + add_operand, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {}, + call, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) { + Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32}); + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16}); + + auto builder = HloComputation::Builder(TestName() + "_fusion"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "full")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, broadcast_shape, "small")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, param1, {0})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, param0, broadcast)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index 113c2e2bd9f73a2b0c783103d7f2da9534bc97c3..77bdcc9de0d830991208a1db271d009bccaf550e 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -30,10 +30,17 @@ limitations under the License. namespace xla { +TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : + exclude_entry_computation_(exclude_entry_computation) {} + StatusOr TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue worklist; for (auto* computation : module->computations()) { + if (exclude_entry_computation_ && + computation == module->entry_computation()) { + continue; + } for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -78,7 +85,6 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), @@ -108,10 +114,10 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // | // GTE if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { - changed = true; HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); + changed = true; TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); for (HloInstruction* user : element_source->users()) { if (user->opcode() == HloOpcode::kTuple || diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8..750950188312c5077d487f2feef0606f07839432 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -27,13 +27,20 @@ namespace xla { // the module. class TupleSimplifier : public HloPassInterface { public: - TupleSimplifier() {} + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} tensorflow::StringPiece name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index ca9ae91281fce5ee061d066fc3e538dbbc09f6b3..d3635eae81ec7017f9bf6a69250d10716309c9ec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { TF_ASSERT_OK(changed_status.status()); EXPECT_EQ(change_expected, changed_status.ValueOrDie()); } + void Run(HloModule* module, bool change_expected, bool exclude_entry) { + TupleSimplifier simplifier(exclude_entry); + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); } +TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { + // Verify that the root computation can be excluded + auto module = CreateNewModule(); + + HloInstruction* p0; + HloInstruction* p1; + HloComputation* c0; + HloComputation* c1; + HloComputation* entry; + + { + HloComputation::Builder builder(TestName() + "_1"); + p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c0 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_2"); + p1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c1 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_Entry"); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* call0 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); + HloInstruction* call1 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); + HloInstruction* gte3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); + + entry = module->AddEntryComputation(builder.Build()); + } + + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + + EXPECT_THAT(c0->root_instruction(), p0); + EXPECT_THAT(c1->root_instruction(), p1); + EXPECT_THAT(entry->instruction_count(), 9); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 754fd8ef169231827eeb5bfd72aeb596644ca767..d33d5bb8f30c8504aa323d461e5f59709b48e1fc 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc deleted file mode 100644 index 4a55e4095aa92cbdcd1bcb585dc851b2c5e9a32c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ /dev/null @@ -1,3548 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { -namespace { - -HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { - switch (unop) { - case UNOP_ABS: - return HloOpcode::kAbs; - case UNOP_CEIL: - return HloOpcode::kCeil; - case UNOP_COS: - return HloOpcode::kCos; - case UNOP_EXP: - return HloOpcode::kExp; - case UNOP_FLOOR: - return HloOpcode::kFloor; - case UNOP_IMAG: - return HloOpcode::kImag; - case UNOP_IS_FINITE: - return HloOpcode::kIsFinite; - case UNOP_LOG: - return HloOpcode::kLog; - case UNOP_NOT: - return HloOpcode::kNot; - case UNOP_NEGATE: - return HloOpcode::kNegate; - case UNOP_REAL: - return HloOpcode::kReal; - case UNOP_ROUND_NEAREST_AFZ: - return HloOpcode::kRoundNearestAfz; - case UNOP_SIGN: - return HloOpcode::kSign; - case UNOP_SIN: - return HloOpcode::kSin; - case UNOP_SORT: - return HloOpcode::kSort; - case UNOP_TANH: - return HloOpcode::kTanh; - default: - LOG(FATAL) << "unhandled operation " << unop; - } -} - -HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { - switch (binop) { - case BINOP_ATAN2: - return HloOpcode::kAtan2; - case BINOP_COMPLEX: - return HloOpcode::kComplex; - case BINOP_MUL: - return HloOpcode::kMultiply; - case BINOP_ADD: - return HloOpcode::kAdd; - case BINOP_SUB: - return HloOpcode::kSubtract; - case BINOP_DIV: - return HloOpcode::kDivide; - case BINOP_EQ: - return HloOpcode::kEq; - case BINOP_GE: - return HloOpcode::kGe; - case BINOP_GT: - return HloOpcode::kGt; - case BINOP_LE: - return HloOpcode::kLe; - case BINOP_LT: - return HloOpcode::kLt; - case BINOP_NE: - return HloOpcode::kNe; - case BINOP_MAX: - return HloOpcode::kMaximum; - case BINOP_MIN: - return HloOpcode::kMinimum; - case BINOP_POW: - return HloOpcode::kPower; - case BINOP_REM: - return HloOpcode::kRemainder; - case BINOP_OR: - return HloOpcode::kOr; - case BINOP_AND: - return HloOpcode::kAnd; - case BINOP_SHIFT_LEFT: - return HloOpcode::kShiftLeft; - case BINOP_SHIFT_RIGHT_ARITHMETIC: - return HloOpcode::kShiftRightArithmetic; - case BINOP_SHIFT_RIGHT_LOGICAL: - return HloOpcode::kShiftRightLogical; - default: - LOG(FATAL) << "unhandled operation " << binop; - } -} - -HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { - switch (triop) { - case TRIOP_CLAMP: - return HloOpcode::kClamp; - case TRIOP_SELECT: - return HloOpcode::kSelect; - default: - LOG(FATAL) << "unhandled operation " << triop; - } -} - -HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { - switch (varop) { - case VAROP_TUPLE: - return HloOpcode::kTuple; - default: - LOG(FATAL) << "unhandled operation " << varop; - } -} - -} // namespace - -/* static */ StatusOr> -UserComputation::MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new) { - auto user_computation = - MakeUnique(session_computation.name(), handle); - { - tensorflow::mutex_lock lock(user_computation->mutex_); - user_computation->session_computation_ = session_computation; - user_computation->next_handle_value_ = - std::max_element(session_computation.requests().begin(), - session_computation.requests().end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }) - ->first + - 1; - TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); - } - - return std::move(user_computation); -} - -UserComputation::UserComputation(const string& name, - const ComputationHandle& handle) - : name_(name), next_handle_value_(1) { - *session_computation_.mutable_computation_handle() = handle; - session_computation_.set_name(name); - - VLOG(1) << "New UserComputation \"" << name - << "\", handle: " << handle.handle(); -} - -ComputationDataHandle UserComputation::CreateComputationDataHandle() { - ComputationDataHandle handle; - handle.set_handle(next_handle_value_); - // Handles are used as Version values and *must* be assigned consecutively for - // computation versioning to work. - next_handle_value_++; - return handle; -} - -StatusOr UserComputation::AddParameterInstruction( - const ParameterRequest& parameter_request) { - tensorflow::mutex_lock lock(mutex_); - - int64 parameter_number = parameter_request.parameter(); - if (parameters_.count(parameter_number) != 0) { - return InvalidArgument("parameter %lld already registered", - parameter_number); - } - ComputationDataHandle handle = CreateComputationDataHandle(); - - const Shape& validated_shape = parameter_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_parameter_request() = parameter_request; - - parameters_[parameter_number] = &request; - - VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << parameter_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddSendInstruction(const SendRequest& send_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); - - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_send_request() = send_request; - - VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << send_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRecvInstruction( - const RecvRequest& recv_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = recv_request.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_recv_request() = recv_request; - - VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << recv_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddPadInstruction( - const PadRequest& pad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(pad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookUpRequest(pad_request.padding_value())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( - operand->output_shape(), - padding_value->output_shape(), - pad_request.padding_config())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_pad_request() = pad_request; - - VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << pad_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConstantInstruction( - const ConstantRequest& constant_request) { - const Shape& validated_shape = constant_request.literal().shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - tensorflow::mutex_lock lock(mutex_); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_constant_request() = constant_request; - - VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle(); - return handle; -} - -StatusOr UserComputation::AddGatherInstruction( - const GatherRequest& gather_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, - LookUpRequest(gather_request.input())); - TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, - LookUpRequest(gather_request.gather_indices())); - - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferGatherShape( - input_request->output_shape(), gather_indices_request->output_shape(), - gather_request.dimension_numbers(), - AsInt64Slice(gather_request.window_bounds()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_gather_request() = gather_request; - - VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << gather_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(get_tuple_element_request.operand())); - if (!ShapeUtil::IsTuple(operand->output_shape())) { - return InvalidArgument( - "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(operand->output_shape()).c_str()); - } - Shape element_shape = ShapeUtil::GetTupleElementShape( - operand->output_shape(), get_tuple_element_request.index()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = element_shape; - *request.mutable_request()->mutable_get_tuple_element_request() = - get_tuple_element_request; - - VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << get_tuple_element_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_trace_request() = trace_request; - - VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << trace_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRngInstruction( - const RngRequest& rng_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check the number of parameters per RNG distribution. - switch (rng_request.distribution()) { - case RandomDistribution::RNG_NORMAL: - case RandomDistribution::RNG_UNIFORM: - if (rng_request.parameter_size() != 2) { - return InvalidArgument( - "RNG distribution (%s) expects 2 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; - default: - LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); - } - - // Verify that the parameter indices are valid; - for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookUpRequest(param).status()); - } - const Shape& validated_shape = rng_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_rng_request() = rng_request; - - VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << rng_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, - AsInt64Slice(map_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_map_request() = map_request; - - VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << map_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceShape( - operand->output_shape(), init_value->output_shape(), - AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_request() = reduce_request; - - VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_training_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_training_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_training_request.offset())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormTrainingShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), batch_norm_training_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_training_request() = - batch_norm_training_request; - - VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_training_request.ShortDebugString(); - - return handle; -} - -StatusOr -UserComputation::AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_inference_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_inference_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_inference_request.offset())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_inference_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_inference_request.variance())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBatchNormInferenceShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), mean->output_shape(), - variance->output_shape(), - batch_norm_inference_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_inference_request() = - batch_norm_inference_request; - - VLOG(1) << "AddBatchNormInferenceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << batch_norm_inference_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_grad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_grad_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_grad_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_grad_request.variance())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, - LookUpRequest(batch_norm_grad_request.grad_output())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormGradShape( - operand->output_shape(), scale->output_shape(), mean->output_shape(), - variance->output_shape(), grad_output->output_shape(), - batch_norm_grad_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_grad_request() = - batch_norm_grad_request; - - VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_grad_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_window_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_window_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->output_shape(), init_value->output_shape(), - reduce_window_request.window(), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_window_request() = - reduce_window_request; - - VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_window_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(select_and_scatter_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookUpRequest(select_and_scatter_request.source())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(select_and_scatter_request.init_value())); - - VersionedComputationHandle::Version select_version = - select_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, - select_computation.ComputeProgramShape(select_version)); - VersionedComputationHandle::Version scatter_version = - scatter_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, - scatter_computation.ComputeProgramShape(scatter_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferSelectAndScatterShape( - operand->output_shape(), *select_program_shape, - select_and_scatter_request.window(), source->output_shape(), - init_value->output_shape(), *scatter_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(select_version); - request.add_embedded_computation_versions(scatter_version); - *request.mutable_request()->mutable_select_and_scatter_request() = - select_and_scatter_request; - - VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << select_and_scatter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReverseInstruction( - const ReverseRequest& reverse_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reverse_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReverseShape( - operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reverse_request() = reverse_request; - VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reverse_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookUpRequest(while_request.init())); - - VersionedComputationHandle::Version condition_version = - condition_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr condition_program_shape, - condition_computation.ComputeProgramShape(condition_version)); - - VersionedComputationHandle::Version body_version = body_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, - body_computation.ComputeProgramShape(body_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferWhileShape( - *condition_program_shape, *body_program_shape, init->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(condition_version); - request.add_embedded_computation_versions(body_version); - *request.mutable_request()->mutable_while_request() = while_request; - - VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << while_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* pred, - LookUpRequest(conditional_request.predicate())); - TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, - LookUpRequest(conditional_request.true_operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, - LookUpRequest(conditional_request.false_operand())); - - VersionedComputationHandle::Version true_computation_version = - true_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr true_computation_shape, - true_computation.ComputeProgramShape(true_computation_version)); - - VersionedComputationHandle::Version false_computation_version = - false_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr false_computation_shape, - false_computation.ComputeProgramShape(false_computation_version)); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferConditionalShape( - pred->output_shape(), true_operand->output_shape(), - false_operand->output_shape(), - *true_computation_shape, *false_computation_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(true_computation_version); - request.add_embedded_computation_versions(false_computation_version); - *request.mutable_request()->mutable_conditional_request() = - conditional_request; - - VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << conditional_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBroadcastInstruction( - const BroadcastRequest& broadcast_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(broadcast_request.operand())); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBroadcastShape( - operand->output_shape(), - AsInt64Slice(broadcast_request.broadcast_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_broadcast_request() = broadcast_request; - - VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << broadcast_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReshapeInstruction( - const ReshapeRequest& reshape_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reshape_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReshapeShape( - operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), - AsInt64Slice(reshape_request.new_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reshape_request() = reshape_request; - - VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reshape_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTransposeInstruction( - const TransposeRequest& transpose_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(transpose_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferTransposeShape( - operand->output_shape(), - AsInt64Slice(transpose_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_transpose_request() = transpose_request; - - VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << transpose_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSliceInstruction( - const SliceRequest& slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(slice_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferSliceShape( - operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_slice_request() = slice_request; - - VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookUpRequest(dynamic_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferDynamicSliceShape( - operand->output_shape(), start_indices->output_shape(), - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_slice_request() = - dynamic_slice_request; - - VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dynamic_slice_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_update_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookUpRequest(dynamic_update_slice_request.update())); - - TF_ASSIGN_OR_RETURN( - const OperationRequest* start_indices, - LookUpRequest(dynamic_update_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->output_shape(), update->output_shape(), - start_indices->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_update_slice_request() = - dynamic_update_slice_request; - - VLOG(1) << "AddDynamicUpdateSliceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << dynamic_update_slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferConcatOpShape( - operand_shapes, concatenate_request.dimension())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_concatenate_request() = - concatenate_request; - - VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << concatenate_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_convert_request() = convert_request; - - VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBitcastConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_bitcast_convert_request() = - convert_request; - - VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_precision_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferReducePrecisionShape( - operand->output_shape(), reduce_precision_request.exponent_bits(), - reduce_precision_request.mantissa_bits())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_reduce_precision_request() = - reduce_precision_request; - - VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_precision_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvolveInstruction( - const ConvolveRequest& convolve_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(convolve_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(convolve_request.rhs())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( - lhs->output_shape(), rhs->output_shape(), - convolve_request.window(), - convolve_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_convolve_request() = convolve_request; - - VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convolve_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddFftInstruction( - const FftRequest& fft_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(fft_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferFftShape( - operand->output_shape(), fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_fft_request() = fft_request; - - VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << fft_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(cross_replica_sum_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand->output_shape()})); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_cross_replica_sum_request() = - cross_replica_sum_request; - - VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << cross_replica_sum_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddInfeedInstruction( - const InfeedRequest& infeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = infeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Infeed must have a layout"); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_infeed_request() = infeed_request; - - VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << infeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddOutfeedInstruction( - const OutfeedRequest& outfeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = outfeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Outfeed must have a layout"); - } - - // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_outfeed_request() = outfeed_request; - - VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << outfeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_call_request() = call_request; - - VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCustomCallInstruction( - const CustomCallRequest& custom_call_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - if (tensorflow::StringPiece(custom_call_request.call_target_name()) - .starts_with("$")) { - return InvalidArgument( - "Invalid custom_call_target \"%s\": Call targets that start with '$' " - "are reserved for internal use.", - custom_call_request.call_target_name().c_str()); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = custom_call_request.shape(); - *request.mutable_request()->mutable_custom_call_request() = - custom_call_request; - - VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << custom_call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddHostComputeInstruction( - const HostComputeRequest& host_compute_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : host_compute_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = host_compute_request.shape(); - *request.mutable_request()->mutable_host_compute_request() = - host_compute_request; - - VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << host_compute_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDotInstruction( - const DotRequest& dot_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(dot_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(dot_request.rhs())); - - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( - lhs->output_shape(), rhs->output_shape(), - dot_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_dot_request() = dot_request; - - VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dot_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddUnaryInstruction( - const UnaryOpRequest& unary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(unary_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), - operand->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_unary_op_request() = unary_request; - - VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << unary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBinaryInstruction( - const BinaryOpRequest& binary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(binary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(binary_request.rhs())); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferBinaryOpShape( - binary_request.binop(), lhs->output_shape(), rhs->output_shape(), - AsInt64Slice(binary_request.broadcast_dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_binary_op_request() = binary_request; - - VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << binary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTernaryInstruction( - const TernaryOpRequest& ternary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(ternary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(ternary_request.rhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookUpRequest(ternary_request.ehs())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferTernaryOpShape( - ternary_request.triop(), lhs->output_shape(), - rhs->output_shape(), ehs->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_ternary_op_request() = ternary_request; - - VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << ternary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddVariadicInstruction( - const VariadicOpRequest& variadic_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferVariadicOpShape( - variadic_request.varop(), operand_shapes)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_variadic_op_request() = variadic_request; - - VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << variadic_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - return operand->output_shape(); -} - -Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_metadata() = metadata; - return Status::OK(); -} - -Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpSharding (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_sharding() = sharding; - return Status::OK(); -} - -Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { - return InvalidArgument("Invalid handle in SetReturnValue"); - } - - handle_to_return_ = handle; - - VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " - << GetVersionedHandleInternal(); - - return Status::OK(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandle() const { - tensorflow::mutex_lock lock(mutex_); - return GetVersionedHandleInternal(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - - if (handle_to_return_.handle() > 0) { - // A specific handle has been requested for the result of the computation. - versioned_handle.version = handle_to_return_.handle(); - } else { - // A version value is simply the most recently assigned - // ComputationDataHandle value, ie the handle value of the root of the - // computation. - versioned_handle.version = next_handle_value_ - 1; - } - - return versioned_handle; -} - -VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const { - tensorflow::mutex_lock lock(mutex_); - - // The version at which an operation was added is simply the handle value of - // the ComputationDataHandle. - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - versioned_handle.version = operation.handle(); - return versioned_handle; -} - -VersionedComputationHandle::Version UserComputation::version() const { - return GetVersionedHandle().version; -} - -namespace { - -// Returns true if the operation type corresponding to the given opcase can be -// the root of the computation. -bool CanBeRoot(const OpRequest::OpCase& op_case) { - switch (op_case) { - case OpRequest::kTraceRequest: - case OpRequest::kSendRequest: - case OpRequest::kOutfeedRequest: - return false; - default: - return true; - } -} - -// Returns a pointer to the operation with the given data handle value in the -// given SessionComputation. -StatusOr LookUpRequest( - int64 handle_value, const SessionComputation& session_computation) { - if (session_computation.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation.requests().at(handle_value); -} - -// Returns the OperationRequest corresponding to the root (result) of the -// session computation. -StatusOr GetRoot( - VersionedComputationHandle::Version version, - const SessionComputation& session_computation) { - TF_RET_CHECK(version > 0); - // Not all instructions can be roots. Walk backwards from the operation - // indicated by this version until a valid root is found. - const OperationRequest* root_request = nullptr; - while (version > 0) { - TF_ASSIGN_OR_RETURN(root_request, - LookUpRequest(version, session_computation)); - if (CanBeRoot(root_request->request().op_case())) { - break; - } - version--; - } - if (version == 0) { - return InternalError("Computation contains no root operation"); - } - return root_request; -} - -} // namespace - -StatusOr> -UserComputation::ComputeProgramShape( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - if (program_shape_ == nullptr || program_shape_version_ != version) { - // ProgramShape has not been computed yet, or is for different - // version. Compute it now. - TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); - - auto program_shape = MakeUnique(); - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - int64 param_no = parameter_request.parameter(); - // Parameters may be out of order so expand ProgramShape parameters - // until it is at least large enough to hold the current parameter - // number. - while (program_shape->parameters_size() <= param_no) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); - } - *program_shape->mutable_parameters(param_no) = request.output_shape(); - *program_shape->mutable_parameter_names(param_no) = - parameter_request.name(); - } - } - - // The root determines the output shape. - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version, session_computation_)); - *program_shape->mutable_result() = root_request->output_shape(); - if (ShapeUtil::IsOpaque(program_shape->result())) { - return Unimplemented("Computation results cannot be opaque"); - } - - program_shape_ = std::move(program_shape); - program_shape_version_ = version; - } - - return program_shape_; -} - -namespace { - -// A visitor which checks whether an operation is pure functional meaning that -// it doesn't depend on any parameter with an index higher then num_parameters. -// The visitor walks the computation starting at a given operation and sets -// is_functional to false iff a parameter or RNG operation is encountered. -void PureFunctionalVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - int64 num_parameters, std::set* visited, - bool* is_functional) { - if (visited->count(handle.handle()) != 0 || !*is_functional) { - return; - } - - const OperationRequest& request = - session_computation.requests().at(handle.handle()); - switch (request.request().op_case()) { - case OpRequest::kRngRequest: - *is_functional = false; - break; - - case OpRequest::kConstantRequest: - break; - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - PureFunctionalVisitor(session_computation, - get_tuple_element_request.operand(), num_parameters, - visited, is_functional); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - PureFunctionalVisitor(session_computation, slice_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.update(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - PureFunctionalVisitor(session_computation, convolve_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, convolve_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - PureFunctionalVisitor(session_computation, fft_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_functional = false; - break; - } - - case OpRequest::kInfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kOutfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kHostComputeRequest: { - *is_functional = false; - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself, - // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_functional=false in other similar - // cases since we're already relying on IsConstant to return true. - *is_functional = false; - break; - } - - case OpRequest::kCustomCallRequest: { - *is_functional = false; - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - PureFunctionalVisitor(session_computation, dot_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, dot_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kSendRequest: { - *is_functional = false; - break; - } - - case OpRequest::kRecvRequest: { - *is_functional = false; - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - PureFunctionalVisitor(session_computation, reduce_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, reduce_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - PureFunctionalVisitor(session_computation, - reduce_window_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - reduce_window_request.init_value(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.source(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the select and scatter - // computations themselves. - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - PureFunctionalVisitor(session_computation, broadcast_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - PureFunctionalVisitor(session_computation, reshape_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - PureFunctionalVisitor(session_computation, reverse_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - PureFunctionalVisitor(session_computation, pad_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, pad_request.padding_value(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - if (parameter_request.parameter() >= num_parameters) { - *is_functional = false; - } - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - PureFunctionalVisitor(session_computation, while_request.init(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the condition and body - // computations themselves. - *is_functional = false; - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - PureFunctionalVisitor(session_computation, - conditional_request.predicate(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.true_operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.false_operand(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the true and false computations - // themselves. - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - PureFunctionalVisitor(session_computation, transpose_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - PureFunctionalVisitor(session_computation, unary_op_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.offset(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.scale(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.offset(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.mean(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.variance(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.variance(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.grad_output(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - PureFunctionalVisitor(session_computation, binary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, binary_op_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kGatherRequest: { - PureFunctionalVisitor(session_computation, - request.request().gather_request().input(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - request.request().gather_request().gather_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - if (!*is_functional) { - VLOG(1) << "Non-functional: " << request.request().DebugString(); - } - visited->insert(handle.handle()); -} - -} // namespace - -StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, - int64 num_parameters) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the handle is valid. - auto operation_status = LookUpRequest(handle); - if (!operation_status.ok()) { - return operation_status.status(); - } - - bool is_constant = true; - std::set visited; - PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, - &is_constant); - - return is_constant; -} - -std::vector -UserComputation::GetEmbeddedComputations( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(1) - << "GetEmbeddedComputations(" << name() << " " - << VersionedComputationHandle{session_computation_.computation_handle(), - version} - << ")"; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - std::vector computations; - std::vector sorted_handles; - for (const auto& handle_request : session_computation_.requests()) { - sorted_handles.push_back(handle_request.first); - } - std::sort(sorted_handles.begin(), sorted_handles.end()); - for (int64 handle : sorted_handles) { - const auto& handle_request = session_computation_.requests().find(handle); - CHECK(handle_request != session_computation_.requests().end()); - int64 handle_value = handle_request->first; - if (handle_value <= version) { - const OperationRequest& request = handle_request->second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const CallRequest& call_request = request.request().call_request(); - const VersionedComputationHandle versioned_handle = { - call_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kMapRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const MapRequest& map_request = request.request().map_request(); - const VersionedComputationHandle versioned_handle = { - map_request.to_apply(), request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceRequest& reduce_request = - request.request().reduce_request(); - const VersionedComputationHandle versioned_handle = { - reduce_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceWindowRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - const VersionedComputationHandle versioned_handle = { - reduce_window_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - const VersionedComputationHandle select_versioned_handle = { - select_and_scatter_request.select(), - request.embedded_computation_versions(0)}; - computations.push_back(select_versioned_handle); - const VersionedComputationHandle scatter_versioned_handle = { - select_and_scatter_request.scatter(), - request.embedded_computation_versions(1)}; - computations.push_back(scatter_versioned_handle); - break; - } - - case OpRequest::kWhileRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const WhileRequest& while_request = request.request().while_request(); - const VersionedComputationHandle condition_versioned_handle = { - while_request.condition(), - request.embedded_computation_versions(0)}; - computations.push_back(condition_versioned_handle); - const VersionedComputationHandle body_versioned_handle = { - while_request.body(), request.embedded_computation_versions(1)}; - computations.push_back(body_versioned_handle); - break; - } - - case OpRequest::kConditionalRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - const VersionedComputationHandle true_computation_versioned_handle = { - conditional_request.true_computation(), - request.embedded_computation_versions(0)}; - computations.push_back(true_computation_versioned_handle); - const VersionedComputationHandle false_computation_versioned_handle = - {conditional_request.false_computation(), - request.embedded_computation_versions(1)}; - computations.push_back(false_computation_versioned_handle); - break; - } - - default: - // No embedded computation. - break; - } - } - } - VLOG(2) << "Embedded computations: " - << tensorflow::str_util::Join( - computations, ", ", - [](string* out, const VersionedComputationHandle& h) { - out->append(h.ToString()); - }); - return computations; -} - -StatusOr -UserComputation::LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const { - tensorflow::mutex_lock lock(mutex_); - return LookUpRequest(handle); -} - -tensorflow::gtl::optional UserComputation::ParameterMetadata( - int parameter_number) const { - tensorflow::mutex_lock lock(mutex_); - auto it = parameters_.find(parameter_number); - if (it == parameters_.end()) { - return tensorflow::gtl::nullopt; - } - OperationRequest* op = it->second; - return &op->request().metadata(); -} - -Status UserComputation::RemapEmbeddedComputations( - const std::map& old_to_new) { - auto update = [&old_to_new](ComputationHandle* to_update) -> Status { - int64 old = to_update->handle(); - auto it = old_to_new.find(old); - if (it == old_to_new.end()) { - string mapping = tensorflow::str_util::Join( - old_to_new, ", ", - [](string* out, std::pair element) { - tensorflow::strings::Appendf(out, "%lld:%lld", element.first, - element.second.handle()); - }); - return NotFound( - "could not find referenced (old) computation handle in mapping: " - "%lld; mapping: {%s}", - old, mapping.c_str()); - } - VLOG(2) << "remapping " << old << " to " << it->second.handle(); - *to_update = it->second; - return Status::OK(); - }; - TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); - for (auto& handle_request : *session_computation_.mutable_requests()) { - OperationRequest& request = handle_request.second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - CallRequest* call_request = - request.mutable_request()->mutable_call_request(); - TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); - break; - } - case OpRequest::kMapRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - MapRequest* map_request = - request.mutable_request()->mutable_map_request(); - TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceRequest* reduce_request = - request.mutable_request()->mutable_reduce_request(); - TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceWindowRequest* reduce_window_request = - request.mutable_request()->mutable_reduce_window_request(); - TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); - break; - } - case OpRequest::kSelectAndScatterRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - SelectAndScatterRequest* select_and_scatter_request = - request.mutable_request()->mutable_select_and_scatter_request(); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_select())); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_scatter())); - break; - } - case OpRequest::kWhileRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - WhileRequest* while_request = - request.mutable_request()->mutable_while_request(); - TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); - TF_RETURN_IF_ERROR(update(while_request->mutable_body())); - break; - } - case OpRequest::kConditionalRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - ConditionalRequest* conditional_request = - request.mutable_request()->mutable_conditional_request(); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_true_computation())); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_false_computation())); - break; - } - default: - // No embedded computation. - TF_RET_CHECK(0 == request.embedded_computation_versions_size()); - break; - } - } - return Status::OK(); -} - -SessionComputation UserComputation::CloneSessionComputation( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - SessionComputation result = session_computation_; - // Erase all the requests that exceed the version specified. - // There's no lower_bound method on tensorflow::protobuf::Map so we iterate - // all the elements. - auto it = result.mutable_requests()->begin(); - while (it != result.mutable_requests()->end()) { - if (it->first > version) { - it = result.mutable_requests()->erase(it); - } else { - ++it; - } - } - return result; -} - -StatusOr UserComputation::LookUpRequest( - const ComputationDataHandle& handle) const { - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation_.requests().at(handle_value); -} - -Status UserComputation::CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const { - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - // Determine number of parameter inputs at the given version. - std::map parameter_requests; - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - // Duplicate parameters should be checked when parameter requests are - // added. - TF_RET_CHECK(0 == - parameter_requests.count(parameter_request.parameter())); - parameter_requests[parameter_request.parameter()] = ¶meter_request; - } - } - - for (int64 i = 0; i < parameter_requests.size(); ++i) { - auto it = parameter_requests.find(i); - if (it == parameter_requests.end()) { - return FailedPrecondition( - "computation %s does not have all its parameters populated " - "sequentially, missing parameter %lld", - name_.c_str(), i); - } - } - - return Status::OK(); -} - -namespace { - -// Helper class which builds an HLO computation from a SessionComputation. To -// construct the HLO computation, the SessionComputation graph is walked in -// DFS order lowering each OperationRequest to an HLO instruction. -class ComputationLowerer { - public: - static StatusOr> Lower( - const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) { - ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver), debug_options, - include_unreachable_instructions); - return lowerer.Lower(); - } - - private: - ComputationLowerer(const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) - : hlo_builder_(computation_name), - session_computation_(session_computation), - version_(version), - hlo_resolver_(std::move(hlo_resolver)), - debug_options_(debug_options), - include_unreachable_instructions_(include_unreachable_instructions) {} - - // Build an HLO computation from the SessionComputation at the given - // version. - StatusOr> Lower(); - - private: - // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. - void TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit); - - // DFS visitor of the UserComputation operations which lowers the operations - // to HLO instructions. - void Visit(const ComputationDataHandle& handle, - std::unordered_map* instructions); - - // Resolves a ComputationHandle and Version to a previously lowered - // HloComputation using the hlo_resolver_ function. - HloComputation* ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version); - - // This function takes an input value which is being implicitly broadcast into - // an output shape and figures out the right kBroadcast instruction(s) - // necessary to replicate the implicit broadcast semantics explicitly. - HloInstruction* ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape); - - HloComputation::Builder hlo_builder_; - const SessionComputation& session_computation_; - const VersionedComputationHandle::Version version_; - const UserComputation::HloComputationResolver hlo_resolver_; - const DebugOptions& debug_options_; - const bool include_unreachable_instructions_; -}; - -// Calls 'apply' on each operand of 'request'. -static void ForEachOperand( - const OperationRequest& request, - const std::function& apply) { - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - for (const ComputationDataHandle& param : rng_request.parameter()) { - apply(param); - } - break; - } - - case OpRequest::kConstantRequest: - break; - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - apply(get_tuple_element_request.operand()); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - apply(slice_request.operand()); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - apply(dynamic_slice_request.operand()); - apply(dynamic_slice_request.start_indices()); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - apply(dynamic_update_slice_request.operand()); - apply(dynamic_update_slice_request.update()); - apply(dynamic_update_slice_request.start_indices()); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - apply(convolve_request.lhs()); - apply(convolve_request.rhs()); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - apply(fft_request.operand()); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - - apply(batch_norm_training_request.operand()); - apply(batch_norm_training_request.scale()); - apply(batch_norm_training_request.offset()); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - - apply(batch_norm_inference_request.operand()); - apply(batch_norm_inference_request.scale()); - apply(batch_norm_inference_request.offset()); - apply(batch_norm_inference_request.mean()); - apply(batch_norm_inference_request.variance()); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - apply(batch_norm_grad_request.operand()); - apply(batch_norm_grad_request.scale()); - apply(batch_norm_grad_request.mean()); - apply(batch_norm_grad_request.variance()); - apply(batch_norm_grad_request.grad_output()); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - apply(cross_replica_sum_request.operand()); - break; - } - - case OpRequest::kInfeedRequest: - break; - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - apply(outfeed_request.operand()); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - apply(reduce_request.operand()); - apply(reduce_request.init_value()); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - apply(reduce_window_request.operand()); - apply(reduce_window_request.init_value()); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - apply(select_and_scatter_request.operand()); - apply(select_and_scatter_request.source()); - apply(select_and_scatter_request.init_value()); - - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - apply(broadcast_request.operand()); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - apply(reshape_request.operand()); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - apply(transpose_request.operand()); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - apply(reverse_request.operand()); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - apply(pad_request.operand()); - apply(pad_request.padding_value()); - break; - } - - case OpRequest::kRecvRequest: - case OpRequest::kParameterRequest: - break; - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - apply(while_request.init()); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - apply(conditional_request.predicate()); - apply(conditional_request.true_operand()); - apply(conditional_request.false_operand()); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - apply(ternary_op_request.lhs()); - apply(ternary_op_request.rhs()); - apply(ternary_op_request.ehs()); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - for (const ComputationDataHandle& operand : cc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& hc_request = - request.request().host_compute_request(); - for (const ComputationDataHandle& operand : hc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - apply(dot_request.rhs()); - apply(dot_request.lhs()); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - apply(unary_op_request.operand()); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - apply(binary_op_request.rhs()); - apply(binary_op_request.lhs()); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - apply(reduce_precision_request.operand()); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - apply(trace_request.operand()); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - apply(send_request.operand()); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - apply(gather_request.input()); - apply(gather_request.gather_indices()); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } -} - -void ComputationLowerer::TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit) { - // Stack containing {handle, enter} pairs. The 'enter' value describes whether - // we are entering or leaving 'handle'. - std::stack> work; - work.push({root, true}); - while (!work.empty()) { - ComputationDataHandle handle; - bool enter; - std::tie(handle, enter) = work.top(); - work.pop(); - - if (enter) { - // We are entering 'handle'. The first time we enter 'handle', we add it - // to 'visited' with a nullptr value. If 'handle' is already in 'visited', - // we do not visit it again. This algorithm only uses the presence of - // a handle in 'visited', but we use a map so we can use the same data - // structure to store the HloInstruction outputs. - if (visited->emplace(handle.handle(), nullptr).second) { - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - // Push the corresponding 'leave' action onto the stack, followed by - // the operands. - work.push({handle, false}); - ForEachOperand(request, [&work](const ComputationDataHandle& child) { - work.push({child, true}); - }); - } - } else { - // We are leaving 'handle'. We have visited the operands of 'handle', and - // now can visit the 'handle' itself. - visit(handle); - } - } -} - -StatusOr> ComputationLowerer::Lower() { - // Map from ComputationDataHandle to HLO instruction. Serves as a record of - // which operations have been visited as well as a cache for looking up - // ComputationDataHandles as HloInstructions. - std::unordered_map instructions; - - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version_, session_computation_)); - - auto visit = [&](const ComputationDataHandle& handle) { - Visit(handle, &instructions); - }; - TraversePostorder(root_request->output_handle(), &instructions, visit); - HloInstruction* hlo_root = - instructions.at(root_request->output_handle().handle()); - - if (include_unreachable_instructions_) { - // Iterate through all computation data handles, and visit any unvisited - // operations. - for (int64 request_num = 1; request_num <= version_; ++request_num) { - TF_ASSIGN_OR_RETURN(const OperationRequest* request, - LookUpRequest(request_num, session_computation_)); - TraversePostorder(request->output_handle(), &instructions, visit); - } - } - - return hlo_builder_.Build(hlo_root); -} - -HloComputation* ComputationLowerer::ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version) { - const VersionedComputationHandle checked_handle = {handle, version}; - return hlo_resolver_(checked_handle); -} - -HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape) { - auto fadd = [this](std::unique_ptr x) { - return hlo_builder_.AddInstruction(std::move(x)); - }; - return fadd( - HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); -} - -void ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::unordered_map* instructions) { - CHECK_LE(handle.handle(), version_); - CHECK(instructions->at(handle.handle()) == nullptr); - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - auto add_instruction = [&](std::unique_ptr instruction) { - HloInstruction* hlo_instruction = - hlo_builder_.AddInstruction(std::move(instruction)); - hlo_instruction->set_metadata(request.request().metadata()); - if (request.request().has_sharding()) { - OpSharding op_sharding = request.request().sharding(); - hlo_instruction->set_sharding( - HloSharding::FromProto(op_sharding).ValueOrDie()); - } - return hlo_instruction; - }; - auto lookup_instruction = [&](const ComputationDataHandle& handle) { - return instructions->at(handle.handle()); - }; - HloInstruction* hlo_instruction; - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - std::vector parameters; - for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(lookup_instruction(param)); - } - hlo_instruction = add_instruction(HloInstruction::CreateRng( - request.output_shape(), rng_request.distribution(), parameters)); - break; - } - - case OpRequest::kConstantRequest: { - const ConstantRequest& constant_request = - request.request().constant_request(); - hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal::CreateFromProto(constant_request.literal()) - .ConsumeValueOrDie())); - break; - } - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - HloInstruction* operand = - lookup_instruction(get_tuple_element_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, get_tuple_element_request.index())); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = lookup_instruction(slice_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSlice( - request.output_shape(), operand, - AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_slice_request.operand()); - HloInstruction* start_indices = - lookup_instruction(dynamic_slice_request.start_indices()); - - hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_update_slice_request.operand()); - HloInstruction* update = - lookup_instruction(dynamic_update_slice_request.update()); - HloInstruction* start_indices = - lookup_instruction(dynamic_update_slice_request.start_indices()); - hlo_instruction = - add_instruction(HloInstruction::CreateDynamicUpdateSlice( - request.output_shape(), operand, update, start_indices)); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( - request.output_shape(), operands, concatenate_request.dimension())); - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); - HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - HloInstruction* operand = lookup_instruction(fft_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateFft( - request.output_shape(), operand, fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - HloInstruction* lhs = lookup_instruction(dot_request.lhs()); - HloInstruction* rhs = lookup_instruction(dot_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateDot( - request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - HloInstruction* operand = - lookup_instruction(cross_replica_sum_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), {operand})); - break; - } - - case OpRequest::kInfeedRequest: { - const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = add_instruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); - break; - } - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - HloInstruction* operand = lookup_instruction(outfeed_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( - outfeed_request.shape(), operand, outfeed_request.outfeed_config())); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - std::vector operands; - for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version map_version = - request.embedded_computation_versions(0); - HloComputation* map_computation = - ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = add_instruction(HloInstruction::CreateMap( - request.output_shape(), operands, map_computation)); - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = lookup_instruction(reduce_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_version = - request.embedded_computation_versions(0); - HloComputation* reduce_computation = - ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - HloInstruction* operand = - lookup_instruction(reduce_window_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_window_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_window_version = - request.embedded_computation_versions(0); - HloComputation* reduce_window_computation = ResolveComputation( - reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - HloInstruction* operand = - lookup_instruction(select_and_scatter_request.operand()); - HloInstruction* source = - lookup_instruction(select_and_scatter_request.source()); - HloInstruction* init_value = - lookup_instruction(select_and_scatter_request.init_value()); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version select_version = - request.embedded_computation_versions(0); - VersionedComputationHandle::Version scatter_version = - request.embedded_computation_versions(1); - HloComputation* select_computation = ResolveComputation( - select_and_scatter_request.select(), select_version); - HloComputation* scatter_computation = ResolveComputation( - select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_training_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_training_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_training_request.offset()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( - request.output_shape(), operand, scale, offset, - batch_norm_training_request.epsilon(), - batch_norm_training_request.feature_index())); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_inference_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_inference_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_inference_request.offset()); - HloInstruction* mean = - lookup_instruction(batch_norm_inference_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_inference_request.variance()); - - hlo_instruction = - add_instruction(HloInstruction::CreateBatchNormInference( - request.output_shape(), operand, scale, offset, mean, variance, - batch_norm_inference_request.epsilon(), - batch_norm_inference_request.feature_index())); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - HloInstruction* operand = - lookup_instruction(batch_norm_grad_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_grad_request.scale()); - HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_grad_request.variance()); - HloInstruction* grad_output = - lookup_instruction(batch_norm_grad_request.grad_output()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( - request.output_shape(), operand, scale, mean, variance, grad_output, - batch_norm_grad_request.epsilon(), - batch_norm_grad_request.feature_index())); - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - HloInstruction* operand = lookup_instruction(broadcast_request.operand()); - std::vector broadcast_dimensions; - // The client-level broadcast instruction just appends dimensions on the - // left (adds lowest numbered dimensions). The HLO broadcast op is more - // flexible and can add new dimensions anywhere. The broadcast_dimensions - // maps operand dimensions to dimensions in the broadcast output, so - // to append dimensions on the left the broadcast_dimensions should just - // be the n highest dimension numbers of the output shape where n is - // the number of input dimensions. - broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { - broadcast_dimensions.push_back(i + - ShapeUtil::Rank(request.output_shape()) - - ShapeUtil::Rank(operand->shape())); - } - hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - HloInstruction* operand = lookup_instruction(reshape_request.operand()); - HloInstruction* transposed; - if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { - transposed = operand; - } else { - transposed = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); - } - hlo_instruction = add_instruction( - HloInstruction::CreateReshape(request.output_shape(), transposed)); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - HloInstruction* operand = lookup_instruction(transpose_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - HloInstruction* operand = lookup_instruction(reverse_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = lookup_instruction(pad_request.operand()); - HloInstruction* padding_value = - lookup_instruction(pad_request.padding_value()); - hlo_instruction = add_instruction(HloInstruction::CreatePad( - request.output_shape(), operand, padding_value, - pad_request.padding_config())); - break; - } - - case OpRequest::kRecvRequest: { - const RecvRequest& recv_request = request.request().recv_request(); - HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( - request.output_shape(), recv_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - hlo_instruction = add_instruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateConvert(request.output_shape(), operand)); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( - request.output_shape(), operand)); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version condition_version = - request.embedded_computation_versions(0); - HloComputation* condition = - ResolveComputation(while_request.condition(), condition_version); - VersionedComputationHandle::Version body_version = - request.embedded_computation_versions(1); - HloComputation* body = - ResolveComputation(while_request.body(), body_version); - HloInstruction* init = lookup_instruction(while_request.init()); - hlo_instruction = add_instruction(HloInstruction::CreateWhile( - request.output_shape(), condition, body, init)); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version true_computation_version = - request.embedded_computation_versions(0); - HloComputation* true_computation = ResolveComputation( - conditional_request.true_computation(), true_computation_version); - VersionedComputationHandle::Version false_computation_version = - request.embedded_computation_versions(1); - HloComputation* false_computation = ResolveComputation( - conditional_request.false_computation(), false_computation_version); - HloInstruction* predicate = - lookup_instruction(conditional_request.predicate()); - HloInstruction* true_operand = - lookup_instruction(conditional_request.true_operand()); - HloInstruction* false_operand = - lookup_instruction(conditional_request.false_operand()); - hlo_instruction = add_instruction(HloInstruction::CreateConditional( - request.output_shape(), predicate, true_operand, true_computation, - false_operand, false_computation)); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); - HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); - auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { - ehs = - ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); - } - } - - hlo_instruction = add_instruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - auto hlo_opcode = - VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = add_instruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - std::vector operands; - for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(lookup_instruction(handle)); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version call_version = - request.embedded_computation_versions(0); - HloComputation* call_computation = - ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = add_instruction(HloInstruction::CreateCall( - request.output_shape(), operands, call_computation)); - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - std::vector operands; - for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& host_compute_request = - request.request().host_compute_request(); - std::vector operands; - for (const ComputationDataHandle& operand : - host_compute_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - auto output_shape = host_compute_request.shape(); - auto channel_name = host_compute_request.channel_name(); - auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); - hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( - output_shape, operands, channel_name, cost_estimate_ns)); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - HloInstruction* operand = lookup_instruction(unary_op_request.operand()); - auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = add_instruction(HloInstruction::CreateUnary( - request.output_shape(), hlo_opcode, operand)); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); - auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0 && - ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { - // Emit a broadcast instruction to perform the "broadcast in dimension" - // operation. - HloInstruction* operand_to_broadcast = - ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs - : rhs; - CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), - binary_op_request.broadcast_dimensions().size()); - - // Construct the bounds of the shape of the kBroadcast instruction - // responsible for the in-dimension broadcast. - std::vector output_dimensions; - for (int64 size : request.output_shape().dimensions()) { - output_dimensions.push_back(size); - } - for (int64 operand_dim = 0; - operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); - ++operand_dim) { - int64 output_dim = - binary_op_request.broadcast_dimensions()[operand_dim]; - output_dimensions[output_dim] = - operand_to_broadcast->shape().dimensions(operand_dim); - } - - Shape broadcast_shape = ShapeUtil::MakeShape( - operand_to_broadcast->shape().element_type(), output_dimensions); - - // The broadcast semantics of a client-level binary op broadcast is - // identical to the HLO broadcast semantics so the broadcast_dimensions - // field can just be passed to the instruction builder. - HloInstruction* broadcasted_operand = - add_instruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand_to_broadcast, - AsInt64Slice(binary_op_request.broadcast_dimensions()))); - - lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; - rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; - } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - } - hlo_instruction = add_instruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - HloInstruction* operand = - lookup_instruction(reduce_precision_request.operand()); - auto exponent_bits = reduce_precision_request.exponent_bits(); - auto mantissa_bits = reduce_precision_request.mantissa_bits(); - hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( - request.output_shape(), operand, exponent_bits, mantissa_bits)); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - HloInstruction* operand = lookup_instruction(trace_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(hlo_instruction); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - HloInstruction* operand = lookup_instruction(send_request.operand()); - HloInstruction* send = add_instruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - HloInstruction* input_operand = - lookup_instruction(gather_request.input()); - HloInstruction* gather_indices_operand = - lookup_instruction(gather_request.gather_indices()); - std::vector window_bounds; - c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); - hlo_instruction = add_instruction(HloInstruction::CreateGather( - request.output_shape(), input_operand, gather_indices_operand, - gather_request.dimension_numbers(), window_bounds)); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - (*instructions)[handle.handle()] = hlo_instruction; -} // NOLINT(readability/fn_size) - -} // namespace - -StatusOr> UserComputation::BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), debug_options, - include_unreachable_instructions)); - - return std::move(hlo_computation); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h deleted file mode 100644 index fd5a2ace9bacf66727dc91b6d96305424771a99b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.h +++ /dev/null @@ -1,412 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// A UserComputation is the built-up computation that users create via the -// XLA Service interface. -// -// The XLA service adds instructions to a user computation via this -// interface. The state of the computation is stored as a SessionComputation -// proto which holds a record of all operation-building requests received by the -// XLA service. -// -// UserComputations are lowered to HloComputations which are passed to the high -// level compiler interface. -class UserComputation { - public: - // Factory used when restoring a computation from serialized session - // computation (computation snapshot) data. Remaps any references to - // computation handle via the old_to_new mapping. - // - // An error will occur if the old_to_new mapping cannot resolve a reference to - // a computation that is present in session_computation. - static StatusOr> MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new); - - // Creates an empty computation with the given name and computation handle. - explicit UserComputation(const string& name, const ComputationHandle& handle); - - // Enqueues a parameter-retrieving instruction onto this user computation. - // Returns an error status if the parameter number is already registered with - // different values. - StatusOr AddParameterInstruction( - const ParameterRequest& parameter_request); - - // Enqueues a pad instruction onto this user computation. - StatusOr AddPadInstruction( - const PadRequest& pad_request); - - // Enqueues a tracing instruction onto this user computation. - // Returns an error status if the operand cannot be resolved. - Status AddTraceInstruction(const TraceRequest& trace_request); - - // Enqueues a random number generation instruction onto this user computation. - StatusOr AddRngInstruction( - const RngRequest& rng_request); - - // Enqueues a unary instruction onto this user computation. - // Returns an error status if the operand index is out of bounds. - StatusOr AddUnaryInstruction( - const UnaryOpRequest& unary_request); - - // Enqueues a batch norm training instruction onto this user computation. - StatusOr AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request); - - // Enqueues a batch norm inference instruction onto this user computation. - StatusOr AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request); - - // Enqueues a batch norm grad instruction onto this user computation. - StatusOr AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request); - - // Enqueues a binary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddBinaryInstruction( - const BinaryOpRequest& binary_request); - - // Enqueues a ternary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddTernaryInstruction( - const TernaryOpRequest& ternary_request); - - // Enqueues a variadic instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddVariadicInstruction( - const VariadicOpRequest& variadic_request); - - // Enqueues a constant instruction onto this user computation. - StatusOr AddConstantInstruction( - const ConstantRequest& constant_request); - - // Enqueues a get tuple element instruction onto this user computation. - StatusOr AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request); - - // Enqueues a map instruction onto this user computation. - StatusOr AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation); - - // Enqueues a reduce-precision instruction onto this user computation. - StatusOr AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request); - - // Enqueues a convolution instruction onto this user computation. - StatusOr AddConvolveInstruction( - const ConvolveRequest& convolve_request); - - // Enqueues an FFT instruction onto this user computation. - StatusOr AddFftInstruction( - const FftRequest& fft_request); - - // Enqueues a cross replica sum instruction onto this user computation. - StatusOr AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request); - - // Enqueues an infeed instruction onto this user computation. - StatusOr AddInfeedInstruction( - const InfeedRequest& infeed_request); - - // Enqueues an outfeed instruction onto this user computation. - StatusOr AddOutfeedInstruction( - const OutfeedRequest& outfeed_request); - - // Enqueues a host compute instruction onto this user computation. - StatusOr AddHostComputeInstruction( - const HostComputeRequest& host_compute_request); - - // Enqueues a call instruction onto this user computation. - StatusOr AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation); - - // Enqueues a custom call instruction onto this user computation. - StatusOr AddCustomCallInstruction( - const CustomCallRequest& custom_call_request); - - // Enqueues a dot instruction onto this user computation. - StatusOr AddDotInstruction( - const DotRequest& dot_request); - - // Enqueues a broadcast instruction onto this user computation. - StatusOr AddBroadcastInstruction( - const BroadcastRequest& broadcast_request); - - // Enqueues a reshape instruction onto this user computation. - StatusOr AddReshapeInstruction( - const ReshapeRequest& reshape_request); - - // Enqueues a transpose instruction onto this user computation. - StatusOr AddTransposeInstruction( - const TransposeRequest& transpose_request); - - // Enqueues a slice instruction onto this user computation. - StatusOr AddSliceInstruction( - const SliceRequest& slice_request); - - // Enqueues a dynamic slice instruction onto this user computation. - StatusOr AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request); - - // Enqueues a dynamic update slice instruction onto this user computation. - StatusOr AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request); - - // Enqueues a concatenate instruction onto this user computation. - StatusOr AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request); - - // Enqueues a convert instruction onto this user computation. - StatusOr AddConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a bitcast element instruction onto this user computation. - StatusOr AddBitcastConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a reduce instruction onto this user computation. - StatusOr AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation); - - // Enqueues a windowed reduce instruction onto this user computation. - StatusOr AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation); - - // Enqueues a select-and-scatter instruction onto this user - // computation. - StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation); - - // Enqueues a reverse instruction onto this user computation. - StatusOr AddReverseInstruction( - const ReverseRequest& reverse_request); - - // Enqueues a while instruction onto this user computation. - StatusOr AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation); - - // Enqueues a conditional instruction on this user computation. - StatusOr AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation); - - // Enqueues a Send instruction onto this user computation. - Status AddSendInstruction(const SendRequest& send_request); - - // Enqueues a Recv instruction onto this user computation. - StatusOr AddRecvInstruction( - const RecvRequest& recv_request); - - // Enqueues a Gather instruction onto this user computation. - StatusOr AddGatherInstruction( - const GatherRequest& gather_request); - - // Returns the user-provided name of this user computation, which is provided - // via the XLA computation-building API. - const string& name() const { return name_; } - - // Subsequent executions of this computation will compute the value - // represented by handle, rather than the last expression enqueued - // on the computation. - Status SetReturnValue(const ComputationDataHandle& handle); - - // Return a versioned handle for this computation. - VersionedComputationHandle GetVersionedHandle() const; - - // Return a versioned handle for this computation with a version equal to the - // point at which given operation was added to the computation. - VersionedComputationHandle GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const; - - // Return a version value representing the current state of the - // computation. - VersionedComputationHandle::Version version() const; - - // Computes and returns the program shape for the user computation -- gathers - // parameters and result type into a single proto. A shared_ptr is used - // because the returned pointer refers to an internally cached value which may - // be discarded by the UserComputation object. This avoid unnecessary copies. - // - // If the parameter space is not dense (i.e. there are holes in the parameter - // numbers provided) then an error status is returned. - StatusOr> ComputeProgramShape( - VersionedComputationHandle::Version version) const; - - // Returns true if the given data handle does not depend on any parameter with - // index higher then num_parameters. That is, the value can be computed at - // compile time if we know the first num_parameters arguments. - StatusOr IsConstant(const ComputationDataHandle& handle, - int64 num_parameters); - - // Returns the output shape of the operation indicated by the given handle. - StatusOr GetShape(const ComputationDataHandle& handle); - - // Sets metadata on the Hlo instruction referenced by the given handle. - Status SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata); - - // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding); - - // Builds a HLO computation from the UserComputation. The parameter "resolver" - // is a function which returns a pointer to the HloComputation corresponding - // to the given ComputationHandle at the given version. The resolver is used - // for operations, such as map, which call other computations and need a - // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unreachable_instructions is true, then - // instructions which are not reachable from the root are lowered into - // HloInstructions. - using HloComputationResolver = - std::function; - StatusOr> BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions = true) const; - - // Return a vector containing the embedded computations used by this - // UserComputation. Only embedded computations which are called directly by - // this UserComputation are included. That is, the transitive closure of - // embedded computations is not included. - std::vector GetEmbeddedComputations( - VersionedComputationHandle::Version version) const; - - // Returns the number of OperationRequest objects in this UserComputation. - // The 'version' of a computation is identical to the number of - // OperationRequests in the UserComputation. - int64 request_count(VersionedComputationHandle::Version version) const { - return version; - } - - // Returns a copy of the internal session state for this computation -- this - // is useful for serializing the guts of a user computation, though references - // to other handles (e.g. referred-to computations) must be handled with care - // in the serialization / de-serialization process. - SessionComputation CloneSessionComputation( - VersionedComputationHandle::Version version) const; - - // Warning: typically we don't want to look up computation data handles until - // the computation is finished being built, for consistency purposes. We - // expose this routine for error reporting purposes so that we can provide - // more meaningful error messages from the XLA service layer. - // - // Returns the operation request that the handle comes from. - StatusOr LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const; - - // Retrieves the parameter metadata for the given parameter number. - // - // If the parameter number is invalid for this computation, nullopt is - // returned. When the return value has_value(), nullptr will never be - // the held value. - tensorflow::gtl::optional ParameterMetadata( - int parameter_number) const; - - private: - // Warning: dangerous mutating operation that doesn't respect versioning. - // This is only used at initialization time when constructing from a - // SessionComputation a la MakeWithRemapping. - // - // Remaps references to old computations (with handle values in the keys of - // old_to_new) to the computation handle given in the values. This is useful - // when loading computations from snapshots, to finish initialization, before - // the user computation is released into the wild. - Status RemapEmbeddedComputations( - const std::map& old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle. - StatusOr LookUpRequest( - const ComputationDataHandle& handle) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Creates a new ComputationDataHandle with the next available handle value. - ComputationDataHandle CreateComputationDataHandle() - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Checks whether the parameter numbers of the parameter operations are - // contiguous starting from zero. Returns appropriate error status if not. - Status CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - VersionedComputationHandle GetVersionedHandleInternal() const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Name of the computation. - string name_; - - mutable tensorflow::mutex mutex_; - - // State of the computation as a record of all operation-building requests. - SessionComputation session_computation_ GUARDED_BY(mutex_); - - // Mapping from parameter number to operation request containing the - // respective ParameterRequest. - std::map parameters_ GUARDED_BY(mutex_); - - // The next ComputationDataHandle value to assign. Handle values are assigned - // sequentially. - int64 next_handle_value_ GUARDED_BY(mutex_); - - // If handle_to_return_.has_handle() then an Execution of this Computation - // will compute the value represented by handle_to_return_, otherwise it will - // compute the value of (next_handle_value_ - 1). - ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); - - // Memoized ProgramShape and its version. A shared_ptr is used because - // references to this object are returned by ComputeProgramShape. - mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; - mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc deleted file mode 100644 index 2fa163953f638c0038e9f6bb11ce2a3742e0558c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using UserComputationTest = ::testing::Test; - -TEST_F(UserComputationTest, SimpleComputation) { - const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); - const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); - - // Build a simple three operation computatation: - // - // %constant = Constant({123, 42}) - // %param = Param(0) - // %outfeed = Outfeed(%constant) - // - // Build the computation at two different versions and check invariants. - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest constant_request; - *constant_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); - - ParameterRequest param_request; - *param_request.mutable_shape() = kScalarShape; - param_request.set_parameter(0); - param_request.set_name("param0"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); - OpMetadata metadata; - metadata.set_op_name("meta"); - TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); - - OutfeedRequest outfeed_request; - *outfeed_request.mutable_operand() = constant_handle; - *outfeed_request.mutable_shape() = kVectorShape; - outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, - computation.AddOutfeedInstruction(outfeed_request)); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - { - // Test the computation at the latest version. In this case, the most - // recently added operation is an outfeed. However, the outfeed is not the - // root because outfeeds cannot be the root of a computation. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Program shape should have a single scalar parameter and scalar - // result. The outfeed instruction should not affect the program shape. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(latest_version.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - DebugOptions())); - // There should be one HloInstruction per UserComputation operation. - EXPECT_EQ(3, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - - { - // Test the computation at the version right after the parameter instruction - // is added. - VersionedComputationHandle version_at_param = - computation.GetVersionedHandleAtOperation(param_handle); - - // Program shape should have a single scalar parameter, and scalar result. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(version_at_param.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // There should be two instructions, one for the constant and one for the - // parameter. The outfeed instruction should not be included. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(version_at_param.version, hlo_resolver, - DebugOptions())); - EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - { - // Test the computation at the latest version, but lowered with - // include_unreachable_instructions set to false. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, DebugOptions(), - /*include_unreachable_instructions=*/false)); - // There is only one reachable instruction, the parameter. - EXPECT_EQ(1, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), - "meta"); - } -} - -TEST_F(UserComputationTest, EliminateScalarBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with scalar broadcast. - // - // %a = Constant({123, 42}) - // %b = Constant(1) - // %add = Add(%a, %b) - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest a_request; - *a_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); - - ConstantRequest b_request; - *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - // The binary operation has implicit scalar broadcast, should be converted - // to an explicit broadcast intruction and a binary instruction. - EXPECT_EQ(4, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - LOG(INFO) << hlo_computation->root_instruction()->ToString(); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with degenerate broadcast. - // - // %a = Param({1, 2, 3}); - // %b = Param({1, 2, 1}); - // %add = Add(%a, %b, {}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - const int64 kDevice = 7; - OpSharding sharding; - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(kDevice); - - TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // b a - // | | - // reshape | - // | | - // broadcast | - // \ / - // add - EXPECT_EQ(5, hlo_computation->instruction_count()); - ASSERT_THAT( - hlo_computation->root_instruction(), - op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); - - const HloInstruction* broadcast = - hlo_computation->root_instruction()->operand(1); - EXPECT_TRUE(broadcast->has_sharding()); - - const HloInstruction* reshape = broadcast->operand(0); - EXPECT_TRUE(reshape->has_sharding()); -} - -TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with in-dim broadcast and degenerate broadcast. - // - // %a = Param({2, 3}); - // %b = Param({2, 1, 4}); - // %add = Add(%a, %b, {0, 1}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - add.add_broadcast_dimensions(0); - add.add_broadcast_dimensions(1); - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // b a - // | | - // broadcast reshape - // | | - // | broadcast - // \ / - // add - EXPECT_EQ(6, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h deleted file mode 100644 index 5732a56caffa31dde52dff5c2775f9fde0cacfbd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// A data structure encapsulating a ComputationHandle and version value of that -// computation. This object is used to unambiguously refer to a particular -// computation in the service. -struct VersionedComputationHandle { - // A version value unambiguously specifying the state of the computation at a - // particular point in time as it is being built. This value is the - // ComputationDataHandle of the current root instruction. - using Version = int64; - - ComputationHandle handle; - Version version; - - string ToString() const; - bool operator==(const VersionedComputationHandle& other) const { - return (handle.handle() == other.handle.handle()) && - (version == other.version); - } - bool operator<(const VersionedComputationHandle& other) const { - return ((handle.handle() < other.handle.handle()) || - ((handle.handle() == other.handle.handle()) && - (version < other.version))); - } -}; - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc new file mode 100644 index 0000000000000000000000000000000000000000..10fc4958fae06414dbe7a3a0a798cb5c6e0f35c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Replaces all uses of old_instr with new_instr except the use at +// `while_body_root` (which must be a tuple instruction) at index `tuple_index`. +// This utility helps us replace an instruction in the while body with a +// constant while still keeping it trivially loop invariant. +static Status ReplaceUsesWhileKeepingLoopInvariance( + HloInstruction* old_instr, HloInstruction* new_instr, + HloInstruction* while_body_root, int64 tuple_index) { + CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple); + + std::vector users; + users.reserve(old_instr->user_count()); + c_copy(old_instr->users(), std::back_inserter(users)); + + for (auto* user : users) { + for (int64 i = 0, e = user->operand_count(); i < e; i++) { + if (user->operand(i) == old_instr && + !(user == while_body_root && i == tuple_index)) { + TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr)); + } + } + } + + return Status::OK(); +} + +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( + HloInstruction* while_instr) { + HloComputation* while_body = while_instr->while_body(); + + const HloInstruction& init_value = *while_instr->operand(0); + if (init_value.opcode() != HloOpcode::kTuple) { + return false; + } + + bool changed = false; + + for (HloInstruction* invariant_gte : + WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + int64 index = invariant_gte->tuple_index(); + const HloInstruction& invariant_value = *init_value.operand(index); + if (invariant_value.opcode() == HloOpcode::kConstant) { + auto* constant_instr = + while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( + invariant_gte, constant_instr, while_body->root_instruction(), + index)); + changed = true; + } + } + + return changed; +} + +StatusOr WhileLoopConstantSinking::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + // Right now we don't particulary care about optimizing while-of-while + // patterns. If/When we do, we'll want to visit the outer while (while_0) + // before we visit the inner while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(constant, ...), body=while_0_body, ...) + // } + // + // This will let us sink the constant into the outer while first and then + // into the inner while in a single run of this pass. + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // We only sink into while loop bodies, but this can be extended to + // transform conditions as well. + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileBody(while_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h new file mode 100644 index 0000000000000000000000000000000000000000..21fb8568a84985692026e145c363500a154a1599 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Sinks while loop invariant values that happen to be constants into the while +// loop body. This is probably not a win in isolation but may unlock further +// optimizations like constant folding. +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(v) +// state = (..., v, ...) +// } +// +// => +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(const) +// state = (..., v, ...) +// } +// +// Note that it leaves the `v` in place to keep that component of the state +// tuple trivially loop invariant. WhileLoopSimplifier will later get rid of +// `v`. +// +// We only sink into while loop bodies, but this can be extended to transform +// conditions as well. +// +// TODO(b/79121449): We should also sink broadcasts of constants. +class WhileLoopConstantSinking : public HloPassInterface { + public: + ~WhileLoopConstantSinking() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..393e75803888d8a642881c4d525b170d1e1180ba --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +class WhileLoopConstantSinkingTest : public ::testing::Test {}; + +TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant()), _)); +} + +TEST_F(WhileLoopConstantSinkingTest, KeepConstantsLoopInvariant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=1 + p_body.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=2 + + add.0 = f32[2] add(p_body.1, p_body.2) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, TupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],(f32[2],f32[2])) parameter(0) + p_b.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=0 + p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1 + + p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0 + + ROOT root = (f32[2],f32[2],f32[2]) tuple(p_b.1.1, p_b.1) +} + +condition { + p_cond = (f32[2],(f32[2],f32[2])) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) + ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(op::Constant(), 0), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, DuplicateGTEs) { + // This test shows that the pass fails to optimize non-canonical IR. + // + // Even though the input IR has a constant value for p_b.2.dup, + // WhileLoopConstantSinking doesn't try to detect this. Instead, it relies on + // prior runs of HLO CSE to have commoned these identical GTE instructions. + + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],f32[2],f32[2]) parameter(0) + + p_b.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=1 + p_b.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + p_b.2.dup = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + + add.0 = f32[2] add(p_b.1, p_b.2.dup) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_b.1, p_b.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index a5f9b01f011ce04f1114c74391a967c62f015221..09ddcffb22c2184262adf87d570870ec000c0e6f 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,51 +98,28 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: + case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: + case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; - - case HloOpcode::kTranspose: - return ShapeUtil::TransposeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape(), instruction.dimensions()); - - case HloOpcode::kReshape: - return ShapeUtil::ReshapeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape()); - } -} - -// Populates `gte_set` with the GetTupleElement instructions in `while_body` -// that access elements in the parameter tuple that don't change across -// iterations. Assumes `while_body` is the body computation of the while loop -// in question. -static void GatherInvariantGTEs(HloComputation* while_body, - FlatSet* gte_set) { - const HloInstruction::InstructionVector root_operands = - while_body->root_instruction()->operands(); - for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && - instr->operand(0) == while_body->parameter_instruction(0) && - ShapeUtil::IsArray(instr->shape())) { - InsertOrDie(gte_set, instr); - } } } -static StatusOr TryHoistingInvariantInstructionsFromWhileBody( +StatusOr +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -180,14 +157,24 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( // unhoisted_invariant_instructions -- they can be legally hoisted, but there // is no benefit to hoisting them unless something that uses it is also // hoisted. - GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + if (ShapeUtil::IsArray(instr->shape())) { + // TODO(b/79147885): We should try to generalize this to tuples for + // uniformity's sake, if nothing else. + InsertOrDie(&unhoisted_invariant_instructions, instr); + } + } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -264,6 +251,9 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { @@ -291,6 +281,14 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b0003c48cfacb9d28e7c8259ac0927d66..8e6cc8787576e4f041229da5cf8dd2b09194eb2a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 799340fda905fb7d40b19b4cb79bb0fcb5629fd3..8831c513eee66e36163135b732f833d46cb7eb03 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { EXPECT_FALSE(simplified_loop); } +const char* const kConstantHoistingTestCase = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2]{0}) parameter(0) + p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 + const = f32[2]{0} constant({3, 4}) + add.0 = f32[2]{0} add(p_body.1, const) + ROOT root = (f32[2]{0}) tuple(add.0) +} + +condition { + p_cond = (f32[2]{0}) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2]{0} constant({1, 2}) + while_init = (f32[2]{0}) tuple(const_0) + ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = module().GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + + // We expect the while body to be the equivalent of: + // + // wide.body { + // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0) + // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0 + // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1) + // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0 + // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7) + // tuple.3 = (f32[2]{0}) tuple(add.1) + // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0 + // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8, + // get-tuple-element.9) + // } + + auto wide_param_1 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0); + auto tuple_1 = op::Tuple(get_tuple_element_1); + auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0); + auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1); + auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7); + auto tuple_3 = op::Tuple(add_1); + auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0); + auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1); + auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9); + + EXPECT_THAT(while_body->root_instruction(), tuple_4); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 981de9b2200a9ae8938db21299580f510834d2f0..ec05a74e286c89dd8db5ae07580e461938d7c087 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -212,7 +213,7 @@ static optional GetLoopTripCount(HloInstruction* while_op) { // Now that we know the index of the induction variable, we can we can try to // compute how many times the loop executes. Start by computing the induction // variable's initial value. - HloEvaluator evaluator; + HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); StatusOr> indvar_init_result = @@ -605,6 +606,78 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { return false; } +static StatusOr TryPropagateConstant(HloInstruction* while_op) { + auto while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + return false; + } + + auto while_body = while_op->while_body(); + auto while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + auto while_body_param = while_body->parameter_instruction(0); + const HloInstruction::InstructionVector& root_operands = + while_body_root->operands(); + + // Find the loop invariant tuple elements with scalar constant init value and + // build a map from the tuple element index to the constant value. Limit this + // to scalar constant values because propagating array constants can regress + // performance by forcing us to copy constants. + tensorflow::gtl::FlatMap index_to_constant; + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && instr->operand(0) == while_body_param && + ShapeUtil::IsScalar(instr->shape())) { + auto tuple_element = while_init->operand(i); + if (tuple_element->IsConstant()) { + VLOG(3) << "Found loop invariant tuple element " << i << " " + << tuple_element->ToString(); + index_to_constant[i] = tuple_element; + } + } + } + + if (index_to_constant.empty()) { + return false; + } + + // Replace the use of each constant tuple element in the loop_condition and + // loop_body with the corresponding constant value. + auto propagate_constant = [&](HloComputation* computation) -> StatusOr { + HloInstruction* param = computation->parameter_instruction(0); + bool changed = false; + for (auto instr : param->users()) { + // Since only a while-loop with a tuple result reaches here, we can safely + // assume that `param` is a tuple and the first operand of the + // GetTupleElement instruction is a use of `param`. + if (instr->opcode() == HloOpcode::kGetTupleElement) { + VLOG(3) << "tuple index " << instr->tuple_index() << " " + << instr->ToString(); + auto iter = index_to_constant.find(instr->tuple_index()); + if (iter != index_to_constant.end()) { + const HloInstruction* hlo_constant = (*iter).second; + VLOG(3) << "Replace use of " << instr->ToString() << " with " + << hlo_constant->ToString(); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith( + computation->AddInstruction(hlo_constant->Clone()))); + changed = true; + } + } + } + return changed; + }; + + TF_ASSIGN_OR_RETURN(bool changed_cond, + propagate_constant(while_op->while_condition())); + TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); + + return changed_cond || changed_body; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -635,7 +708,11 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } - StatusOr result = TryRemoveWhileLoop(while_op); + StatusOr result = TryPropagateConstant(while_op); + TF_RETURN_IF_ERROR(result.status()); + changed |= result.ValueOrDie(); + + result = TryRemoveWhileLoop(while_op); TF_RETURN_IF_ERROR(result.status()); if (result.ValueOrDie()) { changed = true; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index d3d55634c97bbdf3f81321d8089bb808c411340b..3d3e1d60f294c3a2574513c1c2f071805a341ad1 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. -// - A while loops with static trip count of 1 is replaced by its body (sans +// - A while loop with static trip count of 1 is replaced by its body (sans // loop). // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index c5183f8d3aee99696ed4114c3f7e451888222137..619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -26,112 +27,140 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { - public: - // Makes a computation that contains a loop that runs num_iters times. - HloComputation* MakeSimpleLoop(int num_iters, HloModule* module); - - // Makes a computation which has one parameter, of the given shape, and always - // returns PRED[]{true}. This is useful as a dummy loop condition. - HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, - HloModule* module); + protected: + // Makes an HloModule that contains a loop with `num_iters` iteration. + void MakeModuleWithSimpleLoop(int num_iters); + + // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to + // the loop-condition through an element of a tuple which is the + // loop-condition parameter. + void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters, - HloModule* module) { - HloComputation::Builder builder(TestName()); - - auto loop_iter_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - auto loop_data_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1, 2}))); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto loop_var = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(42 + num_iters))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, - limit)); - condition = module->AddEmbeddedComputation(cond_builder.Build()); +void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { + string hlo_string_template = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto loop_var = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto new_loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto loop_data = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - loop_data_init->shape(), loop_var, 1)); - auto new_loop_data = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_data_init->shape(), HloOpcode::kMultiply, loop_data, - loop_data)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); - body = module->AddEmbeddedComputation(body_builder.Build()); + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant({{LOOP_BOUND}}) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - return module->AddEntryComputation(builder.Build()); + string hlo_string = tensorflow::str_util::StringReplace( + hlo_string_template, "{{LOOP_BOUND}}", + tensorflow::strings::StrCat(42 + num_iters), + /*replace_all=*/true); + ParseAndVerifyModule(hlo_string); } -HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation( - const Shape& param_shape, HloModule* module) { - HloComputation::Builder builder(TestName() + ".always_true"); - builder.AddInstruction( - HloInstruction::CreateParameter(0, param_shape, "param")); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); - return module->AddEmbeddedComputation(builder.Build()); +void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( + int num_iters) { + string hlo_string_template = R"( + HloModule SimpleLoopWithIndirectLoopBound + SimpleLoopWithIndirectLoopBound.body { + loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + limit = s32[] get-tuple-element(loop_var.1), index=2 + ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit) + } + SimpleLoopWithIndirectLoopBound.condition { + loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 + ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) + } + ENTRY SimpleLoopWithIndirectLoopBound { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + constant.2 = s32[] constant({{LOOP_BOUND}}) + tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4, + constant.2) + ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1), + condition=SimpleLoopWithIndirectLoopBound.condition, + body=SimpleLoopWithIndirectLoopBound.body + } + )"; + + string hlo_string = tensorflow::str_util::StringReplace( + hlo_string_template, "{{LOOP_BOUND}}", + tensorflow::strings::StrCat(42 + num_iters), + /*replace_all=*/true); + ParseAndVerifyModule(hlo_string); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module()); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), +TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/0); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), +TEST_F(WhileLoopSimplifierTest, + LoopWithZeroIterationTupleElementLoopBoundSimplified) { + MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), + op::Tuple(op::Constant(), op::Constant(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) { - MakeSimpleLoop(/*num_iters=*/2, &module()); +TEST_F(WhileLoopSimplifierTest, + LoopWithOneIterationTupleELementLoopBoundSimplified) { + MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), + op::Tuple(op::Add(), op::Multiply(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/2); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, + LoopWithControlDependencySimplifiedDependencyPreserved) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -139,8 +168,10 @@ TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -149,11 +180,13 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(Literal::CreateR0(true))), /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -161,247 +194,217 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and // other non-removable instructions) isn't fundamental -- it just stems from the // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); while_body->AddInstruction( HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } -// Check that we don't crash when given a loop whose shape is not a tuple. -TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(100))))); - condition = module().AddEmbeddedComputation(cond_builder.Build()); +// A non-tuple shaped loop shouldn't be simplified or crash the compiler. +TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { + const string hlo_string = R"( + HloModule NonTupleShapedLoop + NonTupleShapedLoop.body { + loop_var.1 = s32[] parameter(0) + constant.1 = s32[] constant(-1) + ROOT add = s32[] add(s32[] loop_var.1, s32[] constant.1) + } + NonTupleShapedLoop.condition { + loop_var = s32[] parameter(0) + constant = s32[] constant(100) + ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant) + } + ENTRY INonTupleShapedLoop { + constant.2 = s32[] constant(42) + ROOT while = s32[] while(s32[] constant.2), + condition=NonTupleShapedLoop.condition, + body=NonTupleShapedLoop.body } + )"; - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))))); - body = module().AddEmbeddedComputation(body_builder.Build()); - } - - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -// Construct a loop where we swap the tuple elements in each iteration. -// Although the tuple elements aren't used in the loop, we don't eliminate them, -// because the swapping side-effect is visible to users of the loop. -TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - body_builder.AddInstruction(HloInstruction::CreateTuple({ - body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)), - body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)), - })); - body = module().AddEmbeddedComputation(body_builder.Build()); +// A while loop that does nothing else besides swapping tuple elements +// can't be simplified as the result of the swapping is visible to users of the +// loop. +TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { + const string hlo_string = R"( + HloModule SwappingTupleElements + SwappingTupleElements.body { + loop_var = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) loop_var),index=1 + get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[]) loop_var), + index=0 + ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element, + s32[] get-tuple-element.1) } + SwappingTupleElements.always_true { + param = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y) + ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each // iteration. We can't eliminate tuple element 0, even though we never use its // value. -TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0)))})); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateTuple({ - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, + LoopWithUnusedButModifiedTupleElementNotSimplified) { + const string hlo_string = R"( + HloModule UnusedButModifiedTupleElement + UnusedButModifiedTupleElement.body { + loop_var = (s32[]) parameter(0) + constant.1 = s32[] constant(1) + ROOT tuple = (s32[]) tuple(s32[] constant.1) } + UnusedButModifiedTupleElement.always_true { + param = (s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY UnusedButModifiedTupleElement { + constant.2 = s32[] constant(0) + tuple.1 = (s32[]) tuple(s32[] constant.2) + ROOT while = (s32[]) while((s32[]) tuple.1), + condition=UnusedButModifiedTupleElement.always_true, + body=UnusedButModifiedTupleElement.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. -TEST_F(WhileLoopSimplifierTest, EmptyTuple) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({})); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateTuple({})); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { + const string hlo_string = R"( + HloModule EmptyTuple + EmptyTuple.body { + loop_var = () parameter(0) + ROOT tuple = () tuple() + } + EmptyTuple.always_true { + param = () parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY EmptyTuple { + tuple.1 = () tuple() + ROOT while = () while(() tuple.1), condition=EmptyTuple.always_true, + body=EmptyTuple.body } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't // be simplified away. -TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto* param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "param0")); - auto* gte0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); - // get0 is used twice in the loop body's tuple. - body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0})); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { + const string hlo_string = R"( + HloModule ElemUsedTwice + ElemUsedTwice.body { + param0 = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param0), index=0 + ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element, + s32[] get-tuple-element) + } + ElemUsedTwice.always_true { + param = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) } + ENTRY ElemUsedTwice { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y) + ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1), + condition=ElemUsedTwice.always_true, body=ElemUsedTwice.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be // removed. Element 1 is used by the loop body, and element 2 is used by the // loop condition; these two should stay. -TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - })); - auto loop_shape = loop_init->shape(); - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".loop_condition"); - auto param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_shape, "param0")); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - scalar_s32, param, /*index=*/2)))); - condition = module().AddEmbeddedComputation(cond_builder.Build()); +TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { + const string hlo_string = R"( + HloModule RemoveUnusedOperands + RemoveUnusedOperands.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[], + s32[]) loop_var), index=0 + get-tuple-element.2 = s32[] get-tuple-element((s32[], s32[], + s32[]) loop_var), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1) + get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[]) + loop_var), index=2 + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element.1, + s32[] add, s32[] get-tuple-element.3) } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto* param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_shape, "loop_var")); - - auto* tuple0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); - auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary( - scalar_s32, HloOpcode::kAdd, - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - scalar_s32, param, /*index=*/1)), - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto* tuple2 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({tuple0, tuple1, tuple2})); - - body = module().AddEmbeddedComputation(body_builder.Build()); + RemoveUnusedOperands.loop_condition { + constant.2 = s32[] constant(0) + param0 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0), + index=2 + ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element) } + ENTRY RemoveUnusedOperands { + x = s32[] parameter(0) + constant.3 = s32[] constant(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] constant.3, + s32[] y) + ROOT while = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) tuple.1), + condition=RemoveUnusedOperands.loop_condition, + body=RemoveUnusedOperands.body + } + )"; + + ParseAndVerifyModule(hlo_string); + HloModule* the_module = &module(); + EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + + // The original while instruction is still left in the module as a dead + // instruction, find a while instruction with a different name as the new + // while instruction. + HloInstruction* new_while_op = + *std::find_if(the_module->entry_computation()->instructions().begin(), + the_module->entry_computation()->instructions().end(), + [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); - auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - - // We leave most of the checking to HloVerifiedTestBase, which runs the - // verifier on module() at the end of this test. - HloInstruction* new_while_op = *std::find_if( - module().entry_computation()->instructions().begin(), - module().entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return instr != while_op && instr->opcode() == HloOpcode::kWhile; - }); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( ShapeUtil::Equal(new_while_op->shape(), ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}))) @@ -418,31 +421,91 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } -TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); - - HloComputation* while_body = [&]() { - HloComputation::Builder builder(TestName() + ".passthrough"); - HloInstruction* param = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); - - result->AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); - return result; - }(); - - HloComputation::Builder builder(TestName()); - auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); - builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopSimplifier{}.Run(&module())); - EXPECT_FALSE(simplified_loop); +TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { + const string hlo_string = R"( + HloModule BodyHasNonTupleRoot + BodyHasNonTupleRoot.passthrough { + ROOT param = (s32[], s32[]) parameter(0) + } + BodyHasNonTupleRoot.always_true { + param.1 = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY BodyHasNonTupleRoot { + init_value = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while((s32[], s32[]) init_value), + condition=BodyHasNonTupleRoot.always_true, + body=BodyHasNonTupleRoot.passthrough + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, + LoopWithNonTupleBodyRootInstructionNotSimplified) { + const string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply), + custom_call_target="x" + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(44) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { + const string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s32[3]{0} get-tuple-element(loop_var.1), index=2 + add.2 = s32[3]{0} add(get-tuple-element.2, get-tuple-element.3) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, add.2, get-tuple-element.3) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(47) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4, constant.4) + ROOT while = (s32[], s32[3]{0}, s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index e20b25e4a08a946f6b58575a4d4e557744f8035c..473eab2ea84eb8faf745cbe299bc80bcc1b62a35 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,18 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { +using tensorflow::strings::StrCat; + static StatusOr WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { const Shape& narrow_shape = narrow_condition->parameter_instruction(0)->shape(); HloComputation* wide_while_cond = [&]() { - HloComputation::Builder builder( - tensorflow::strings::StrCat("wide.", narrow_condition->name())); + HloComputation::Builder builder(StrCat("wide.", narrow_condition->name())); builder.AddInstruction( HloInstruction::CreateParameter(0, wide_shape, "wide_param")); @@ -57,8 +60,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); HloComputation* wide_while_body = [&]() { - HloComputation::Builder builder( - tensorflow::strings::StrCat("wide.", narrow_body->name())); + HloComputation::Builder builder(StrCat("wide.", narrow_body->name())); builder.AddInstruction( HloInstruction::CreateParameter(0, wide_shape, "wide_param")); return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); @@ -115,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while = containing_computation->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); - TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( - while_instr, TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + + // We want to get rid of the old while instruction even if it has side + // effecting operations so we do a manual HloComputation::RemoveInstruction + // instead of relying on HloComputation::ReplaceInstruction. + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; @@ -137,4 +143,126 @@ WhileUtil::MakeInstructionsLiveIn( return std::move(result); } + +static StatusOr> +MakeCountedLoopConditionComputation(const Shape& loop_state_shape, + int32 trip_count) { + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + + TF_ASSIGN_OR_RETURN(std::unique_ptr cond_computation, + CreateComputationWithSignature( + {&loop_state_shape}, scalar_pred, "while_cond")); + + HloInstruction* trip_count_constant = cond_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + + HloInstruction* param = cond_computation->parameter_instruction(0); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * compare, + MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); + cond_computation->set_root_instruction(compare); + return std::move(cond_computation); +} + +static StatusOr> MakeCountedLoopBodyComputation( + const Shape& loop_state_shape, + const std::function( + HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) { + TF_ASSIGN_OR_RETURN(std::unique_ptr body_computation, + CreateComputationWithSignature( + {&loop_state_shape}, loop_state_shape, "while_body")); + HloInstruction* one = body_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction* param = body_computation->parameter_instruction(0); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar, + MakeBinaryHlo(HloOpcode::kAdd, indvar, one)); + + std::vector loop_body_generator_args; + for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) { + TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element, + MakeGetTupleElementHlo(param, i)); + loop_body_generator_args.push_back(tuple_element); + } + TF_ASSIGN_OR_RETURN(std::vector next_state, + loop_body_generator(indvar, loop_body_generator_args)); + next_state.insert(next_state.begin(), next_indvar); + HloInstruction* next_state_tuple = + body_computation->AddInstruction(HloInstruction::CreateTuple(next_state)); + body_computation->set_root_instruction(next_state_tuple); + + return std::move(body_computation); +} + +static StatusOr MakeInitTupleFromInitValues( + HloComputation* computation, const WhileUtil::LoopStateTy& init_values) { + std::vector init_values_with_indvar; + init_values_with_indvar.reserve(init_values.size() + 1); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + init_values_with_indvar.push_back(zero); + c_copy(init_values, std::back_inserter(init_values_with_indvar)); + return computation->AddInstruction( + HloInstruction::CreateTuple(init_values_with_indvar)); +} + +static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { + std::vector loop_state_shape_components; + loop_state_shape_components.reserve(init_values.size() + 1); + loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); + c_transform(init_values, std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); + return ShapeUtil::MakeTupleShape(loop_state_shape_components); +} + +/*static*/ StatusOr WhileUtil::MakeCountedLoop( + HloComputation* computation, int32 trip_count, + const WhileUtil::LoopStateTy& init_values, + const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) { + CHECK_GE(trip_count, 0); + + Shape loop_state_shape = MakeLoopStateShape(init_values); + TF_ASSIGN_OR_RETURN( + std::unique_ptr cond, + MakeCountedLoopConditionComputation(loop_state_shape, trip_count)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr body, + MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator)); + TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple, + MakeInitTupleFromInitValues(computation, init_values)); + HloModule* module = computation->parent(); + HloInstruction* while_instr = + computation->AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, module->AddEmbeddedComputation(std::move(cond)), + module->AddEmbeddedComputation(std::move(body)), init_tuple)); + + std::vector result; + for (int64 i = 0, e = init_values.size(); i < e; i++) { + TF_ASSIGN_OR_RETURN(HloInstruction * user_state, + MakeGetTupleElementHlo(while_instr, i + 1)); + result.push_back(user_state); + } + return result; +} + +/*static*/ std::vector WhileUtil::GetInvariantGTEsForWhileBody( + const HloComputation& while_body) { + std::vector result; + const HloInstruction::InstructionVector root_operands = + while_body.root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body.parameter_instruction(0)) { + result.push_back(instr); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 3600b5a80d26e37fdb7d5173c3b8743734306390..e67636d80f4b682fe1335eae535fb86105ac082b 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -38,20 +38,52 @@ class WhileUtil { }; // Replaces `while_instr` with a new while instruction that is equivalent to - // `while_instr`, except that it has all of the HLO instructions in + // `while_instr` except that it has all of the HLO instructions in // `instructions` as live-in, loop invariant values. These new live in values // are represented as new elements appended to the parameter of the while // loop, which must be of tuple shape. GetTupleElement instructions computing // each new live in value is returned in the `while_body_live_in_values` // vector. // - // Precondition: `while_instr` must have a tuple shaped state. + // Deletes `while_instr` after replacing it. // - // Every instruction in `instructions` must be contained in the computation - // that contains `while_instr`. + // Preconditions: + // + // `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); + + using LoopStateTy = std::vector; + using LoopBodyGeneratorTy = std::function( + HloInstruction* /*induction_var*/, + const LoopStateTy& /*current_values*/)>; + + // Creates a while loop in `computation` that runs for `trip_count` + // iterations. The structure of the while loop is as follows, in pseudocode: + // + // loop_state while_loop() { + // indvar = 0; + // loop_state = init_values + // while (indvar < trip_count) { + // loop_state = loop_body_generator(loop_state) + // indvar++; + // } + // return loop_state; + // } + static StatusOr MakeCountedLoop( + HloComputation* computation, int32 trip_count, + const LoopStateTy& init_values, + const LoopBodyGeneratorTy& loop_body_generator); + + // Returns the GetTupleElement instructions in `while_body` that access + // elements in the parameter tuple that don't change across iterations. + // Assumes `while_body` is the body computation of the while loop in question. + static std::vector GetInvariantGTEsForWhileBody( + const HloComputation& while_body); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index cf0d0db99bd92b6b364b4e28e56a0902d4065963..d79d3297213e832306ea4726483b0f215df0f5d3 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace { @@ -49,7 +50,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -126,5 +127,84 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) { op::GetTupleElement(op::Parameter(0), 3))); } +TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* while_body = module->GetComputationWithName("body"); + + ASSERT_NE(while_body, nullptr) + << "Expected exactly one while_body computation"; + + std::vector gte_list = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + ASSERT_EQ(gte_list.size(), 1); + EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); +} + +TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { + const char* const hlo_string = R"( +HloModule WhileWithSideEffects + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT condition = pred[] infeed() +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + to_make_live_in = f32[100] parameter(1) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* main = module->GetComputationWithName("main"); + HloInstruction* while_instr = main->root_instruction(); + HloInstruction* to_make_live_in = main->parameter_instruction(1); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{to_make_live_in})); + + auto is_while = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }; + EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 063e312df66ce9cba0fa9f49c2fc6026ba6b74aa..8763e588c484011ba2ccbc7cad8f29817347a605 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -// HLO pass that replaces zero sized Hlos with an zero sized constant literal. +// HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index 4f8cdc1e0e73cdaa8675fc945ba3dbe19ce3da7d..f5331280ee9f252aa5717baab88f2c203be5c372 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,9 +45,9 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - HloModule module("zero_sized_elimination_test_module"); - module.AddEntryComputation(builder_.Build()); - return ZeroSizedHloElimination{}.Run(&module); + auto module = CreateNewModule("zero_sized_elimination_test_module"); + module->AddEntryComputation(builder_.Build()); + return ZeroSizedHloElimination{}.Run(module.get()); } HloComputation::Builder builder_; diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 809941d8fe1f63d66bf104e66eea66167a0f509d..14c35e7b84f07bebac33a9753ac26a8ee1418f1e 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -31,84 +32,52 @@ class ServiceInterface { virtual ~ServiceInterface() = default; // TODO(b/31824348): Convert to use StatusOr. - virtual tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + virtual Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + virtual Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) = 0; - virtual tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) = 0; - virtual tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) = 0; + virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; - virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) = 0; + virtual Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; + virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; + virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteParallel( - const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) = 0; - virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; + virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) = 0; - virtual tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; + virtual Status GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) = 0; - virtual tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; + virtual Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; - virtual tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; - virtual tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; + virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) = 0; - virtual tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) = 0; - - virtual tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) = 0; - - virtual tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - - // Methods used by ComputationBuilder. - virtual tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual tensorflow::Status SetReturnValue( - const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; - - virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; - - virtual tensorflow::Status ComputeConstant( - const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; - - // Methods used by Computation. - virtual tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; + virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; // Methods used by GlobalData. - virtual tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) = 0; + virtual Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 789eba5780d37e1fd4d80ec881855951c8bba0eb..7ee366b27a82bdbcb7a63a57ea80194db8ca7df4 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -22,24 +22,24 @@ limitations under the License. namespace xla { -tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { +Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(other_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } shape_ = other_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { +Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } *to_shape = shape_; - return tensorflow::Status::OK(); + return Status::OK(); } void ShapeLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 4c83750f3e6f3c735db66d8e0b86ae3f43e5ca11..36806da599cc9b27286e67c128bb7f496f29c105 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -40,7 +40,7 @@ class ShapeLayout { // Assigns the layouts in this ShapeLayout to the Layout fields of the given // shape. 'to_shape' and the shape of the ShapeLayout object must be // compatible. - tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; + Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible @@ -48,9 +48,8 @@ class ShapeLayout { bool MatchesLayoutInShape(const Shape& shape) const; // Copies the layout from the given shape into this ShapeLayout. 'other_shape' - // must be compatible with the ShapeLayout's shape, and 'other_shape' must - // have a layout (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); + // must be compatible with the ShapeLayout's shape. + Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 280f02e88675381bd75108bfae0dd22c462ba718..18e54d23c241ae0d4c61d8be79ff021dfb02a3e6 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -42,36 +42,23 @@ namespace internal { template struct ShapeTreeNode { // Data corresponding to this node. - T data; + std::pair data; - // Children of this node. - std::vector> children; + // Children of this node, as indices into the container's nodes_ array. + std::vector children; - ShapeTreeNode() = default; - explicit ShapeTreeNode(const T& data) : data(data) {} + // Tells whether this is a leaf node. + bool is_leaf = true; - ShapeTreeNode(const ShapeTreeNode& other) - : data(other.data), children(other.children.size()) { - for (size_t i = 0; i < children.size(); ++i) { - children[i] = MakeUnique(*other.children[i]); - } - } - - ShapeTreeNode& operator=(const ShapeTreeNode& other) { - if (this != &other) { - data = other.data; - children.resize(other.children.size()); - for (size_t i = 0; i < children.size(); ++i) { - children[i] = MakeUnique(*other.children[i]); - } - } - return *this; - } + explicit ShapeTreeNode(ShapeIndex index) + : ShapeTreeNode(std::move(index), T()) {} + ShapeTreeNode(ShapeIndex index, T data) + : data(std::move(index), std::move(data)) {} }; } // namespace internal -template +template class ShapeTreeIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a @@ -95,10 +82,9 @@ class ShapeTreeIterator; // before its ShapeTree goes away. template class ShapeTree { - friend class ShapeTreeIterator; - friend class ShapeTreeIterator; - public: + using Node = internal::ShapeTreeNode; + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -110,30 +96,12 @@ class ShapeTree { // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); + explicit ShapeTree(const std::shared_ptr& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); - - ShapeTree(const ShapeTree& other) { *this = other; } - ShapeTree(ShapeTree&&) = default; - - ShapeTree& operator=(const ShapeTree& other) { - root_ = other.root_; - - // Fix up internal pointer if necessary. - if (other.shape_storage_) { - CHECK_EQ(other.shape_, other.shape_storage_.get()); - shape_storage_.reset(new Shape(*other.shape_)); - shape_ = shape_storage_.get(); - } else { - shape_ = other.shape_; - } - - return *this; - } - - ShapeTree& operator=(ShapeTree&& other) = default; + ShapeTree(const std::shared_ptr& shape, const T& init_value); // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -157,67 +125,72 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index)->children.empty(); - } + bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; } + + ShapeTree(const ShapeTree&) = default; + ShapeTree& operator=(const ShapeTree&) = default; + ShapeTree(ShapeTree&&) = default; + ShapeTree& operator=(ShapeTree&& other) = default; - // iterator implements a forward_iterator with value_type = - // std::pair - using iterator = ShapeTreeIterator; - using const_iterator = ShapeTreeIterator; + // iterator implements a bidirectional_iterator with + // value_type = std::pair. + // + // The iteration order is guaranteed to be a pre-order walk of the ShapeTree. + using iterator = + ShapeTreeIterator, typename std::vector::iterator, + std::pair>; + using const_iterator = + ShapeTreeIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; // begin/end for iterating over all nodes. iterator begin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } iterator end() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } // rbegin/rend for iterating over all nodes in reverse. - iterator rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - iterator rend() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); } - const_iterator rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - const_iterator rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -227,22 +200,32 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - iterator leaf_rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } + reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } + const_reverse_iterator leaf_rbegin() const { + return const_reverse_iterator(leaf_end()); } - iterator leaf_rend() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_reverse_iterator leaf_rend() const { + return const_reverse_iterator(leaf_begin()); } - const_iterator leaf_rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/true); + + // Returns an iterator pointing to the given ShapeIndex. + // REQUIRES: index must exist in the ShapeTree. + iterator find(const ShapeIndex& index) { + Node* element = Lookup(index); + return iterator(&nodes_, typename std::vector::iterator(element), + /*iterate_leaves_only=*/false); } - const_iterator leaf_rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_iterator find(const ShapeIndex& index) const { + Node* element = Lookup(index); + return iterator(&nodes_, + typename std::vector::const_iterator(element), + /*iterate_leaves_only=*/false); } + // Returns the number of leaf nodes in the tree. + int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -282,8 +265,6 @@ class ShapeTree { bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: - using Node = internal::ShapeTreeNode; - // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); @@ -292,136 +273,55 @@ class ShapeTree { // default-constructed data values. void InitChildren(const Shape& shape, Node* node); + // Returns the number of subshapes, including interior nodes, in shape. + int64 CountSubshapes(const Shape& shape); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template - static Status ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index); + static Status ForEachHelper(const Fn& func, const std::vector& nodes); template - static Status ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index); + static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); const Node* Lookup(const ShapeIndex& index) const; - // The root node, which contains all other nodes. - Node root_; + // The nodes in this shape tree. + std::vector nodes_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. - std::unique_ptr shape_storage_; + std::shared_ptr shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either // shape_storage_.get() or the Shape pointer passed to our constructor. const Shape* shape_; }; -// Internal iterator that performs a pre-order walk. This is copyable, but -// contains a vector so isn't cheap to copy. This also means post-increment is -// expensive. The iterator value_type is equivalent to a std::pair, similar to std::map. The non-const iterator's T& type can be mutated -// in-place. -template -class ShapeTreeIterator : public std::iterator> { +// Internal iterator that performs a pre-order walk. This is cheap to copy. +// The iterator value_type is equivalent to a +// std::pair&, similar to std::map. +template +class ShapeTreeIterator + : public std::iterator { public: - using value_type = - typename std::conditional, - std::pair>::type; - using NodeType = - typename std::conditional::Node, - typename ShapeTree::Node>::type; - - // Construct an iterator pointing at node. Node must either be the tree root - // or nullptr (which is equivalent to end() and should not be dereferenced or - // incremented). If iterate_leaves_only is true, the iterator will not include - // interior tree nodes, only leaves. If reverse is true, the iterator will - // visit nodes in the reverse of pre-order traversal. - ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) - : node_(node), - iterate_leaves_only_(iterate_leaves_only), - reverse_(reverse) { - if (node_) { - if (reverse_) { - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - } else { - if (!node_->children.empty() && iterate_leaves_only) { - ++*this; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node, + bool iterate_leaves_only) + : nodes_(nodes), + node_(std::move(node)), + iterate_leaves_only_(iterate_leaves_only) { + while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { + ++node_; } } - ShapeTreeIterator(const ShapeTreeIterator& other) - : node_(other.node_), - stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_), - reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { - CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - if (reverse_) { - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second - 1; - stack_.pop_back(); - if (next_child_index < 0) { - if (!iterate_leaves_only_) { - // All children are visited, yield . - return *this; - } - } else { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - return *this; - } - } - } else { - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - // Otherwise we are currently at a leaf. Walk back up until a node - // contains a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - } + ++node_; + while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { + ++node_; } - // We've walked off the end of the tree. Set node_ to nullptr to signify - // end(). - node_ = nullptr; - current_.reset(); return *this; } ShapeTreeIterator operator++(int) { @@ -429,52 +329,62 @@ class ShapeTreeIterator : public std::iterator nodes_->begin() && !node_->is_leaf) { + --node_; + } + return *this; + } + ShapeTreeIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + bool operator==(const ShapeTreeIterator& other) const { return node_ == other.node_; } bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - value_type& operator*() { return UpdateCurrent(); } - value_type* operator->() { return &UpdateCurrent(); } + ValueType& operator*() { return node_->data; } + ValueType* operator->() { return &node_->data; } private: - // Updates the current_ member to reflect the current state. - value_type& UpdateCurrent() { - ShapeIndex index; - for (auto& node_and_index : stack_) { - index.push_back(node_and_index.second); - } - current_ = MakeUnique(index, node_->data); - return *current_; - } - - // The node to which this iterator is pointing. This is the source of truth in - // the iterator - the stack only exists to facilitate walking back from - // children to parents. - NodeType* node_; - // Stack of {node, child-index} pairs of the path taken from the root to get - // to node_. This allows us to backtrack and know where to go next. - std::vector> stack_; + ContainerType* nodes_; + IteratorType node_; // True if we should not include interior nodes in our walk. - bool iterate_leaves_only_; - // True if we should yield the reverse of the pre-order traversal. - bool reverse_; - // Placeholder for the current value. Ideally this wouldn't exist and would - // just be an rvalue, but operator -> needs to return a pointer to something. - // We cannot just use a plain old value_type as it contains a reference so - // cannot be default-constructed. - std::unique_ptr current_; + const bool iterate_leaves_only_; }; +template +int64 ShapeTree::CountSubshapes(const Shape& shape) { + int64 current_count = 1; + if (ShapeUtil::IsTuple(shape)) { + int64 count = ShapeUtil::TupleElementCount(shape); + for (int i = 0; i < count; ++i) { + current_count += CountSubshapes(shape.tuple_shapes(i)); + } + } + return current_count; +} + template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node(init_value)); - InitChildren(shape.tuple_shapes(i), init_value, - node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + node->is_leaf = false; + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index, init_value); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); } } } @@ -482,63 +392,93 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node()); - InitChildren(shape.tuple_shapes(i), node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + node->is_leaf = false; + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index); + InitChildren(shape.tuple_shapes(i), &nodes_.back()); } } } template ShapeTree::ShapeTree(Shape shape) - : root_(), - shape_storage_(MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template -ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { - InitChildren(*shape_, &root_); +ShapeTree::ShapeTree(const std::shared_ptr& shape) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) - : root_(init_value), - shape_storage_(MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, init_value, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) - : root_(init_value), shape_(shape) { - InitChildren(*shape_, init_value, &root_); + : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const std::shared_ptr& shape, + const T& init_value) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index)->data; + return Lookup(index)->data.second; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index)->data; + return &Lookup(index)->data.second; } template internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { - Node* node = &root_; + Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); CHECK_LT(i, node->children.size()); - node = node->children[i].get(); + node = &nodes_[node->children[i]]; } return node; } @@ -552,13 +492,10 @@ const internal::ShapeTreeNode* ShapeTree::Lookup( /* static */ template template -Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, node.data)); - for (int64 i = 0; i < node.children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); - index->pop_back(); +Status ShapeTree::ForEachHelper(const Fn& func, + const std::vector& nodes) { + for (const auto& node : nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, node.data.second)); } return Status::OK(); } @@ -566,14 +503,10 @@ Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, /* static */ template template -Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, &node->data)); - for (int64 i = 0; i < node->children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, node->children[i].get(), index)); - index->pop_back(); +Status ShapeTree::ForEachMutableHelper(const Fn& func, + std::vector* nodes) { + for (auto& node : *nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second)); } return Status::OK(); } @@ -581,40 +514,36 @@ Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, root_, &index); + return ForEachHelper(func, nodes_); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper(func, &root_, &index); + return ForEachMutableHelper(func, &nodes_); } template template void ShapeTree::ForEachElement(const Fn& func) const { - ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, - root_, &index) + nodes_) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { - ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, - &root_, &index) + &nodes_) .IgnoreError(); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4b6ab772811f4a6c6ffc1d10befc7122f883b8f9..51de82e95746281ed6e587b545dc933b48ce1ad4 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -115,6 +116,11 @@ TEST_F(ShapeTreeTest, InitValueConstructor) { TestInitValueConstructor(nested_tuple_shape_, 10); } +TEST_F(ShapeTreeTest, EmptyTupleMustHaveNoLeaves) { + ShapeTree shape_tree{ShapeUtil::MakeTupleShape({})}; + EXPECT_EQ(0, shape_tree.leaf_count()); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; @@ -421,8 +427,8 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - t.begin()->second = 78; - EXPECT_EQ(78, t.begin()->second); + (*t.begin()).second = 78; + EXPECT_EQ(78, (*t.begin()).second); i = 0; for (auto& index_to_data : t) { if (i == 0) { @@ -434,14 +440,14 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - EXPECT_EQ(78, t.begin()->second); - EXPECT_EQ(98, std::next(t.begin())->second); + EXPECT_EQ(78, (*t.begin()).second); + EXPECT_EQ(98, (*std::next(t.begin())).second); } TEST_F(ShapeTreeTest, IterateOrder) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t) { + for (auto index_to_data : t) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{{}, @@ -479,7 +485,7 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t.leaves()) { + for (auto index_to_data : t.leaves()) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{ @@ -502,5 +508,106 @@ TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { })); } +void BM_Construct(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(shape); + } +} + +void BM_ConstructUnowned(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(&shape); + } +} + +void BM_Copy(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = shape_tree; + tensorflow::testing::DoNotOptimize(copy); + } +} + +void BM_Move(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = std::move(shape_tree); + shape_tree = std::move(copy); + } +} + +void BM_ForEach(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + shape_tree.ForEachMutableElement([](const ShapeIndex& index, int* data) { + tensorflow::testing::DoNotOptimize(index); + }); + } +} + +void BM_Iterate(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + for (auto& iter : shape_tree) { + tensorflow::testing::DoNotOptimize(iter.second); + } + } +} + +BENCHMARK(BM_Construct)->ArgPair(2, 8); +BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); +BENCHMARK(BM_Copy)->ArgPair(2, 8); +BENCHMARK(BM_Move)->ArgPair(2, 8); +BENCHMARK(BM_ForEach)->ArgPair(2, 8); +BENCHMARK(BM_Iterate)->ArgPair(2, 8); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 604e0173e789348923316174873f58058eaf2815..5db66599324913b9214d7623597060950246fb03 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -27,11 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -41,17 +41,35 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + string ShapeIndex::ToString() const { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } string ShapeIndexView::ToString() const { - return tensorflow::strings::StrCat( - "{", - tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), - ","), - "}"); + return StrCat("{", + tensorflow::str_util::Join( + tensorflow::gtl::make_range(begin_, end_), ","), + "}"); +} + +bool ShapeIndexView::operator==(const ShapeIndexView& other) const { + if (size() != other.size()) { + return false; + } + for (auto it = begin(), other_it = other.begin(); it != end(); + ++it, ++other_it) { + if (*it != *other_it) { + return false; + } + } + return true; +} + +bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { + return !(*this == other); } std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { @@ -66,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { namespace { +// Returns whether the given primitive type corresponds to an array shape. +bool IsArrayPrimitiveType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { - return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + if (!ShapeUtil::SameElementType(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + + if (ShapeUtil::IsTuple(lhs)) { + return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); - } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { - return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); + } else if (!ShapeUtil::IsArray(lhs)) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return true; } if (compare_layouts) { @@ -107,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; } - if (!ShapeUtil::SameElementType(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } return true; } @@ -153,8 +179,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) - << "Tuples do not have a rank, shape: " << shape; + CHECK(ShapeUtil::IsArray(shape)) + << "Non-arrays do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -181,8 +207,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); return result; @@ -205,8 +230,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, int64 max_sparse_elements) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -253,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return result; } +/* static */ Shape ShapeUtil::MakeTokenShape() { + Shape result; + result.set_element_type(TOKEN); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); @@ -276,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + if (!IsArray(shape)) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -302,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case TUPLE: case OPAQUE: + case TOKEN: return false; default: @@ -317,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return IsArrayPrimitiveType(shape.element_type()); +} + /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), IsTuple); @@ -370,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); + CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -385,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.element_type() == F32 && Rank(shape) == 0; } -/* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { - string text = "("; - const char* prefix = ""; - for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); - prefix = ", "; - } - text += ")"; - return text; - } else { - return tensorflow::strings::StrCat( - tensorflow::str_util::Lowercase( - PrimitiveType_Name(shape.element_type())), - "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); - } -} namespace { @@ -452,48 +471,56 @@ StatusOr StringToPrimitiveType(const string& name) { } // namespace -/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { +/* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, - HumanStringWithLayout(elem_shape)); + StrAppend(&text, prefix, HumanString(elem_shape)); prefix = ", "; } text += ")"; return text; - } else { - string result = tensorflow::strings::StrCat( - LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", - shape.dimensions(i)); + } + return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", + tensorflow::str_util::Join(shape.dimensions(), ","), "]"); +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (IsTuple(shape)) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + StrAppend(&text, prefix, HumanStringWithLayout(elem_shape)); + prefix = ", "; } - result += "]"; - if (!IsScalar(shape) && !IsOpaque(shape)) { - if (LayoutUtil::HasLayout(shape)) { - tensorflow::strings::StrAppend(&result, - LayoutUtil::HumanString(shape.layout())); - } + text += ")"; + return text; + } + string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + for (int i = 0; i < shape.dimensions().size(); i++) { + StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + } + result += "]"; + if (!IsScalar(shape) && IsArray(shape)) { + if (LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } - return result; } + return result; } /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { std::vector parameters; for (auto& shape : program_shape.parameters()) { const int i = parameters.size(); - parameters.push_back( - tensorflow::strings::StrCat(i < program_shape.parameter_names_size() - ? program_shape.parameter_names(i) - : "(unknown)", - ": ", HumanString(shape))); + parameters.push_back(StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); } - return tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", - HumanString(program_shape.result())); + return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); } namespace { @@ -502,20 +529,20 @@ namespace { StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { tensorflow::str_util::RemoveLeadingWhitespace(s); - if (s->Consume("(")) { // Tuple. + if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; bool must_end = false; while (true) { - if (s->Consume(")")) { + if (tensorflow::str_util::ConsumePrefix(s, ")")) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !s->Consume(","); + must_end = !tensorflow::str_util::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } @@ -540,7 +567,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { return InvalidArgument( "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), s->ToString().c_str()); + input.c_str(), std::string(*s).c_str()); } return element; }; @@ -563,14 +590,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the primitive element type. TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || - primitive_type == OPAQUE) { + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (format_string.empty() && layout_string.empty()) { + if (primitive_type == OPAQUE) { + result = ShapeUtil::MakeOpaqueShape(); + } else if (primitive_type == TOKEN) { + result = ShapeUtil::MakeTokenShape(); + } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); } else if (format_string == "sparse") { @@ -593,7 +623,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } return InvalidArgument("Invalid shape string to parse: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } } // namespace @@ -602,45 +632,57 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { return InvalidArgument("Invalid shape string to parse: \"%s\"", - s.ToString().c_str()); + std::string(s).c_str()); } return shape; } /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { + CHECK(ShapeUtil::IsArray(lhs)); + CHECK(ShapeUtil::IsArray(rhs)); return ContainersEqual(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - return SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + CompatibleIgnoringElementType(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return CompatibleIgnoringElementType(lhs, rhs); - } - return false; } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -662,10 +704,6 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { switch (primitive_type) { case PRED: return sizeof(int8); - case TUPLE: - LOG(FATAL) << "tuples have no definitive size"; - case OPAQUE: - LOG(FATAL) << "opaque have no definitive size"; case S8: return sizeof(int8); case S16: @@ -692,6 +730,13 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(double); case C64: return sizeof(complex64); + case TOKEN: + // Tokens require no space. + return 0; + case TUPLE: + case OPAQUE: + LOG(FATAL) << PrimitiveType_Name(primitive_type) + << " primitive type has no definitive size"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } @@ -700,28 +745,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); + } else if (IsArray(shape)) { + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; + } else if (shape.element_type() == TOKEN) { + return 0; } - int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } - return byte_size; + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " primitive type has no definitive size"; } /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_EQ(TUPLE, shape.element_type()); CHECK_GT(pointer_size, 0); return pointer_size * shape.tuple_shapes_size(); } /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(ShapeUtil::IsArray(shape)); + CHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -746,13 +795,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(LayoutUtil::IsSparseArray(shape)); + CHECK(LayoutUtil::IsSparseArray(shape)); return LayoutUtil::MaxSparseElements(shape.layout()) * ShapeUtil::Rank(shape) * sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -768,10 +821,24 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.tuple_shapes_size() > 0) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + + // Tokens and opaques can should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.dimensions_size() != 0) { + return InvalidArgument( + "shape has %s element type, but has dimensions field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + if (shape.has_layout()) { + return InvalidArgument( + "shape has %s element type, but has layout field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + return Status::OK(); } + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " @@ -813,6 +880,18 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return new_shape; } +/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape, + ShapeIndexView index) { + const Shape* subshape = &shape; + for (auto i : index) { + if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size()) { + return false; + } + subshape = &subshape->tuple_shapes(i); + } + return true; +} + /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; @@ -839,57 +918,25 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !IsTuple(GetSubshape(shape, index)); } -/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { - std::vector dimension_sizes; - std::vector degenerate_dimensions; - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.dimensions(i) == 1) { - degenerate_dimensions.push_back(i); - } else { - dimension_sizes.push_back(shape.dimensions(i)); - } - } - - // Construct minor_to_major of stripped shape. The order of the non-degenerate - // dimensions should be preserved from the original shape. First, create - // vector of the non-degenerate dimensions from the original minor_to_major - // array. - std::vector minor_to_major; - for (int64 i : shape.layout().minor_to_major()) { - if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), - i) == degenerate_dimensions.end()) { - minor_to_major.push_back(i); +/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + int64 count = 0; + ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + ++count; } - } + }); + return count; +} - // The dimensions in minor_to_major need to be renumbered to account for the - // degenerate dimensions which have removed. Decrement each dimension number - // once for each degenerate dimension which has a smaller number. - for (int i = 0; i < minor_to_major.size(); ++i) { - int adjustment = 0; - for (int64 dim : degenerate_dimensions) { - if (minor_to_major[i] > dim) { - adjustment++; - } +/* static */ std::vector ShapeUtil::GetLeafShapes( + const Shape& shape) { + std::vector leaves; + ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + leaves.emplace_back(index, sub_shape); } - minor_to_major[i] -= adjustment; - } - - { - std::vector dims(minor_to_major.size()); - std::iota(dims.begin(), dims.end(), 0); - DCHECK(minor_to_major.size() == dims.size() && - std::is_permutation(minor_to_major.begin(), minor_to_major.end(), - dims.begin())); - } - Shape stripped_shape = - shape.has_layout() ? MakeShapeWithLayout(shape.element_type(), - dimension_sizes, minor_to_major) - : MakeShape(shape.element_type(), dimension_sizes); - - VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); - VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); - return stripped_shape; + }); + return leaves; } namespace { @@ -997,6 +1044,9 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { + CHECK(IsArray(shape_pre)); + CHECK(IsArray(shape_post)); + auto nil = std::make_tuple(false, std::vector(), std::vector()); std::vector deleted_indices; @@ -1054,6 +1104,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -1073,9 +1126,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping) { - // Can't insert bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) && - !LayoutUtil::HasLayout(output_shape)) { + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); + + if (!SameElementType(input_shape, output_shape)) { return false; } @@ -1106,9 +1160,12 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - // Can't convert reshapes into bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) || - !LayoutUtil::HasLayout(output_shape)) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + CHECK(LayoutUtil::HasLayout(input_shape)); + CHECK(LayoutUtil::HasLayout(output_shape)); + + if (!SameElementType(input_shape, output_shape)) { return false; } @@ -1268,6 +1325,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + int64 input_rank = Rank(input_shape); int64 output_rank = Rank(output_shape); @@ -1402,6 +1462,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { + CHECK(IsArray(shape)); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); @@ -1423,6 +1484,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { + CHECK(IsArray(shape)); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { @@ -1440,4 +1502,26 @@ std::ostream& operator<<(std::ostream& out, const Shape& shape) { return out; } +/*static*/ size_t ShapeUtil::Hash(const Shape& shape) { + using tensorflow::hash; + using tensorflow::Hash64Combine; + + size_t hash_value = hash()(shape.element_type()); + + if (shape.tuple_shapes().empty()) { + for (int64 dim : shape.dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(dim)); + } + + hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout())); + } else { + hash_value = 0; + for (const Shape& subshape : shape.tuple_shapes()) { + hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape)); + } + } + + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 19b1aa93bd373ebd5f502d0dca56c9b31ab4fd7f..ae2d17d6bbbfed96e1da192253838ae5e9a67e17 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -24,11 +24,16 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -57,6 +62,8 @@ class ShapeIndex { public: ShapeIndex() = default; ShapeIndex(std::initializer_list init) : indices_(init) {} + template + ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} bool empty() const { return indices_.empty(); } size_t size() const { return indices_.size(); } @@ -127,6 +134,10 @@ class ShapeIndexView { ++new_begin; return ShapeIndexView(new_begin, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } + + bool operator==(const ShapeIndexView& other) const; + bool operator!=(const ShapeIndexView& other) const; string ToString() const; @@ -146,12 +157,22 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + // Data structure which describes the coordinates and the shape, of a tuple + // shaped sub-shape. + struct IndexedShape { + IndexedShape() = default; + IndexedShape(ShapeIndex index, Shape shape) + : index(std::move(index)), shape(std::move(shape)) {} + ShapeIndex index; + Shape shape; + }; + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: !IsTuple(shape) + // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -162,13 +183,11 @@ class ShapeUtil { // shapes. This includes only the size of the top-level buffer. For example, a // tuple is stored as an array of pointers to other buffers. In this case, // this method only returns the size of the pointer array. - // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && - // !ShapeUtil::IsOpaque(shape) static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); // Returns the number of bytes used to store the primitive_type. // - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + // Precondition: ShapeUtil::IsArray(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -208,6 +227,7 @@ class ShapeUtil { // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. + // Precondition: IsArray(lhs) && IsArray(rhs) static bool SameDimensions(const Shape& lhs, const Shape& rhs); // Returns whether the lhs and rhs shapes have the same element type. @@ -274,10 +294,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + return IsArray(shape) && Rank(shape) == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + return IsArray(shape) && TrueRank(shape) == 0; } static bool IsScalarF32(const Shape& shape); @@ -306,6 +326,10 @@ class ShapeUtil { // into a custom operation. static Shape MakeOpaqueShape(); + // Creates a token shape. Values of this shape are used for ordering + // side-effecting operations. + static Shape MakeTokenShape(); + // Appends a shape to the given tuple. static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); @@ -315,11 +339,25 @@ class ShapeUtil { // Returns an empty tuple shape. Can be used to indicate side-effects. static Shape MakeNil() { return MakeTupleShape({}); } + // Checks whether the shape is initialized. + static bool IsInitialized(const Shape& shape) { + return shape.element_type() != PRIMITIVE_TYPE_INVALID; + } + // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions); + // Creates a Shape with element type corresponding to T and the given + // dimensions + template + static Shape MakeShapeWithType( + tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + dimensions); + } + // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). static Shape MakeShapeWithLayout( @@ -391,11 +429,15 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } + // Returns whether the shape is an token value used for ordering + // side-effecting operations. + static bool IsToken(const Shape& shape) { + return shape.element_type() == TOKEN; + } + // Returns whether the shape is an array. Note that scalars are considered // arrays. - static bool IsArray(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape); - } + static bool IsArray(const Shape& shape); // Returns whether the shape is a tuple with at least one element which is // also a tuple. @@ -430,6 +472,9 @@ class ShapeUtil { static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); + // Returns true if the given shape has a subshape at the given index. + static bool IndexIsValid(const Shape& shape, ShapeIndexView index); + // GetSubshape and GetMutableSubshape return a particular nested Shape within // the given Shape argument. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); @@ -439,6 +484,13 @@ class ShapeUtil { // shape. static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); + // Returns the number of leaves in the shape. + static int64 GetLeafCount(const Shape& shape); + + // Retrieves all the leaf shapes and their indexes, in the order walked by + // the ForEachSubshape() API. + static std::vector GetLeafShapes(const Shape& shape); + // Calls the given visitor function for each subshape of the given shape. // Subshapes are visited in DFS pre-order starting with the entire shape // (index {}). @@ -461,26 +513,6 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); - // Removes all degenerate dimensions (size one) from the given shape. The - // stripped minor_to_major preserves the relative ordering of non-degenerate - // dimensions. The stripped shape has the property that the underlying - // representation (bits in memory) for the stripped shape is the same as the - // original shape modulo padding. Examples: - // - // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} - // stripped shape: F32 [2], minor_to_major = {0} - // - // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} - // stripped shape: F32 [6, 5], minor_to_major = {1, 0} - // - // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} - // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} - // - // input shape: F32 [1, 1], minor_to_major = {0, 1} - // stripped shape: F32 [], minor_to_major = {} - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) - static Shape StripDegenerateDimensions(const Shape& shape); - // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i] static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, @@ -522,12 +554,16 @@ class ShapeUtil { // Returns whether a transpose from input_shape to output_shape with dimension // mapping "dimension_mapping" produces a result which is bit-wise identical // to its input and thus may be replaced with a bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); @@ -560,16 +596,109 @@ class ShapeUtil { // The visitor_function visitor function should return true if it wants to // continue, or false otherwise. // - // visitor_function must be a callable of type bool(const std::vector&) - // or compatible. + // visitor_function must be a callable of type + // StatusOr(ArraySlice) or compatible. + template + static Status ForEachIndexWithStatus(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { + return ForEachIndexInternal(shape, base, count, incr, visitor_function); + } + + // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus. + struct IndexIterationSpace { + std::vector index_base; + std::vector index_count; + std::vector index_incr; + }; + + template + static Status ForEachIndexWithStatus( + const Shape& shape, const IndexIterationSpace& iteration_space, + FnTy&& function) { + return ShapeUtil::ForEachIndexWithStatus( + shape, iteration_space.index_base, iteration_space.index_count, + iteration_space.index_incr, std::forward(function)); + } + template static void ForEachIndex(const Shape& shape, tensorflow::gtl::ArraySlice base, tensorflow::gtl::ArraySlice count, tensorflow::gtl::ArraySlice incr, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, base, count, incr, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + + // These convenience wrappers don't take `base`, `count` and `incr` + // explicitly, but iterate over every element in `shape` instead. + + template + static Status ForEachIndexWithStatus(const Shape& shape, + const FnType& visitor_function) { + std::vector base(shape.dimensions_size()); + std::vector incr(shape.dimensions_size(), 1); + return ForEachIndexWithStatus(shape, base, + /*count=*/AsInt64Slice(shape.dimensions()), + incr, visitor_function); + } + + template + static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + + // A parallel version of ForEachIndex(WithStatus). This can only be used if + // the visitor_function is thread-safe and the order of iteration does not + // matter. + // + // visitor_function must be a callable of type + // void(ArraySlice) or compatible. + template + static void ForEachIndexParallel(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { + // The parallel version of ForEachIndexInternal can never fail. + CHECK(ForEachIndexInternal( + shape, base, count, incr, + [&visitor_function](tensorflow::gtl::ArraySlice indexes) + -> StatusOr { + visitor_function(indexes); + return true; + }, + /*parallel=*/true) + .ok()); + } + + // Compute a hash for `shape`. + static size_t Hash(const Shape& shape); + + private: + // Validates all of the non-layout properties of the shape -- this is a helper + // used by both the layout-optional and layout-required public method. + static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); + + template + static Status ForEachIndexInternal(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function, + bool parallel = false) { if (ShapeUtil::HasZeroElements(shape)) { - return; + return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); @@ -579,7 +708,22 @@ class ShapeUtil { // once with the proper empty indexes. int64 n = -1; std::vector indexes(base.begin(), base.end()); - while (n < rank && visitor_function(indexes)) { + const int kNumThreads = tensorflow::port::NumSchedulableCPUs(); + tensorflow::gtl::optional pool; + if (parallel) { + pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); + } + + while (n < rank) { + if (pool != tensorflow::gtl::nullopt) { + pool->Schedule( + [indexes, &visitor_function] { visitor_function(indexes); }); + } else { + TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); + if (!should_continue) { + break; + } + } // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { int64 dim = LayoutUtil::Minor(shape.layout(), n); @@ -590,12 +734,9 @@ class ShapeUtil { indexes[dim] = base[dim]; } } - } - private: - // Validates all of the non-layout properties of the shape -- this is a helper - // used by both the layout-optional and layout-required public method. - static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); + return Status::OK(); + } TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4db97d45b20b86dc60531845c6e28a223203ff7f..0ff514564bdb27b7afa4cf99b0d727f2c029a5ae 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { } TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2]), f32[3])"; + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeShape(F32, {3}), }); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) @@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, ParseInvalidShapeString) { string shape_strings[] = { "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", @@ -238,6 +257,18 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleScalarVsTuple) { + Shape shape1 = ShapeUtil::MakeShape(F32, {}); + Shape shape2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(U32, {})}); + EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); +} + TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -283,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); + + EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); + EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { @@ -437,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) { TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(nested_tuple)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); @@ -458,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) { EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", - ShapeUtil::HumanStringWithLayout(nested_tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + ShapeUtil::HumanStringWithLayout(nested_tuple)); ProgramShape prog = ShapeUtil::MakeProgramShape( {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); @@ -469,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) { "(unknown): u32[1,2], " "(unknown): s32[3,4], " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); prog.add_parameter_names("arg0"); @@ -485,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) { "matrix: u32[1,2], " "matrix2: s32[3,4], " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); } @@ -573,10 +614,11 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = [&invocations](const std::vector& indexes) { - invocations++; - return true; - }; + auto increment_func = + [&invocations](tensorflow::gtl::ArraySlice indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -588,6 +630,47 @@ TEST(ShapeUtilTest, ForEachIndex) { } } +TEST(ShapeUtilTest, ForEachIndexWithStatus) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); + // Increments at every invocation. + int invocations = 0; + auto increment_func = + [&invocations]( + tensorflow::gtl::ArraySlice indexes) -> StatusOr { + if (++invocations == 5) { + return Unimplemented("Cannot increment beyond 5."); + } + return true; + }; + + Status error_status = ShapeUtil::ForEachIndexWithStatus( + shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1}, + increment_func); + + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT(error_status.error_message(), + ::testing::HasSubstr("Cannot increment beyond 5.")); + EXPECT_EQ(invocations, 5); +} + +TEST(ShapeUtilTest, ForEachIndexParallel) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); + int64 output[10][10]; + int init = 5; + auto set_func = [&](tensorflow::gtl::ArraySlice indexes) { + output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; + }; + + ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 10}, + /*incr=*/{1, 1}, set_func); + + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + EXPECT_EQ(output[i][j], init + i + j); + } + } +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index 4eb3bf3766412d5d9a8e78a4652807c5eaeef6ee..69abb51852ac09e8d357a9ba7924efc348ef2001 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,7 +21,7 @@ limitations under the License. namespace xla { -using tensorflow::Status; +using tensorflow::Status; // TENSORFLOW_STATUS_OK } // namespace xla diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index 641b5e9a6accc0a2e7737f79bcd485d317e4e521..0e1387c93938fa520562fcd63ac107a82b089a51 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. @@ -113,17 +110,19 @@ class StatusOr : private internal_statusor::StatusOrData, StatusOr& operator=(StatusOr&&) = default; // Conversion copy/move constructor, T must be convertible from U. - // TODO(b/62186717): These should not participate in overload resolution if U - // is not convertible to T. - template + template ::value>::type* = nullptr> StatusOr(const StatusOr& other); - template + template ::value>::type* = nullptr> StatusOr(StatusOr&& other); // Conversion copy/move assignment operator, T must be convertible from U. - template + template ::value>::type* = nullptr> StatusOr& operator=(const StatusOr& other); - template + template ::value>::type* = nullptr> StatusOr& operator=(StatusOr&& other); // Constructs a new StatusOr with the given value. After calling this @@ -233,12 +232,14 @@ StatusOr& StatusOr::operator=(Status&& status) { } template -template +template ::value>::type*> inline StatusOr::StatusOr(const StatusOr& other) : Base(static_cast::Base&>(other)) {} template -template +template ::value>::type*> inline StatusOr& StatusOr::operator=(const StatusOr& other) { if (other.ok()) this->Assign(other.ValueOrDie()); @@ -248,12 +249,14 @@ inline StatusOr& StatusOr::operator=(const StatusOr& other) { } template -template +template ::value>::type*> inline StatusOr::StatusOr(StatusOr&& other) : Base(static_cast::Base&&>(other)) {} template -template +template ::value>::type*> inline StatusOr& StatusOr::operator=(StatusOr&& other) { if (other.ok()) { this->Assign(std::move(other).ValueOrDie()); diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc617507735fb6c4d011c39723497f69..377a618ffbd99316d409130df8a39f352664dee0 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); @@ -405,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) { EXPECT_EQ(&kI, thing.ValueOrDie()); } -// NOTE(tucker): tensorflow::StatusOr does not support this kind +// NOTE(tucker): StatusOr does not support this kind // of resize op. // TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { // using EvilType = std::vector>; diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 17bae2e4f611268df824ce793c75ba1c95573455..8918350135fbb86973b228b35f5873fea8695b2f 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -40,13 +40,10 @@ class Literal; namespace testing { namespace internal_status { -inline const ::tensorflow::Status& GetStatus( - const ::tensorflow::Status& status) { - return status; -} +inline const Status& GetStatus(const Status& status) { return status; } template -inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { +inline const Status& GetStatus(const StatusOr& status) { return status.status(); } } // namespace internal_status @@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { // The following macros are similar to macros in gmock, but deliberately named // differently in order to avoid conflicts in files which include both. -// Macros for testing the results of functions that return tensorflow::Status or +// Macros for testing the results of functions that return Status or // StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 8339d08ef4d7455f9739b80074ab0405a404e8e8..e7e0a19db0516e4210f6bb78d6b5e6968bf78b2a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:test", ], + alwayslink = True, ) cc_library( @@ -86,12 +87,12 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -116,11 +117,11 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -137,6 +138,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -150,7 +152,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -184,10 +187,12 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -252,8 +257,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -271,15 +276,18 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -290,6 +298,9 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -297,9 +308,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -309,13 +321,17 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -325,6 +341,9 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -332,10 +351,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -352,8 +371,9 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_utils", @@ -366,9 +386,13 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -389,11 +413,11 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -415,10 +439,10 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -430,11 +454,15 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -444,11 +472,15 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -460,11 +492,13 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -476,12 +510,14 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -500,9 +536,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -518,10 +555,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -532,6 +569,9 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -539,10 +579,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -565,10 +605,12 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -577,6 +619,7 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", + size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], backends = [ "cpu", @@ -584,13 +627,13 @@ xla_test( ], shard_count = 48, tags = [ - "enormous", "manual", "notap", ], deps = [ ":client_library_test_base", ":literal_test_util", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -608,9 +651,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -622,16 +665,19 @@ xla_test( xla_test( name = "dot_operation_test", srcs = ["dot_operation_test.cc"], + shard_count = 20, tags = [ "enable_for_xla_interpreter", + "optonly", ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -642,32 +688,21 @@ xla_test( ], ) -# Tests the dot operation in some cases that can be performed via a -# runtime call on some backends - e.g. a runtime call to Eigen. xla_test( - name = "dot_operation_runtime_test", - srcs = ["dot_operation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], + name = "gather_operation_test", + srcs = ["gather_operation_test.cc"], deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) -# Repeat dot_operation_runtime_test with single-threded eigen. +# Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", srcs = ["dot_operation_test.cc"], @@ -675,17 +710,17 @@ xla_test( "cpu": [ "--xla_cpu_multi_thread_eigen=false", ], - "cpu_parallel": [ - "--xla_cpu_multi_thread_eigen=false", - ], }, + shard_count = 20, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -699,13 +734,17 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -717,14 +756,18 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -733,30 +776,42 @@ xla_test( ], ) +CONVOLUTION_TEST_DEPS = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", +] + xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], + deps = CONVOLUTION_TEST_DEPS, +) + +xla_test( + name = "convolution_test_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + deps = CONVOLUTION_TEST_DEPS, ) xla_test( @@ -766,7 +821,6 @@ xla_test( backend_tags = { # TODO(b/31436974): Fix msan failure. Failed on 2016-09-12. "cpu": ["nomsan"], - "cpu_parallel": ["nomsan"], }, shard_count = 30, deps = [ @@ -775,9 +829,10 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -798,9 +853,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -826,11 +882,11 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -858,11 +914,11 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -877,8 +933,7 @@ xla_test( name = "half_test", srcs = ["half_test.cc"], backends = [ - # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25 - # "cpu", + "cpu", "gpu", ], deps = [ @@ -887,8 +942,8 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -902,11 +957,14 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -918,11 +976,15 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -934,14 +996,16 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -960,6 +1024,9 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -967,10 +1034,11 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -980,13 +1048,17 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -998,6 +1070,10 @@ xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], shard_count = 40, + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1008,11 +1084,11 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1035,10 +1111,11 @@ xla_test_library( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1075,11 +1152,11 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1091,12 +1168,16 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1109,11 +1190,30 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1122,13 +1222,16 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1159,12 +1262,16 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1175,14 +1282,18 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1192,15 +1303,18 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1212,9 +1326,13 @@ xla_test( xla_test( name = "fmax_test", srcs = ["fmax_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1225,9 +1343,13 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1238,6 +1360,9 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1247,11 +1372,12 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1267,8 +1393,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1280,6 +1407,9 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1291,10 +1421,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1306,11 +1436,14 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1322,17 +1455,20 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1345,6 +1481,9 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1352,9 +1491,9 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1365,11 +1504,16 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1379,14 +1523,41 @@ xla_test( ], ) +xla_test( + name = "cross_replica_sum_test", + srcs = ["cross_replica_sum_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1405,10 +1576,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1421,9 +1592,13 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1444,9 +1619,9 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1464,9 +1639,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1481,8 +1657,9 @@ xla_test( srcs = ["execution_profile_test.cc"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1494,8 +1671,9 @@ xla_test( args = ["--xla_hlo_profile"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1504,17 +1682,20 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1526,6 +1707,9 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1582,6 +1766,9 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1589,9 +1776,9 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1616,9 +1803,9 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1653,9 +1840,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1668,7 +1855,10 @@ xla_test( xla_test( name = "local_client_execute_test", + # TODO(b/79375911): Test times out in LLVM at normal size. + size = "large", srcs = ["local_client_execute_test.cc"], + shard_count = 30, tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", @@ -1678,9 +1868,9 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -1704,9 +1894,8 @@ tf_cc_test( deps = [ ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1733,22 +1922,6 @@ xla_test( ], ) -xla_test( - name = "set_return_value_test", - srcs = ["set_return_value_test.cc"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], @@ -1762,10 +1935,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1778,6 +1951,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1854,16 +2028,16 @@ tf_cc_test( ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +xla_test( + name = "test_utils_test", + srcs = ["test_utils_test.cc"], + deps = [ + ":local_client_test_base", + ":test_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 7e9005001db34d403ea923eb9c152d114bf32803..36a706496918ac8c15780473019e2a8d098ffa22 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,28 +50,28 @@ class ArrayElementwiseOpTestParamCount public ::testing::WithParamInterface {}; XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1, 324, std::numeric_limits::min(), std::numeric_limits::max()}); - auto result = builder.Neg(a); + builder.Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -83,28 +83,55 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({ + -1, + 1, + 0, + 0x12345678, + static_cast(0xffffffff12345678l), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + }); + builder.Neg(a); + LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); + + ComputeAndCompareR1(&builder, + { + 1, + -1, + 0, + -0x12345678, + 0xedcba988, + static_cast(0x8000000000000000LL), + -static_cast(0x8000000000000001LL), + }, + {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -113,64 +140,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.IsFinite(builder.ConstantR0(NAN)); + XlaBuilder builder(TestName()); + builder.IsFinite(builder.ConstantR0(NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto result_non_canonical = - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - auto result_inf = builder.IsFinite(builder.ConstantR0(inf)); + builder.IsFinite(builder.ConstantR0(inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_neg_inf = builder.IsFinite(builder.ConstantR0(-inf)); + builder.IsFinite(builder.ConstantR0(-inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_zero = builder.IsFinite(builder.ConstantR0(0.0f)); + builder.IsFinite(builder.ConstantR0(0.0f)); ComputeAndCompareR0(&builder, true, {}); } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); auto a = builder.ConstantR1( {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); auto b = builder.ConstantR1( {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -178,17 +204,97 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { + XlaBuilder b(TestName()); + + std::vector lhs{0xFFFFFFFF, + static_cast(-1), + 0, + 0, + 0x7FFFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFLL, + 0x8000000000000000LL, + 0x8000000000000000LL, + 1}; + std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_data = + client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + + std::vector rhs{1, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 0x8000000000000000LL, + 0, + static_cast(-1), + 0, + 1, + 0x8000000000000000LL}; + std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_data = + client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + + b.Add(lhs_param, rhs_param); + + std::vector expected(lhs.size()); + for (int64 i = 0; i < lhs.size(); ++i) { + expected[i] = lhs[i] + rhs[i]; + } + + ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { + XlaBuilder b(TestName()); + + std::vector lhs{static_cast(0x8000000000000000LL), + static_cast(0x8000000000000000LL), + -1, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 1, + 0, + -1}; + std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_data = + client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + + std::vector rhs{-1, + 0, + static_cast(0x8000000000000000LL), + 1, + 0, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL}; + std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_data = + client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + + auto sub = b.Sub(lhs_param, rhs_param); + + std::vector expected(lhs.size()); + for (int64 i = 0; i < lhs.size(); ++i) { + expected[i] = lhs[i] - rhs[i]; + } + + ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector a_values; std::vector b_values; for (int i = 0; i < count; ++i) { @@ -227,49 +333,49 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); auto b = builder.ConstantR1({-1, 2, 1, -1}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -277,29 +383,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -329,9 +435,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -344,8 +450,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -354,9 +460,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -369,8 +475,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -400,9 +506,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -414,8 +520,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -424,9 +530,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -438,8 +544,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -449,33 +555,33 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); auto b = builder.ConstantR1( {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); auto b = builder.ConstantR1( {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -483,21 +589,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); auto b = builder.ConstantR1( {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -505,20 +611,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -541,19 +647,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -572,21 +678,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -594,378 +700,386 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, -8}); auto b = builder.ConstantR1({5, -7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {3, 8}}); auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, 8}); auto b = builder.ConstantR1({5, -7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -1}, {8, 8}}); auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {8, 8}}); auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, true}, {true, false}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 4294967295}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { - ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x12345678), - static_cast(0xF0001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15}); - auto out = builder.ShiftLeft(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({static_cast(0x12345678), + static_cast(0xF0001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); + builder.ShiftLeft(a, b); - ComputeAndCompareR1( - &builder, - {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); + ComputeAndCompareR1(&builder, + {static_cast(0x23456780), 0x00100000, 0x4, + 0x180, 2523136, 0, 0, 0}, + {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { - ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2}); - auto out = builder.ShiftRightArithmetic(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); + builder.ShiftRightArithmetic(a, b); - ComputeAndCompareR1(&builder, - {static_cast(0xF9234567), - static_cast(0x00100010), 0, 0, 19}, - {}); + ComputeAndCompareR1( + &builder, + {static_cast(0xF9234567), static_cast(0x00100010), 0, 0, 19, + 0, -1, 0}, + {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { - ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5}); - auto out = builder.ShiftRightLogical(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); + builder.ShiftRightLogical(a, b); - ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); + ComputeAndCompareR1(&builder, + {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x12345678, 0xF0001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15}); - auto out = builder.ShiftLeft(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); + builder.ShiftLeft(a, b); ComputeAndCompareR1( - &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2}); - auto out = builder.ShiftRightArithmetic(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); + builder.ShiftRightArithmetic(a, b); - ComputeAndCompareR1(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); + ComputeAndCompareR1( + &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5}); - auto out = builder.ShiftRightLogical(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); + builder.ShiftRightLogical(a, b); - ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); + ComputeAndCompareR1(&builder, + {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -973,10 +1087,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -984,17 +1098,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1005,16 +1119,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1023,7 +1137,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1034,7 +1148,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1043,10 +1157,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1054,10 +1168,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1066,10 +1180,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1078,10 +1192,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1091,10 +1205,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1103,10 +1217,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1115,10 +1229,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1127,10 +1241,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1138,10 +1252,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1149,10 +1263,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1161,10 +1275,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1172,10 +1286,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1184,12 +1298,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1197,27 +1311,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } // Some Pow cases that can be implemented more efficiently. XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1245,7 +1359,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { } XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1270,7 +1384,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { } XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1295,7 +1409,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { } XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1320,7 +1434,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { } XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1345,7 +1459,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { } XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1377,7 +1491,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { } XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1410,7 +1524,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { } XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; @@ -1443,7 +1557,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { } XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1484,14 +1598,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector values; values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } auto x = builder.ConstantR1(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1503,7 +1617,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 2, 2); std::vector values_vector; @@ -1517,140 +1631,86 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT -// such -// * fmin(NaN, x) = x -// * fmax(NaN, x) = x -// so we only test NAN on CPU. -// -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); -#endif - auto minimum = builder.Min(lhs, rhs); - - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {1.0f, -5.0f, 1.0f}, -#else - {1.0f, -5.0f, 1.0f, 10.0f, 6.0f}, -#endif - {}, error_spec_); + builder.Min(lhs, rhs); + + ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); -#endif - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {1.0, -5.0, 1.0}, -#else - {1.0, -5.0, 1.0, 10.0, 6.0}, -#endif - {}, error_spec_); + ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, + error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); -#endif - auto maximum = builder.Max(lhs, rhs); - - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {2.0f, 1.0f, 2.25f}, -#else - {2.0f, 1.0f, 2.25f, 10.0f, 6.0f}, -#endif - {}, error_spec_); + builder.Max(lhs, rhs); + + ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); -#endif - auto maximum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {2.0, 1.0, 2.25}, -#else - {2.0, 1.0, 2.25, 10.0, 6.0}, -#endif - {}, error_spec_); + ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1665,7 +1725,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1679,7 +1739,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Max(x, y); @@ -1690,7 +1750,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Min(x, y); @@ -1700,7 +1760,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( @@ -1713,7 +1773,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR1({}); builder.Max(u, v); @@ -1723,7 +1783,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); @@ -1733,7 +1793,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); @@ -1744,7 +1804,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({}); auto m = builder.ConstantR2({{}, {}}); builder.Max(v, m, /*broadcast_dimensions=*/{1}); @@ -1754,7 +1814,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1765,7 +1825,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d(2, 0, 3); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1776,7 +1836,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); @@ -1787,7 +1847,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{}, {}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); builder.Min(m, v, /*broadcast_dimensions=*/{0}); @@ -1797,7 +1857,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); auto array4d = builder.ConstantR4FromArray4D( @@ -1812,7 +1872,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); @@ -1824,7 +1884,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Min(x, y); @@ -1834,7 +1894,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Max(x, y); @@ -1844,110 +1904,107 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR0(0.0f); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto maximum = builder.ConstantR0(5.0f); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -1961,7 +2018,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -1969,7 +2026,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); @@ -1983,7 +2040,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -1991,7 +2048,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -2000,35 +2057,35 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - auto add = builder.Add(a, p); + builder.Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Cos(a); + builder.Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Sin(a); + builder.Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - auto atan = builder.Atan2(a, b); + builder.Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2037,9 +2094,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - auto result = builder.Tanh(a); + builder.Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2049,7 +2106,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, @@ -2088,7 +2145,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. @@ -2124,7 +2181,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, @@ -2159,19 +2216,37 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { + XlaBuilder builder(TestName()); + auto a = + builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / // b -----/ / // c---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(a, b); - auto add2 = builder.Add(add, c); + builder.Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2182,14 +2257,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // / / // c -----/ / // a---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(b, c); - auto add2 = builder.Add(a, add); + builder.Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2199,14 +2274,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // a ----- (neg) ----- (add) // / // b ----- (neg) ----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto neg_a = builder.Neg(a); auto neg_b = builder.Neg(b); - auto result = builder.Add(neg_a, neg_b); + builder.Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2220,7 +2295,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // c ------ (add) ------------/ // / // d -----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); @@ -2229,19 +2304,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { auto add_ab = builder.Add(a, b); auto add_cd = builder.Add(c, d); - auto add_all = builder.Add(add_ab, add_cd); + builder.Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2250,11 +2325,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(scalar, a); + builder.Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2262,11 +2337,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(a, scalar); + builder.Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2275,14 +2350,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); // clang-format off auto m = builder.ConstantR2({ {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2290,7 +2365,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // Test broadcasting in Eq comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({42, 73}); auto m = builder.ConstantR2({{42, 73}, {42, 52}}); @@ -2308,10 +2383,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({42, 73}); auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + builder.Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2322,10 +2397,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + builder.Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2336,10 +2411,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + builder.Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2350,10 +2425,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); + builder.Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2364,10 +2439,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + builder.Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2379,24 +2454,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + builder.Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2404,14 +2479,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2422,13 +2497,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // effectively creates an "outer product" operation. // This is taken from the Numpy docs example at: // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2439,10 +2514,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2450,17 +2525,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); + builder.Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { // Binary add of two R3s together - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2468,7 +2543,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto add = builder.Add(a, b); + builder.Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2479,7 +2554,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2492,7 +2567,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); + builder.Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2503,7 +2578,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2516,7 +2591,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); + builder.Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2534,7 +2609,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} // for broadcasting. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2549,7 +2624,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2566,7 +2641,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { // Comparison between two 3D arrays of compatible shapes: // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2574,7 +2649,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto compare = builder.Gt(a, b); + builder.Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2590,7 +2665,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { } XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> operand_b_4d(new Array4D(2, 3, 4, 5)); @@ -2611,13 +2686,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); @@ -2639,7 +2714,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR1(operand_b_1d); - auto add = builder.Add(a, b, {1}); + builder.Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2654,7 +2729,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::vector r1(d1); std::iota(r1.begin(), r1.end(), 1.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = builder.ConstantLiteral(*a_literal); @@ -2675,11 +2750,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { // Show that we can't add two opaques. XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); auto x = builder.Parameter(0, shape, "x"); - auto concatenated = builder.Add(x, x); - StatusOr computation_status = builder.Build(); + builder.Add(x, x); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( @@ -2687,12 +2762,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2700,14 +2775,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { } XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); - StatusOr computation_status = builder.Build(); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().error_message(), ::testing::ContainsRegex("must.*be the identity")); @@ -2716,7 +2791,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { // Regression test for b/31927799. "slice - y" is fused and requires implicit // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x_literal = Literal::CreateR1({1, 2, 3}); auto y_literal = Literal::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 3f6fd7c65d3360a622dbf754833009fb20410535..fcd9ff55e393f64476ddd4754e0fa74427f1cb51 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,11 +28,11 @@ namespace { class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { - ComputationBuilder builder(client_, "ax_10"); + XlaBuilder builder("ax_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Mul(alpha, x); + builder.Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -41,26 +41,26 @@ TEST_F(AxpySimpleTest, AxTenValues) { } XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { - ComputationBuilder builder(client_, "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1({}); auto y = builder.ConstantR1({}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(AxpySimpleTest, AxpyTenValues) { - ComputationBuilder builder(client_, "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); auto y = builder.ConstantR1( {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index e4bf1827acf24bcdbfe20fe39e794a0265ab89e3..22c3394e6f34bd018ffaaaa4d9d68339673c3764 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -34,13 +34,13 @@ namespace { class BadRngShapeValidationTest : public ClientLibraryTestBase {}; TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0.0); auto one = builder.ConstantR0(1.0); Shape default_constructed; builder.RngUniform(zero, one, default_constructed); - StatusOr computation = builder.Build(); + StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); LOG(INFO) << "status received: " << computation.status(); EXPECT_THAT(computation.status().error_message(), @@ -48,7 +48,7 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { } TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0.0); auto one = builder.ConstantR0(1.0); Shape sans_layout; @@ -57,7 +57,7 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { builder.RngUniform(zero, one, sans_layout); - StatusOr computation = builder.Build(); + StatusOr computation = builder.Build(); ASSERT_TRUE(computation.ok()); LOG(INFO) << computation.status(); } diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 28ab9654997728fbafd6610af840e721e72cce5a..f3dac75a44b948c4b45b80b93e7462073010979e 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -69,6 +69,15 @@ class BatchNormalizationTest CHECK_EQ(kY, input_array_.width()); } + XlaOp CheckShape(XlaBuilder* b, const XlaOp& operand, + const Shape& expected_shape) const { + Shape actual_shape = b->GetShape(operand).ConsumeValueOrDie(); + CHECK(ShapeUtil::Equal(expected_shape, actual_shape)) + << "want " << ShapeUtil::HumanString(expected_shape) << " got " + << ShapeUtil::HumanString(actual_shape); + return operand; + } + static constexpr int64 kSamples = 3; static constexpr int64 kX = 1; static constexpr int64 kY = 1; @@ -91,7 +100,7 @@ INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, #endif XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { - ComputationBuilder builder(client_, "subtract_in_z_one_sample"); + XlaBuilder builder("subtract_in_z_one_sample"); auto x = builder.ConstantLiteral(input_literal_); auto y = builder.ConstantR1({3.14, 4.25}); builder.Sub(x, y, /*broadcast_dimensions=*/{1}); @@ -107,7 +116,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { } XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { - ComputationBuilder builder(client_, "square_tesseract_elementwise"); + XlaBuilder builder("square_tesseract_elementwise"); auto x = builder.ConstantLiteral(input_literal_); builder.SquareF32(x); @@ -124,9 +133,9 @@ XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { } XLA_TEST_P(BatchNormalizationTest, SumToZ) { - ComputationBuilder builder(client_, "sum_to_z"); + XlaBuilder builder("sum_to_z"); auto input_activations = builder.ConstantLiteral(input_literal_); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all but the Z dimension. builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, {0, 2, 3}); @@ -136,24 +145,23 @@ XLA_TEST_P(BatchNormalizationTest, SumToZ) { } XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { - ComputationBuilder builder(client_, "square_and_reduce"); + XlaBuilder builder("square_and_reduce"); auto input_activations = builder.ConstantLiteral(input_literal_); auto set_means = builder.ConstantR1({2.f, 4.2f}); auto activation_deviations = builder.Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = builder.Reduce( - dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); + builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); std::vector expected = {18, 0.06}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); } XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { - ComputationBuilder builder(client_, "variance_to_stddev"); + XlaBuilder builder("variance_to_stddev"); auto variance = builder.ConstantR1({6.f, .02f}); - auto sqrt = builder.SqrtF32(variance); + builder.SqrtF32(variance); std::vector expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -162,23 +170,24 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { // Compare against a forward batch normalization example in the NN spec // reference. XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { - ComputationBuilder builder(client_, "batch_normalize_per_spec"); + XlaBuilder builder("batch_normalize_per_spec"); auto input_activations = - builder.CheckShape(builder.ConstantLiteral(input_literal_), - ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); + CheckShape(&builder, builder.ConstantLiteral(input_literal_), + ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); auto gamma = builder.ConstantR1({1.0, 1.0}); auto beta = builder.ConstantR1({0.0, 0.0}); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); - auto sum = builder.CheckShape( + auto sum = CheckShape( + &builder, builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); - auto count = builder.ConstantR0(ShapeUtil::ElementsIn(*input_shape) / - ShapeUtil::ElementsIn(*sum_shape)); + auto count = builder.ConstantR0(ShapeUtil::ElementsIn(input_shape) / + ShapeUtil::ElementsIn(sum_shape)); auto set_means = builder.Div(sum, count); const float kEpsilon = 1e-9f; @@ -187,14 +196,16 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { auto activation_deviations = builder.Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = builder.CheckShape( + auto sum_of_squares = CheckShape( + &builder, builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto variance = builder.Div(sum_of_squares, count); auto standard_deviation = builder.SqrtF32(variance); - auto standard_deviation_above_epsilon = builder.CheckShape( - builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); + auto standard_deviation_above_epsilon = + CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + ShapeUtil::MakeShape(PRED, {2})); auto gt_eps = builder.Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); auto normalization_factors = builder.ReciprocalF32(gt_eps); @@ -219,7 +230,7 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { XLA_TEST_P(BatchNormalizationTest, BasicTraining) { const int kFeatureIndex = 3; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); @@ -228,8 +239,8 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { auto offset = builder.ConstantR1({1.0f, 2.0f}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, @@ -243,7 +254,7 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); @@ -252,8 +263,8 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { auto offset = builder.ConstantR1({1.0f, 2.0f}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, @@ -268,23 +279,23 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { // Use 0 dimension as feature, tests layout analyzer. const int kFeatureIndex = 0; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle h0; + XlaOp h0; auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), /*parameter_number=*/0, "operand", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto scale = CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; + XlaOp h2; auto offset = CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/2, "offset", &builder, &h2); - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); + builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) @@ -300,24 +311,24 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { // Test the correctness of choosing a large epsilon value. const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle h0; + XlaOp h0; auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, /*parameter_number=*/0, "operand", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto scale = CreateR1Parameter(std::vector(1, 1.0f), /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; + XlaOp h2; auto offset = CreateR1Parameter(std::vector(1, 0.0f), /*parameter_number=*/2, "offset", &builder, &h2); // var = 125, mean = 15, epsilon = -100 - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); + builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) @@ -332,7 +343,7 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); @@ -439,7 +450,7 @@ INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes, XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { float epsilon = 0.001; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const std::vector& bounds = GetParam().bounds; Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); input_array.FillRandom(GetParam().random_value_var, @@ -539,7 +550,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { float epsilon = 0.001; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const std::vector& bounds = GetParam().bounds; Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); input_array.FillRandom(GetParam().random_value_var, @@ -647,7 +658,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { float epsilon = 0.001; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const std::vector& bounds = GetParam().bounds; Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); input_array.FillRandom(GetParam().random_value_var, @@ -814,9 +825,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { std::unique_ptr grad_output_data = client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); - auto t = builder.BatchNormGrad(input_parameter, scale_parameter, - mean_parameter, var_parameter, - grad_output_parameter, epsilon, feature_index); + builder.BatchNormGrad(input_parameter, scale_parameter, mean_parameter, + var_parameter, grad_output_parameter, epsilon, + feature_index); auto expected = Literal::MakeTuple({expected_grad_activation.get(), diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index b853dfaa15d7ff2e21048a5a6a486d22c5a05416..ca337e78840e77377719636cd4cf33af2578210d 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -19,10 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -38,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -52,7 +50,7 @@ class Bfloat16Test : public ClientLibraryTestBase { }; XLA_TEST_F(Bfloat16Test, ScalarOperation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR0(static_cast(2.0f)); auto y = builder.ConstantR0(static_cast(1.0f)); builder.Add(x, y); @@ -62,7 +60,7 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { } XLA_TEST_F(Bfloat16Test, LogOperation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR0(static_cast(4.0f)); builder.Log(x); @@ -71,7 +69,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { } XLA_TEST_F(Bfloat16Test, NegateScalarF16) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(static_cast(2.1f))); ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, @@ -80,7 +78,7 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) { XLA_TEST_F(Bfloat16Test, BatchNormTraining) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( {{{{static_cast(1.f)}, {static_cast(2.f)}}, @@ -117,7 +115,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { XLA_TEST_F(Bfloat16Test, BatchNormGrad) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( Array4D(2, 2, 2, 1, static_cast(0.0f))); diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 97fec89b63fb8d3a4264275f3253a91e1ea2ce68..48203b1d40ea69ff00a57c2c9e42620739b23d59 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -32,7 +32,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -48,7 +48,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 129); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -64,7 +64,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 9, 5); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -80,7 +80,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 257); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -93,7 +93,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { } TEST_F(BinopScalingTest, R0PlusR2F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR0(42.0); auto rhs = builder.ConstantR2({ {1.0, 2.0}, {3.0, 4.0}, @@ -109,7 +109,7 @@ TEST_F(BinopScalingTest, R0PlusR2F32) { } TEST_F(BinopScalingTest, R4PlusR0S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array4D lhs_array({ {{{1, 2}, diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index 0d94d65c1015fb54ada3fdfc95d0c31d0a0f158b..bff60f25ec8f15d372d251ac313200301a04f20f 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -34,7 +34,7 @@ namespace { class BitcastConvertTest : public ClientLibraryTestBase { public: - explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr) + explicit BitcastConvertTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -42,7 +42,7 @@ class BitcastConvertTest : public ClientLibraryTestBase { }; TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.BitcastConvertType(a, S32); @@ -51,7 +51,7 @@ TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { } TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0f, 64.0f}); builder.BitcastConvertType(a, F32); @@ -60,7 +60,7 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { } TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, static_cast(0xBF800000), 0x3F000000, @@ -72,7 +72,7 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { } XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); builder.BitcastConvertType(a, F32); @@ -81,7 +81,7 @@ XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { } TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.6, 64.4}); builder.BitcastConvertType(a, S32); @@ -90,7 +90,7 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { } TEST_F(BitcastConvertTest, ConvertS32Extremes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {std::numeric_limits::min(), std::numeric_limits::max()}); builder.BitcastConvertType(a, F32); @@ -100,7 +100,7 @@ TEST_F(BitcastConvertTest, ConvertS32Extremes) { } TEST_F(BitcastConvertTest, ConvertMapToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->BitcastConvertType(param, S32); @@ -112,7 +112,7 @@ TEST_F(BitcastConvertTest, ConvertMapToS32) { } TEST_F(BitcastConvertTest, ConvertMapToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->BitcastConvertType(param, F32); @@ -129,7 +129,7 @@ TEST_F(BitcastConvertTest, ConvertMapToF32) { // input -> convert -> reshape // the new convert should have the same element type as the old convert. TEST_F(BitcastConvertTest, ConvertReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR1({0x42280000}); auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); builder.BitcastConvertType(reshape, F32); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 03f5e08315bfed2bcb43ebb7098aaa0b97228605..3a0f51fc66d65c8684bd607b9e8103559cd4d8d4 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -33,10 +33,8 @@ namespace { class BroadcastSimpleTest : public ClientLibraryTestBase { public: - ComputationDataHandle BuildBinOp(HloOpcode op, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - ComputationBuilder* builder) { + XlaOp BuildBinOp(HloOpcode op, const XlaOp& lhs, const XlaOp& rhs, + XlaBuilder* builder) { switch (op) { case HloOpcode::kMinimum: { return builder->Min(lhs, rhs); @@ -105,21 +103,21 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(1.5), {}); ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(2.25), {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle src; + XlaBuilder b(TestName()); + XlaOp src; std::unique_ptr param_data = CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", /*builder=*/&b, /*data_handle=*/&src); @@ -131,21 +129,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(2.25), {2, 0}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(2.25), {0, 2}); Array2D expected(0, 2); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR1({1, 2, 3}), {2}); Array2D expected(2, 3); @@ -160,7 +158,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { // Tests implicit broadcasting of PREDs. XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); Array2D x_vals(2, 1); x_vals(0, 0) = true; @@ -171,7 +169,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { y_vals(1, 0, 0) = true; y_vals(1, 1, 0) = true; - ComputationDataHandle x, y; + XlaOp x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); b.And(x, y, /*broadcast_dimensions=*/{1, 2}); @@ -186,7 +184,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { } XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR1({}), {2}); Array2D expected(2, 0); @@ -194,7 +192,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { } XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR1({1, 2, 3}), {0}); Array2D expected(0, 3); @@ -209,7 +207,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one // dimensions. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 5.0}}), b.ConstantLiteral(*Literal::CreateR3( @@ -247,7 +245,7 @@ class BroadcastR3ImplicitTest XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { const R3ImplicitBroadcastSpec& spec = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape r3_shape, r3_implicit_shape; Array3D r3_array(spec.output_bounds[0], spec.output_bounds[1], @@ -264,8 +262,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); auto r3_parameter = builder.Parameter(1, r3_shape, "input"); - ComputationDataHandle op = - BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); @@ -300,9 +297,9 @@ INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle r1h; - ComputationDataHandle r3h; + XlaBuilder b(TestName()); + XlaOp r1h; + XlaOp r3h; Array3D r1d = {{{1}}, {{2}}}; Array3D r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; @@ -319,7 +316,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -332,7 +329,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -345,7 +342,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -358,7 +355,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -371,7 +368,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = b.ConstantLiteral( @@ -385,7 +382,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -491,7 +488,7 @@ class BroadcastR2ImplicitTest XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { const R2ImplicitBroadcastSpec& spec = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Operands with degenerate dimensions require implicit broadcasting: Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2; @@ -517,10 +514,9 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto r2_implicit_parameter2 = builder.Parameter(2, r2_implicit_shape2, "input2"); - ComputationDataHandle op1 = + XlaOp op1 = BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); - ComputationDataHandle op2 = - BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); Array2D expected_array(spec.output_bounds[0], spec.output_bounds[1]); @@ -547,7 +543,7 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, ::testing::ValuesIn(kR2ImplicitBroadcastTestCases)); XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); @@ -558,7 +554,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); @@ -569,7 +565,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -582,7 +578,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -595,7 +591,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -608,7 +604,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1_0 = b.ConstantR1({1000, 2000}); auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); @@ -629,7 +625,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1_0 = b.ConstantR1({1000, 2000}); auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); @@ -652,7 +648,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) // results in a shape incompatible with the lhs [2, 3, 1]. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), b.ConstantLiteral(*Literal::CreateR3( @@ -662,12 +658,12 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); } XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 2.0}}), b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); @@ -675,12 +671,12 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("binary op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 2.0}}), b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); @@ -688,7 +684,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("binary op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 6ebbf7191833ef85ee4a48cc96c0a3be38c71228..51b9f0d3e330e73f5d110f0a62f824179d5c7cf7 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, - error_spec_); + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { @@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralView::Create(*result, {0}), error_spec_); + LiteralSlice(*result, {0}), error_spec_)); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralView::Create(*result, {1}), error_spec_); + LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { @@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_); + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 610302ac1256a57db6ed6e18016a4136973e3891..53f2c3bfbfce9585cb68f103a495ce2f1ad8432e 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -4,7 +4,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") -all_backends = ["cpu", "cpu_parallel", "gpu"] + plugins.keys() +all_backends = ["cpu", "gpu"] + plugins.keys() def filter_backends(backends): """Removes "gpu" from a backend list if CUDA is not enabled. @@ -39,10 +39,10 @@ def xla_test(name, **kwargs): """Generates cc_test targets for the given XLA backends. - This rule generates a cc_test target for one or more XLA backends and also - a platform-agnostic cc_library rule. The arguments are identical to cc_test - with two additions: 'backends' and 'backend_args'. 'backends' specifies the - backends to generate tests for ("cpu", "cpu_parallel", "gpu"), and + This rule generates a cc_test target for one or more XLA backends and also a + platform-agnostic cc_library rule. The arguments are identical to cc_test with + two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "gpu"), and 'backend_args'/'backend_tags' specifies backend-specific args parameters to use when generating the cc_test. @@ -90,9 +90,9 @@ def xla_test(name, deps: Dependencies of the target. xla_test_library_deps: If set, the generated test targets will depend on the respective cc_libraries generated by the xla_test_library rule. - backends: A list of backends to generate tests for. Supported - values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will - be generated for all supported backends. + backends: A list of backends to generate tests for. Supported values: "cpu", + "gpu". If this list is empty, the test will be generated for all supported + backends. blacklisted_backends: A list of backends to NOT generate tests for. args: Test arguments for the target. tags: Tags for the target. @@ -128,16 +128,13 @@ def xla_test(name, if backend == "cpu": backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - elif backend == "cpu_parallel": - backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - this_backend_args += ["--xla_backend_extra_options=\"xla_cpu_parallel\""] elif backend == "gpu": backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] this_backend_tags += ["requires-gpu-sm35"] elif backend in plugins: - backend_deps = plugins[backend]["deps"] + backend_deps = [] + backend_deps += plugins[backend]["deps"] this_backend_copts += plugins[backend]["copts"] this_backend_tags += plugins[backend]["tags"] this_backend_args += plugins[backend]["args"] @@ -200,7 +197,7 @@ def xla_test_library(name, hdrs: Headers for the target. deps: Dependencies of the target. backends: A list of backends to generate libraries for. - Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the + Supported values: "cpu", "gpu". If this list is empty, the library will be generated for all supported backends. """ @@ -209,7 +206,7 @@ def xla_test_library(name, for backend in filter_backends(backends): this_backend_copts = [] - if backend in ["cpu", "cpu_parallel", "gpu"]: + if backend in ["cpu", "gpu"]: backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] elif backend in plugins: backend_deps = plugins[backend]["deps"] diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 5e42365ae38dcc770bc2f1c9cb2c088fe02241a3..5fd33b50c94356839bbed58acd43b7d0286f4a7e 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,16 +32,16 @@ namespace { class CallOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0F32IdentityComputation() { - ComputationBuilder builder(client_, "Identity"); + XlaComputation CreateR0F32IdentityComputation() { + XlaBuilder builder("Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateR1S0F32AdditionComputation() { - ComputationBuilder builder(client_, "Addition"); + XlaComputation CreateR1S0F32AdditionComputation() { + XlaBuilder builder("Addition"); auto x = builder.Parameter(0, r1s0f32_, "x"); auto y = builder.Parameter(1, r1s0f32_, "y"); builder.Add(x, y); @@ -50,8 +50,8 @@ class CallOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR1S2F32AdditionComputation() { - ComputationBuilder builder(client_, "Addition"); + XlaComputation CreateR1S2F32AdditionComputation() { + XlaBuilder builder("Addition"); auto x = builder.Parameter(0, r1s2f32_, "x"); auto y = builder.Parameter(1, r1s2f32_, "y"); builder.Add(x, y); @@ -60,8 +60,8 @@ class CallOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0F32TupleComputation() { - ComputationBuilder builder(client_, "Tuple"); + XlaComputation CreateR0F32TupleComputation() { + XlaBuilder builder("Tuple"); builder.Tuple({builder.Parameter(0, r0f32_, "x")}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); @@ -74,8 +74,8 @@ class CallOpTest : public ClientLibraryTestBase { }; XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR0F32IdentityComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR0F32IdentityComputation(); auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); builder.Call(callee, {constant}); @@ -83,8 +83,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { } XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR1S0F32AdditionComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR1S0F32AdditionComputation(); auto x = builder.ConstantLiteral(*Literal::CreateR1({})); auto y = builder.ConstantLiteral(*Literal::CreateR1({})); builder.Call(callee, {x, y}); @@ -93,8 +93,8 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { } XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR1S2F32AdditionComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); builder.Call(callee, {x, y}); @@ -103,23 +103,23 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { } XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { - ComputationBuilder builder(client_, "inner"); + XlaBuilder builder("inner"); { auto x = builder.Parameter(0, r0f32_, "x"); builder.Add(x, builder.ConstantR0(1.0)); } - TF_ASSERT_OK_AND_ASSIGN(Computation inner, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build()); - ComputationBuilder builder2(client_, "outer"); + XlaBuilder builder2("outer"); { auto x = builder2.Parameter(0, r0f32_, "x"); x = builder2.Call(inner, {x}); x = builder2.Call(inner, {x}); x = builder2.Call(inner, {x}); } - TF_ASSERT_OK_AND_ASSIGN(Computation outer, builder2.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build()); - ComputationBuilder builder3(client_, "outermost"); + XlaBuilder builder3("outermost"); { auto x = builder3.Parameter(0, r0f32_, "x"); x = builder3.Call(outer, {x}); @@ -134,8 +134,8 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { } XLA_TEST_F(CallOpTest, CallR0F32Tuple) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR0F32TupleComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR0F32TupleComputation(); auto elem = Literal::CreateR0(42.0); auto tuple = Literal::MakeTuple({elem.get()}); builder.Call(callee, {builder.ConstantLiteral(*elem)}); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index f594cc10ac6496f710d03f0b0b134e6dd3b6d38f..660ff0cad5666219a4a7cb1eedbed03f06e651ba 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -35,7 +35,7 @@ using ::testing::ContainsRegex; class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { - ComputationBuilder builder(client_, "add_two_params"); + XlaBuilder builder("add_two_params"); auto param_literal = Literal::CreateR1({1.1f, 2.2f}); auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); @@ -75,7 +75,7 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { } XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { - ComputationBuilder builder(client_, "add_two_params"); + XlaBuilder builder("add_two_params"); auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a677986cd926cc0054d8f36abc98ccac33dc043d..bf8ed4d9fb0bc61b86ef0b5872711a122a3d416b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { + +// Name of the interpreter backend. +constexpr char kInterpreter[] = "interpreter"; + // Wrapper function that creates a nicer error message (than a bare // ValueOrDie()) if the platform we intend to test is not available. Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { @@ -43,14 +45,26 @@ Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); } + +// Helper functions to get the reference platform. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + } // namespace ClientLibraryTestBase::ClientLibraryTestBase( - perftools::gputools::Platform* platform, - const LocalClientOptions& client_options) + se::Platform* platform, const LocalClientOptions& client_options) : client_(GetOrCreateLocalClientOrDie(client_options)), execution_options_(CreateDefaultExecutionOptions()) { CHECK_EQ(platform, client_options.platform()); + + LocalClientOptions ref_options; + ref_options.set_platform(GetReferencePlatform()); + ref_client_ = GetOrCreateLocalClientOrDie(ref_options); + // Disabling constant_folding so that tests (usually written using Constants) // will exercise the intended code paths, instead of being constant folded. // @@ -66,6 +80,11 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) LocalClientOptions default_options; default_options.set_platform(platform); client_ = GetOrCreateLocalClientOrDie(default_options); + + LocalClientOptions ref_options; + ref_options.set_platform(GetReferencePlatform()); + ref_client_ = GetOrCreateLocalClientOrDie(ref_options); + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); } @@ -75,15 +94,14 @@ string ClientLibraryTestBase::TestName() const { } StatusOr> ClientLibraryTestBase::Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -96,34 +114,35 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return Execute(builder, arguments).ConsumeValueOrDie(); -} - -std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); +StatusOr> +ClientLibraryTestBase::ExecuteAndTransferReference( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + ExecutionOptions execution_options = execution_options_; + if (shape_with_output_layout != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *shape_with_output_layout; + } + execution_options.clear_device_handles(); + return ref_client_->ExecuteAndTransfer(computation, arguments, + &execution_options); } string ClientLibraryTestBase::ExecuteToString( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - StatusOr computation_status = builder->Build(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); } - Computation computation = computation_status.ConsumeValueOrDie(); + auto computation = computation_status.ConsumeValueOrDie(); auto result = client_->ExecuteAndTransfer(computation, arguments, &execution_options_); @@ -135,7 +154,7 @@ string ClientLibraryTestBase::ExecuteToString( } void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, + XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -143,7 +162,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( } void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, @@ -151,16 +170,15 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( } void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output) { @@ -181,12 +199,11 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( "Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& /*expected*/, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, @@ -196,8 +213,8 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // This is a recursive function. It's an std::function instead of a lambda // because it needs to capture itself. The index is the index of the argument // to try all layouts for. - std::function choose; - choose = [&, this](int64 index) -> tensorflow::Status { + std::function choose; + choose = [&, this](int64 index) -> Status { if (index < arguments.size()) { // Try out all layouts for the operand. TF_ASSIGN_OR_RETURN(auto literal, @@ -210,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); layout_strings.pop_back(); - return tensorflow::Status::OK(); + return Status::OK(); } std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); @@ -228,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( layout_strings.pop_back(); } while ( std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } // Every argument has an assigned layout. @@ -243,14 +260,14 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::strings::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); - return tensorflow::Status::OK(); + return Status::OK(); }; return choose(0); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -277,7 +294,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -291,7 +308,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -303,12 +320,12 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + return Status::OK(); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -329,7 +346,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -343,7 +360,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)) + << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -355,12 +373,12 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( - ComputationBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, tensorflow::StringPiece expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -379,7 +397,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( } void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -387,11 +405,11 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -399,46 +417,47 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(expected, *actual, error); + EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { return; } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); } void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + ErrorSpec error) { + auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { return; } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); } StatusOr, std::unique_ptr>> ClientLibraryTestBase::ComputeValueAndReference( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's // into a vector to keep the data alive on the service until the end of this // function. std::vector> argument_data; + std::vector> ref_argument_data; for (const auto& arg : arguments) { - TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); + TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone())); + TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg)); argument_data.push_back(std::move(data)); + ref_argument_data.push_back(std::move(ref_data)); } // Create raw pointers to the GlobalData for the rest of the call stack. @@ -447,17 +466,25 @@ ClientLibraryTestBase::ComputeValueAndReference( argument_data.begin(), argument_data.end(), std::back_inserter(argument_data_ptr), [](const std::unique_ptr& data) { return data.get(); }); + std::vector ref_argument_data_ptr; + std::transform( + ref_argument_data.begin(), ref_argument_data.end(), + std::back_inserter(ref_argument_data_ptr), + [](const std::unique_ptr& data) { return data.get(); }); + + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - TF_ASSIGN_OR_RETURN( - auto reference, - builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); TF_ASSIGN_OR_RETURN(auto result, - ExecuteAndTransfer(builder, argument_data_ptr)); + ExecuteAndTransfer(computation, argument_data_ptr)); + + TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference( + computation, ref_argument_data_ptr)); + return std::make_pair(std::move(reference), std::move(result)); } -Computation ClientLibraryTestBase::CreateScalarRelu() { - ComputationBuilder builder(client_, "relu"); +XlaComputation ClientLibraryTestBase::CreateScalarRelu() { + XlaBuilder builder("relu"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto z_value = builder.Parameter(0, shape, "z_value"); auto zero = use_bfloat16_ @@ -469,8 +496,8 @@ Computation ClientLibraryTestBase::CreateScalarRelu() { return computation_status.ConsumeValueOrDie(); } -Computation ClientLibraryTestBase::CreateScalarMax() { - ComputationBuilder builder(client_, "max"); +XlaComputation ClientLibraryTestBase::CreateScalarMax() { + XlaBuilder builder("max"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto x = builder.Parameter(0, shape, "x"); auto y = builder.Parameter(1, shape, "y"); @@ -480,8 +507,8 @@ Computation ClientLibraryTestBase::CreateScalarMax() { return computation_status.ConsumeValueOrDie(); } -Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { - ComputationBuilder builder(client_, "relu_sensitivity"); +XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { + XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto activation = builder.Parameter(0, shape, "activation"); auto backprop = builder.Parameter(1, shape, "backprop"); @@ -522,10 +549,26 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } +XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, + XlaBuilder* builder) { + XlaOp data_handle; + arguments_.push_back(CreateParameterAndTransferLiteral( + arguments_.size(), argument, "", builder, &data_handle)); + return data_handle; +} + +XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, + XlaBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); +} + std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { return CreateParameterAndTransferLiteral(parameter_number, literal, name, nullptr, builder, data_handle); } @@ -533,12 +576,12 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { const Literal* param_literal = &literal; std::unique_ptr converted_literal; if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + converted_literal = Literal::ConvertF32ToBF16(literal); param_literal = converted_literal.get(); } std::unique_ptr data = @@ -549,18 +592,4 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( return data; } -ComputationDataHandle ClientLibraryTestBase::AddParam( - const Literal& argument, ComputationBuilder* builder) { - ComputationDataHandle data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; -} - -ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( - const Literal& literal, ComputationBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ba0319990bc04196386e6812b0a03671676698ec..0499fec5898a42affa0e0a712dee10187355c13e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,11 +63,10 @@ std::vector ExpandUseBfloat16( // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: - explicit ClientLibraryTestBase( - perftools::gputools::Platform* platform = nullptr); + explicit ClientLibraryTestBase(se::Platform* platform = nullptr); // Creates a new ClientLibraryTestBase with custom client options. - ClientLibraryTestBase(perftools::gputools::Platform* platform, + ClientLibraryTestBase(se::Platform* platform, const LocalClientOptions& client_options); // Returns the name of the test currently being run. @@ -92,86 +91,86 @@ class ClientLibraryTestBase : public ::testing::Test { // execution options. Modify execution_options_ in your test if you want to // customize the options. StatusOr> Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); + StatusOr> ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); - // Convenience OrDie variants of above methods. - std::unique_ptr ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + // This executes the computation via the reference client (which connects a + // interpreter backend). The result is used as the expected values of the + // computation. + StatusOr> ExecuteAndTransferReference( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. - string ExecuteToString(ComputationBuilder* builder, + string ExecuteToString(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder for details). For each rank, two forms are provided: one - // for floating point types with an ErrorSpec parameter, and one for integral - // types without the ErrorSpec parameter. + // XlaBuilder for details). For each rank, two forms are + // provided: one for floating point types with an ErrorSpec parameter, and one + // for integral types without the ErrorSpec parameter. template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); template - void ComputeAndCompareR1(ComputationBuilder* builder, + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); template - void ComputeAndCompareR1(ComputationBuilder* builder, + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. - void ComputeAndCompareR1(ComputationBuilder* builder, + void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); template - void ComputeAndCompareR2(ComputationBuilder* builder, + void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments); template - void ComputeAndCompareR2(ComputationBuilder* builder, + void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); template - void ComputeAndCompareR3(ComputationBuilder* builder, + void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments); template - void ComputeAndCompareR3(ComputationBuilder* builder, + void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); template - void ComputeAndCompareR4(ComputationBuilder* builder, + void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments); template - void ComputeAndCompareR4(ComputationBuilder* builder, + void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); @@ -180,53 +179,51 @@ class ClientLibraryTestBase : public ::testing::Test { // literal. shape_with_layout indicates the result layout to request when // calling Execute. void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + Status ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + Status ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. void ComputeAndCompareR1U8( - ComputationBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, tensorflow::StringPiece expected, tensorflow::gtl::ArraySlice arguments); // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. void ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result - // with the HloEvaluator. - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, + // with the reference result. + void ComputeAndCompare(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, + void ComputeAndCompare(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Create scalar operations for use in reductions. - Computation CreateScalarRelu(); - Computation CreateScalarMax(); - Computation CreateScalarReluSensitivity(); + XlaComputation CreateScalarRelu(); + XlaComputation CreateScalarMax(); + XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -268,49 +265,46 @@ class ClientLibraryTestBase : public ::testing::Test { // elements, the literal will be converted to BF16 before being transferred. std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters // or none and no parameters must be passed when invoking the computation if // using this mechanism. If using this mechanism, then each parameter must be // set exactly once. The first added parameter gets index 0, then 1 and so on. - ComputationDataHandle AddParam(const Literal& argument, - ComputationBuilder* builder); + XlaOp AddParam(const Literal& argument, XlaBuilder* builder); template - ComputationDataHandle AddParam(const Array& argument, - ComputationBuilder* builder) { + XlaOp AddParam(const Array& argument, XlaBuilder* builder) { return AddParam(*Literal::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the // use_bfloat16 flag is set but the literal has F32 elements, the elements // will be converted to BF16s. - ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, - ComputationBuilder* builder); + XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be // converted to bfloat16s. + template - ComputationDataHandle CreateConstantFromArray(const Array& array, - ComputationBuilder* builder) { + XlaOp CreateConstantFromArray(const Array& array, + XlaBuilder* builder) { return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); } // Same as CreateConstantFromArray, but for scalars. template - ComputationDataHandle CreateConstantFromScalar(NativeT value, - ComputationBuilder* builder) { + XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { return CreateConstantFromLiteral(*Literal::CreateR0(value), builder); } @@ -324,9 +318,11 @@ class ClientLibraryTestBase : public ::testing::Test { // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. template - std::unique_ptr CreateR0Parameter( - NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + std::unique_ptr CreateR0Parameter(NativeT value, + int64 parameter_number, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -339,8 +335,7 @@ class ClientLibraryTestBase : public ::testing::Test { template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -354,8 +349,7 @@ class ClientLibraryTestBase : public ::testing::Test { template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -369,8 +363,7 @@ class ClientLibraryTestBase : public ::testing::Test { template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -381,29 +374,27 @@ class ClientLibraryTestBase : public ::testing::Test { PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } Client* client_; + Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; private: - // Build and run the computation with all permutations of output layouts. - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, + Status ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output); - // Build and run the computation with all permutations of layouts of all input - // arguments. - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, + Status ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); // Executes the computation and calculates the expected reference value using - // the HloEvaluator. Returns two literal in the order of (expected, actual). + // the reference client. Returns two literals in the order of (expected, + // actual). StatusOr, std::unique_ptr>> - ComputeValueAndReference(ComputationBuilder* builder, - const ComputationDataHandle& operand, + ComputeValueAndReference(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); // Whether to run tests with all float-type input/output converted to @@ -416,7 +407,7 @@ class ClientLibraryTestBase : public ::testing::Test { template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -426,7 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -442,7 +433,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -452,7 +443,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -468,7 +459,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -478,7 +469,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -494,7 +485,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -504,7 +495,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -520,7 +511,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -530,7 +521,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -547,10 +538,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4( template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { + XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -561,11 +552,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -576,11 +566,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -591,11 +580,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 045148cdd11da94ae4789a753efca95c6aaa1f27..08671cf62445826649b5c97003f998ae98a59d97 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -16,9 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -37,7 +38,7 @@ namespace { class ClientTest : public ClientLibraryTestBase {}; XLA_TEST_F(ClientTest, ExecuteWithLayout) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& execute_layout : layouts) { @@ -61,15 +62,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})}); @@ -90,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralView::Create(*result, {0})); + LiteralSlice(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralView::Create(*result, {1})); + LiteralSlice(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -107,16 +108,15 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -XLA_TEST_F(ClientTest, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { - Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; +XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { + XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr const_arg, client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Add(b.Parameter(0, shape, "param_0"), b.ConstantR2({{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); @@ -124,14 +124,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::ComputationInstance( + computation_instances.push_back(Client::XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, @@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, auto result_literal, client_->Transfer(*results[0], &expected_result->shape())); - LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 0f780fa87ef98fd5c48726ef83fa8efc1e90fbf7..50a006964869b3e5dce431d441f7cd81af9df910 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -39,7 +39,7 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: void ExecuteComputationR0F32( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; @@ -49,13 +49,13 @@ class CompilationCacheTest : public ClientLibraryTestBase { /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } void ExecuteComputationR2F32( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, std::initializer_list> expected_result, bool expect_cache_hit) { @@ -66,25 +66,28 @@ class CompilationCacheTest : public ClientLibraryTestBase { .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } ErrorSpec error_spec_{0.0001}; }; -XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { - ComputationBuilder builder(client_, TestName()); +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(42.0)); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, + DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = client_->TransferToServer(*Literal::CreateR0(42.0f)) .ConsumeValueOrDie(); @@ -95,9 +98,9 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { client_->TransferToServer(*Literal::CreateR0(456.0f)) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, /*expect_cache_hit=*/false); @@ -109,19 +112,20 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, MultipleComputations) { - ComputationBuilder builder_neg(client_, TestName() + "_neg"); +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { + XlaBuilder builder_neg(TestName() + "_neg"); builder_neg.Neg(builder_neg.ConstantR0(42.0)); - Computation computation_neg = builder_neg.Build().ConsumeValueOrDie(); + XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie(); - ComputationBuilder builder_exp(client_, TestName() + "_exp"); + XlaBuilder builder_exp(TestName() + "_exp"); builder_exp.Exp(builder_exp.ConstantR0(1.0)); - Computation computation_exp = builder_exp.Build().ConsumeValueOrDie(); + XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie(); - ComputationBuilder builder_add(client_, TestName() + "_add"); + XlaBuilder builder_add(TestName() + "_add"); builder_add.Add(builder_add.ConstantR0(2.0), builder_add.ConstantR0(3.0)); - Computation computation_add = builder_add.Build().ConsumeValueOrDie(); + XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation_neg, {}, -42.0, /*expect_cache_hit=*/false); @@ -133,7 +137,8 @@ XLA_TEST_F(CompilationCacheTest, MultipleComputations) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { // Create two GlobalData arrays with the same shape but different // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache @@ -148,9 +153,9 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR2F32(computation, {colmaj_handle.get()}, {{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -169,32 +174,5 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, MutatedComputation) { - // Build a computation, execute it, then mutate it. The mutated computation - // should not be in the cache until it is run once. This must be done through - // the stub interface because Computations built from ComputationBuilder are - // immutable. - ComputationBuilder builder(client_, TestName()); - auto neg = builder.Neg(builder.ConstantR0(42.0)); - Computation computation = builder.Build().ConsumeValueOrDie(); - - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); - - BinaryOpRequest request; - request.set_binop(BINOP_ADD); - *request.mutable_lhs() = neg; - *request.mutable_rhs() = neg; - OpRequest op_request; - *op_request.mutable_computation() = computation.handle(); - *op_request.mutable_binary_op_request() = request; - OpResponse response; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - ASSERT_TRUE(s.ok()); - - ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ec2c580670cfac14ba42e8c9a836c86551af4b89..ba22530f1cfee56337f862c25122d399dbf0f1e4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -43,16 +45,14 @@ ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; class ComputeConstantTest : public ::testing::Test { public: - explicit ComputeConstantTest( - perftools::gputools::Platform* platform = nullptr) + explicit ComputeConstantTest(se::Platform* platform = nullptr) : platform_(platform) {} string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - Client* ClientOrDie(::perftools::gputools::Platform* platform, - ClientType client_type) { + Client* ClientOrDie(se::Platform* platform, ClientType client_type) { if (client_type == ClientType::kLocal) { StatusOr result = ClientLibrary::GetOrCreateLocalClient(platform); @@ -70,39 +70,35 @@ class ComputeConstantTest : public ::testing::Test { } StatusOr> ComputeConstantLiteral( - Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( - operand, output_layout, parameters)); + Client* client, const XlaOp& operand, XlaBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); + TF_ASSIGN_OR_RETURN(auto computed, + client->ComputeConstant(subgraph, output_layout)); return std::move(computed); } template - StatusOr ComputeConstantScalar( - Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN( - auto literal, - ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); + StatusOr ComputeConstantScalar(Client* client, const XlaOp& operand, + XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, + builder, nullptr)); return literal->Get({}); } - bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder, int64 num_parameters = 0) { - StatusOr result = builder->IsConstant(operand, num_parameters); + bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { + StatusOr result = builder->IsConstant(operand); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } - perftools::gputools::Platform* platform_; + se::Platform* platform_; }; TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.ConstantR0(42); EXPECT_TRUE(IsConstant(computation, &b)); @@ -115,7 +111,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { TEST_F(ComputeConstantTest, ScalarFloatAdd) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); @@ -129,7 +125,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { TEST_F(ComputeConstantTest, ScalarRng) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), ShapeUtil::MakeShape(F32, {})); @@ -141,34 +137,16 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, Param) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); - auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); - auto computation = b.Add(param, b.ConstantR0(1.5f)); - - std::vector arguments; - arguments.push_back(std::move(*Literal::CreateR0(42.5f))); - EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); - - auto value = - ComputeConstantScalar(client, computation, &b, arguments); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 44.0f); - } -} - TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on a parameter")) + EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), + "depends on a parameter")) << value.status(); } } @@ -176,15 +154,15 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR0(1.0f), b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on a parameter")) + EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), + "depends on a parameter")) << value.status(); } } @@ -194,7 +172,7 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { TEST_F(ComputeConstantTest, UnrelatedParam) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = @@ -211,64 +189,64 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { EXPECT_TRUE(IsConstant(constant_13, &b)); - auto value = ComputeConstantScalar(client, constant_13, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 13.0f); + TF_ASSERT_OK_AND_ASSIGN( + auto value, ComputeConstantScalar(client, constant_13, &b)); + EXPECT_EQ(value, 13.0f); } } TEST_F(ComputeConstantTest, NonScalarAdd) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(client, computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computed, + ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(client, computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computed, + ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } XLA_TEST_F(ComputeConstantTest, Layout) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& layout : layouts) { auto layout_proto = LayoutUtil::MakeLayout(layout); - auto computed = ComputeConstantLiteral( - client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), - &b, &layout_proto); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN( + auto computed, ComputeConstantLiteral( + client, + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto)); std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 1bcad5a3f37a37c9d482f3a5a899ac527666cca3..a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -38,9 +38,9 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { - ComputationBuilder builder(client_, TestName()); - auto concatenated = builder.ConcatInDim({}, 0); - StatusOr computation_status = builder.Build(); + XlaBuilder builder(TestName()); + builder.ConcatInDim({}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("Concatenate expects at least one argument")); @@ -48,18 +48,18 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -68,51 +68,51 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(42.0); auto b = builder.ConstantR0(64.0); - auto concatenated = builder.ConcatInDim({a, b}, 0); - StatusOr computation_status = builder.Build(); + builder.ConcatInDim({a, b}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), - HasSubstr("dimension to concatenate along out of bounds: 0")); + HasSubstr("out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); auto b = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); auto b = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -129,20 +129,20 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { expected[253 + i] = rhs[i] = 253 + i + 1; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(lhs); auto b = builder.ConstantR1(rhs); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto concatenated = builder.ConcatInDim({a, b}, dim); + builder.ConcatInDim({a, b}, dim); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, ErrorSpec(0.0001)); @@ -150,26 +150,27 @@ XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D expected({ - {0}, {64}, + {0}, + {64}, }); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D expected({ {0, 64}, @@ -178,22 +179,22 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { } XLA_TEST_F(ConcatTest, Concat2x0With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat2x3With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -203,22 +204,22 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { } XLA_TEST_F(ConcatTest, Concat3x2With0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat3x2With5x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D expected({ {0, 1}, @@ -234,16 +235,16 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { } XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_array({ // 3x1x2 {{0, 1}}, @@ -258,27 +259,29 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { }); auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); Array3D expected({ - {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}}, + {{0, 1, 6}}, + {{2, 3, 7}}, + {{4, 5, 8}}, }); ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b, c}, 0); + builder.ConcatInDim({a, b, c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_array({ // 3x1x2 {{0, 1}}, @@ -300,35 +303,35 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); auto c = builder.ConstantR3FromArray3D(c_array); - auto concatenated = builder.ConcatInDim({a, b, c}, 2); + builder.ConcatInDim({a, b, c}, 2); Array3D expected({ - {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}}, + {{0, 1, 2, 3}}, + {{4, 5, 6, 7}}, + {{8, 9, 10, 11}}, }); ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); // concatenated = (a concat b) concat c - auto concatenated = - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); // concatenated = a concat (b concat c) - auto concatenated = - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -342,7 +345,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 0); @@ -363,7 +366,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 1); @@ -388,7 +391,7 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 1); @@ -404,13 +407,13 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { // Show that we can't concatenate with an opaques. XLA_TEST_F(ConcatTest, CannotConcatOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); auto x = builder.Parameter(0, r1f32, "x"); auto y = builder.Parameter(1, opaque_shape, "y"); - auto concatenated = builder.ConcatInDim({x, y}, 0); - StatusOr computation_status = builder.Build(); + builder.ConcatInDim({x, y}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), @@ -418,23 +421,23 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto p0 = builder.ConstantR1({true}); auto p1 = builder.ConstantR1({false}); auto p2 = builder.ConstantR1({true}); - auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0); + builder.ConcatInDim({p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR1({1}); auto a1 = builder.ConstantR1({2, 3}); auto a2 = builder.ConstantR1({4, 5, 6}); auto a3 = builder.ConstantR1({7, 8, 9, 10}); - auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0); + builder.ConcatInDim({a0, a1, a2, a3}, 0); std::vector expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -442,7 +445,7 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { } XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D arr0(9, 17, 1); arr0.Fill(1); @@ -462,14 +465,14 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { } } - ComputationDataHandle h0; + XlaOp h0; auto p0 = CreateR3Parameter(arr0, /*parameter_number=*/0, "p0", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - auto concatenated = builder.ConcatInDim({h0, h1}, 2); + builder.ConcatInDim({h0, h1}, 2); ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } @@ -495,7 +498,7 @@ TEST_P(ConcatR2BinaryTest, DoIt) { Array2D rhs(spec.rhs_dim0, spec.rhs_dim1); rhs.FillUnique(1000); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR2FromArray2D(lhs); auto a1 = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a0, a1}, spec.concat_dimension); @@ -521,7 +524,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, f32_scalar, "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto mul = builder.Mul(x, y); @@ -545,7 +548,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "z"); @@ -573,7 +576,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "y"); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index bc821674820fb128823786d7149037fc59b22ab6..7ff6706935740c7d76ee5cd03eae292386760397 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -23,8 +24,8 @@ namespace { class ConditionalOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0ConstantComputation(float value) { - ComputationBuilder builder(client_, "Constant"); + XlaComputation CreateR0ConstantComputation(float value) { + XlaBuilder builder("Constant"); builder.Parameter(0, empty_tuple_, "tuple"); builder.ConstantR0(value); auto build_status = builder.Build(); @@ -32,16 +33,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0IdentityComputation() { - ComputationBuilder builder(client_, "Identity"); + XlaComputation CreateR0IdentityComputation() { + XlaBuilder builder("Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateCeilComputation(const Shape& shape) { - ComputationBuilder builder(client_, "Ceil"); + XlaComputation CreateCeilComputation(const Shape& shape) { + XlaBuilder builder("Ceil"); auto param = builder.Parameter(0, shape, "param"); builder.Ceil(param); auto build_status = builder.Build(); @@ -49,16 +50,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0CeilComputation() { + XlaComputation CreateR0CeilComputation() { return CreateCeilComputation(r0f32_); } - Computation CreateR1CeilComputation() { + XlaComputation CreateR1CeilComputation() { return CreateCeilComputation(r1s2f32_); } - Computation CreateFloorComputation(const Shape& shape) { - ComputationBuilder builder(client_, "Floor"); + XlaComputation CreateFloorComputation(const Shape& shape) { + XlaBuilder builder("Floor"); auto param = builder.Parameter(0, shape, "param"); builder.Floor(param); auto build_status = builder.Build(); @@ -66,17 +67,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0FloorComputation() { + XlaComputation CreateR0FloorComputation() { return CreateFloorComputation(r0f32_); } - Computation CreateR1FloorComputation() { + XlaComputation CreateR1FloorComputation() { return CreateFloorComputation(r1s2f32_); } - Computation CreateTupleCeilComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleCeilComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -88,17 +89,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleCeilComputation() { + XlaComputation CreateR0TupleCeilComputation() { return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_); } - Computation CreateR1TupleCeilComputation() { + XlaComputation CreateR1TupleCeilComputation() { return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_); } - Computation CreateTupleFloorComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleFloorComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -110,17 +111,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleFloorComputation() { + XlaComputation CreateR0TupleFloorComputation() { return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_); } - Computation CreateR1TupleFloorComputation() { + XlaComputation CreateR1TupleFloorComputation() { return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_); } - Computation CreateTupleAddComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleAddComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -130,17 +131,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleAddComputation() { + XlaComputation CreateR0TupleAddComputation() { return CreateTupleAddComputation("AddR0", tuple_2_r0f32_); } - Computation CreateR1TupleAddComputation() { + XlaComputation CreateR1TupleAddComputation() { return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_); } - Computation CreateTupleSubComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleSubComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -150,11 +151,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleSubComputation() { + XlaComputation CreateR0TupleSubComputation() { return CreateTupleSubComputation("SubR0", tuple_2_r0f32_); } - Computation CreateR1TupleSubComputation() { + XlaComputation CreateR1TupleSubComputation() { return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_); } @@ -170,26 +171,25 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); - auto result = builder.Conditional(pred, operands, true_computation, operands, - false_computation); + builder.Conditional(pred, operands, true_computation, operands, + false_computation); ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); } // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); auto identity = CreateR0IdentityComputation(); - auto result = - builder.Conditional(pred, operand1, identity, operand2, identity); + builder.Conditional(pred, operand1, identity, operand2, identity); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -197,12 +197,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); - auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -210,11 +210,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { // Test conditional with two different computations in the true and false cases // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand = builder.ConstantR0(12.6f); - auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(), - operand, CreateR0FloorComputation()); + builder.Conditional(pred, operand, CreateR0CeilComputation(), operand, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -222,12 +222,12 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { // Test conditional with the same computation in the true and false cases but // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); auto floor = CreateR0FloorComputation(); - auto result = builder.Conditional(pred, operand1, floor, operand2, floor); + builder.Conditional(pred, operand1, floor, operand2, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -235,11 +235,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { // Test conditional with the same computation in the true and false cases that // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand = builder.ConstantR0(12.6f); auto floor = CreateR0FloorComputation(); - auto result = builder.Conditional(pred, operand, floor, operand, floor); + builder.Conditional(pred, operand, floor, operand, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -247,12 +247,12 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { // Test conditional with different instances of the same computation in the true // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); - auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -260,7 +260,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { // Test the case when a call invokes a computation that contains a conditional. XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); @@ -268,7 +268,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { false_operand, CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); @@ -281,14 +281,13 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { // Test true and false computations that take in 2 parameters and predicate is // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), - operands, CreateR0TupleSubComputation()); + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); } @@ -296,14 +295,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { // Test true and false computations that take in 2 parameters and predicate is // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), - operands, CreateR0TupleSubComputation()); + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); } @@ -311,14 +309,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { // Test true and false computations that take in 2 array parameters and // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operand1 = builder.ConstantR1({24.0f, 56.0f}); auto operand2 = builder.ConstantR1({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), - operands, CreateR1TupleSubComputation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); } @@ -326,21 +323,20 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR1({24.0f, 56.0f}); auto operand2 = builder.ConstantR1({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), - operands, CreateR1TupleSubComputation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); } // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operands = builder.Tuple( {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); @@ -356,7 +352,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), builder.ConstantR1({25.6f, 29.2f})}); @@ -373,7 +369,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { // Test true and false computations that return a tuple of a predicate, a // scalar, and an array. XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { - ComputationBuilder true_builder(client_, TestName() + ".true"); + XlaBuilder true_builder(TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); auto true_pred = true_builder.ConstantR0(true); @@ -384,7 +380,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); - ComputationBuilder false_builder(client_, TestName() + ".false"); + XlaBuilder false_builder(TestName() + ".false"); { false_builder.Parameter(0, empty_tuple_, "tuple"); auto false_pred = false_builder.ConstantR0(false); @@ -395,7 +391,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({}); builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), @@ -411,7 +407,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { // Test true and false computations that return a nested tuple. XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { - ComputationBuilder true_builder(client_, TestName() + ".true"); + XlaBuilder true_builder(TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); auto true_constant1 = true_builder.ConstantR0(12.2f); @@ -424,7 +420,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); - ComputationBuilder false_builder(client_, TestName() + ".false"); + XlaBuilder false_builder(TestName() + ".false"); { false_builder.Parameter(0, empty_tuple_, "tuple"); auto false_constant1 = false_builder.ConstantR0(46.6f); @@ -438,7 +434,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operands = builder.Tuple({}); builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), @@ -460,16 +456,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { // params. XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle pred, operand1, operand2; + XlaOp pred, operand1, operand2; auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operand1_param = CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); - auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0( &builder, 57.0f, @@ -480,16 +476,16 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { // Test conditional that takes in array operands in the form of external params. XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle pred, operand1, operand2; + XlaOp pred, operand1, operand2; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1_param = CreateR1Parameter({24.3f, 56.7f}, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", &builder, &operand2); - auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(), - operand2, CreateR1FloorComputation()); + builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2, + CreateR1FloorComputation()); ComputeAndCompareR1( &builder, {10.0f, 11.0f}, @@ -499,7 +495,7 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { // Test the case where one conditional is nested within another. XLA_TEST_F(ConditionalOpTest, NestedConditionals) { - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); @@ -514,7 +510,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred1 = builder.ConstantR0(true); auto pred2 = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(1.1f); @@ -529,7 +525,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { } XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); @@ -544,7 +540,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred2 = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(1.1f); auto operand2 = builder.ConstantR0(12.2f); @@ -556,7 +552,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); @@ -571,5 +567,56 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { "only parameter of true_computation")); } +XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_}); + XlaComputation swapper; + { + XlaBuilder builder(TestName() + ".swapper"); + auto param0 = builder.Parameter(0, tuple_shape, "sp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({y, x}); + swapper = builder.Build().ConsumeValueOrDie(); + } + XlaComputation forwarder; + { + XlaBuilder builder(TestName() + ".forwarder"); + auto param0 = builder.Parameter(0, tuple_shape, "fp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({x, y}); + forwarder = builder.Build().ConsumeValueOrDie(); + } + XlaComputation main; + { + XlaBuilder builder(TestName() + ".main"); + auto param0 = builder.Parameter(0, tuple_shape, "mp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + auto lt_pred = builder.Lt(x, y); + auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = builder.Ge(x, y); + builder.Conditional(ge_pred, res, swapper, res, forwarder); + main = builder.Build().ConsumeValueOrDie(); + } + + auto test_swap = [&](float a, float b) { + XlaBuilder builder(TestName()); + auto x = builder.ConstantR0(a); + auto y = builder.ConstantR0(b); + auto tuple_operand = builder.Tuple({x, y}); + builder.Call(main, {tuple_operand}); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(a).get(), + Literal::CreateR0(b).get()}), + {}, error_spec_); + }; + + test_swap(3.11f, 9.4f); + test_swap(11.24f, 5.55f); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 35aa3f6d696297efb7d95d826ed75a504a24529d..916ffadbc798ec0dd016f45b0bc4c36233455ee7 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,12 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,7 +38,7 @@ class ConstantsTest : public ClientLibraryTestBase { }; TEST_F(ConstantsTest, ZeroCellF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({}); ComputeAndCompareR1(&builder, {}, {}, error_spec_); @@ -48,7 +47,7 @@ TEST_F(ConstantsTest, ZeroCellF32) { TEST_F(ConstantsTest, OneCellF32) { std::vector constant = {2.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); @@ -57,7 +56,7 @@ TEST_F(ConstantsTest, OneCellF32) { TEST_F(ConstantsTest, OneCellS32) { std::vector constant = {2}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}); @@ -66,7 +65,7 @@ TEST_F(ConstantsTest, OneCellS32) { TEST_F(ConstantsTest, OneCellU32) { std::vector constant = {2}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}); @@ -75,7 +74,7 @@ TEST_F(ConstantsTest, OneCellU32) { TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); @@ -85,14 +84,14 @@ TEST_F(ConstantsTest, SixteenCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR2FromArray2D(Array2D(0, 2)); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); @@ -102,14 +101,14 @@ TEST_F(ConstantsTest, Small_2x2) { std::unique_ptr> constant = MakeLinspaceArray2D(100.0, 200.0, 2, 2); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR2FromArray2D(*constant); ComputeAndCompareR2(&builder, *constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_3x0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto constant = builder.ConstantLiteral( *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); @@ -117,7 +116,7 @@ TEST_F(ConstantsTest, Empty_3x0x2) { } TEST_F(ConstantsTest, Small_2x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D array3d({ // x0 x1 {{1.f, 2.f}, // y0 @@ -145,13 +144,13 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { Literal::CreateR4FromArray4D(input_array); { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantLiteral(*input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR4FromArray4D(input_array); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,17 +158,18 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantLiteral( *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), Literal::CreateR1({2.0, 42}).get()})); - std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); + std::unique_ptr result = + ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); + {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index f66e3b57bf45fbc9f8ea786146d6fffe5d55a262..722d882471a41a75c1e5e60f8c1a151b76c7e004 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -18,13 +18,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -34,7 +36,7 @@ namespace { class ConvertTest : public ClientLibraryTestBase { public: - explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) + explicit ConvertTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -42,7 +44,7 @@ class ConvertTest : public ClientLibraryTestBase { }; TEST_F(ConvertTest, ConvertR1S32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.ConvertElementType(a, S32); @@ -51,7 +53,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { } TEST_F(ConvertTest, ConvertR1F32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0f, 64.0f}); builder.ConvertElementType(a, F32); @@ -60,7 +62,7 @@ TEST_F(ConvertTest, ConvertR1F32ToR1F32) { } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.ConvertElementType(a, F32); @@ -69,7 +71,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, true}); builder.ConvertElementType(a, S32); @@ -78,7 +80,7 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { } TEST_F(ConvertTest, ConvertR1PREDToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, true}); builder.ConvertElementType(a, F32); @@ -87,7 +89,7 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); builder.ConvertElementType(a, F32); @@ -96,7 +98,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.6, 64.4}); builder.ConvertElementType(a, S32); @@ -105,16 +107,168 @@ TEST_F(ConvertTest, ConvertR1F32ToR1S32) { } XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, F32); + XlaBuilder builder(TestName()); + std::vector arg{ + -9223371216516022272, + -2, + -1, + -0x7FFFFFFF, + -0x80000000, + 0, + 1, + 2, + 1073742145, + 1073742656, + 0x7FFFFFFF, + 0x80000000, + 826720496944058148, + 4296062029846194332, + 0x0007FB72E4000000LL, + 0x0007FB72E4000001LL, + 0x0007FB72E6000000LL, + 0x0007FB72E7000000LL, + 0x0007FB72E7FFFFFFLL, + 0x0007FB72E8000000LL, + 0x0007FB72E8000001LL, + 0x0007FB72EA000000LL, + 0x0007FB72EB000000LL, + 0x0007FB72EBFFFFFFLL, + 0x0007FB72EC000000LL, + 0x7FFFFF0000000000LL, + 0x7FFFFF8000000000LL, + 0x7FFFFFFFFFFFFF00, + static_cast(0xFFFFFFFFFFFFFFFF), + static_cast(0x0000f234e67e0001LL), + static_cast(0x8000000000000000), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + static_cast(0x8000008000000000LL), + static_cast(0x8000010000000000LL), + }; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, F32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} - std::vector expected = {32.0, 64.0}; - ComputeAndCompareR1(&builder, expected, {}); +XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { + XlaBuilder builder(TestName()); + std::vector arg{0, 1, 0x1000, 0x7fffffff, + 0x80000000, 0x80000001, 0x80000002, 0x80000003, + 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, F32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { + XlaBuilder builder(TestName()); + std::vector arg{0.0f, 1.0f, 16777216.0f, + 16777218.0f, 2147483647.0f, 4294967040.0f}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, U32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { + XlaBuilder builder(TestName()); + std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { + XlaBuilder builder(TestName()); + std::vector arg{0, 1, 0x1000, -1, -0x1000}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { + XlaBuilder builder(TestName()); + // Test cases from compiler_rt library. + std::vector arg{0.0f, + 0.5f, + 0.99f, + 1.0f, + 1.5f, + 1.99f, + 2.0f, + 2.01f, + 2147483648.f, + -0.5f, + -0.99f, + -1.0f, + -1.5f, + -1.99f, + -2.0f, + -2.01f, + 9223371487098961920.f, + 9223370937343148032.f, + -9223371487098961920.f, + -9223370937343148032.f}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, F32); @@ -123,7 +277,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, S32); @@ -132,7 +286,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, U32); @@ -141,7 +295,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32.0f, 64.0f}); builder.ConvertElementType(a, F64); @@ -150,7 +304,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { } XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32.0, 64.0}); builder.ConvertElementType(a, F32); @@ -159,7 +313,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { } TEST_F(ConvertTest, ConvertS32Extremes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {std::numeric_limits::min(), std::numeric_limits::max()}); builder.ConvertElementType(a, F32); @@ -171,7 +325,7 @@ TEST_F(ConvertTest, ConvertS32Extremes) { } TEST_F(ConvertTest, ConvertMapToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->ConvertElementType(param, S32); @@ -183,7 +337,7 @@ TEST_F(ConvertTest, ConvertMapToS32) { } TEST_F(ConvertTest, ConvertMapToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->ConvertElementType(param, F32); @@ -200,7 +354,7 @@ TEST_F(ConvertTest, ConvertMapToF32) { // input -> convert -> reshape // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR1({42}); auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); builder.ConvertElementType(reshape, F32); @@ -208,5 +362,104 @@ TEST_F(ConvertTest, ConvertReshape) { ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } +std::vector GetInterestingF16ConversionTestCases() { + float infinity = std::numeric_limits::infinity(); + float half_min_positive_normal = + tensorflow::bit_cast(0x38800000); + float half_max_subnormal = tensorflow::bit_cast(0x387fc000); + float half_min_positive_subnormal = + tensorflow::bit_cast(0x33800000); + float half_max = 65504.0f; + + std::vector test_cases( + {-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f, + -half_min_positive_subnormal, -half_max_subnormal, + -half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal, + half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max, + (half_max * 2 + 1), infinity}); + return test_cases; +} + +XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { + std::vector test_cases = GetInterestingF16ConversionTestCases(); + std::vector input; + c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); + std::vector expected_output; + c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast(h); }); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dot_lhs_handle, + client_->TransferToServer(*Literal::CreateR1(input))); + + XlaBuilder builder(TestName()); + builder.ConvertElementType( + builder.Parameter( + 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), + "param"), + F32); + + ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { + std::vector input = GetInterestingF16ConversionTestCases(); + std::vector expected_output; + c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dot_lhs_handle, + client_->TransferToServer(*Literal::CreateR1(input))); + + XlaBuilder builder(TestName()); + builder.ConvertElementType( + builder.Parameter( + 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), + "param"), + F16); + + ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertC64ToC64) { + XlaBuilder builder(TestName()); + std::vector x = {{42.0f, 64.0f}}; + builder.ConvertElementType(builder.ConstantR1(x), C64); + ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConvertTest, ConvertS64S64) { + XlaBuilder builder(TestName()); + std::vector x = {{-42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), S64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64U64) { + XlaBuilder builder(TestName()); + std::vector x = {{42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), U64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64S64) { + XlaBuilder builder(TestName()); + std::vector unsigned_x = {{42, UINT64_MAX}}; + builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + std::vector signed_x = {{42, -1}}; + ComputeAndCompareR1(&builder, signed_x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertS64U64) { + XlaBuilder builder(TestName()); + std::vector signed_x = {{42, -1, INT64_MIN}}; + builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + std::vector unsigned_x = { + {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; + ComputeAndCompareR1(&builder, unsigned_x, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 896b34fb6e2762c14bd9ec2bf1ba13c548d4cf60..b5a42e305987df030c15d089f5877f73bb61de1b 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,13 +34,35 @@ limitations under the License. namespace xla { namespace { +StatusOr CreateConvDimensionNumbers( + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, + int64 kernel_output_feature, int64 kernel_input_feature, + int64 kernel_first_spatial, int64 kernel_second_spatial) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(input_batch); + dimension_numbers.set_input_feature_dimension(input_feature); + dimension_numbers.add_input_spatial_dimensions(input_first_spatial); + dimension_numbers.add_input_spatial_dimensions(input_second_spatial); + dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); + dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); + dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); + dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); + dimension_numbers.add_output_spatial_dimensions(output_first_spatial); + dimension_numbers.add_output_spatial_dimensions(output_second_spatial); + TF_RETURN_IF_ERROR(XlaBuilder::Validate(dimension_numbers)); + return dimension_numbers; +} + class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, - 1, 2, 3); + CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -49,8 +71,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, - 2, 2, 3); + CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, 2, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); @@ -59,8 +80,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { // Tests the convolution operation with invalid output dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, - 1, 2, 3); + CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("output are not unique")); @@ -76,14 +96,14 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(*input_array); auto weight = builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); ConvolutionDimensionNumbers dim_nums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(); + XlaBuilder::CreateDefaultConvDimensionNumbers(); // Swap batch_dimension and feature_dimension. int64 old_input_batch_dim = dim_nums.input_batch_dimension(); int64 old_output_batch_dim = dim_nums.output_batch_dimension(); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 1385b437fc47fe5289c401581fab8b5278872382..346bb3a3996ee5bf662b0f74dd0c2096efbf5295 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -47,33 +47,18 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2); + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4); + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4); #endif }; -// TODO(b/72509305): Enable half data type tests for CPU -#if (XLA_TEST_BACKEND_GPU) -using TestTypes = ::testing::Types; -#else +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; #endif -template -Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions); - -template <> -Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions) { - return ShapeUtil::MakeShape(F32, dimensions); -} - -template <> -Shape MakeShapeWrapper( - tensorflow::gtl::ArraySlice dimensions) { - return ShapeUtil::MakeShape(F16, dimensions); -} - template class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { public: @@ -103,12 +88,12 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { ASSERT_EQ(2, arhs->width()); ASSERT_EQ(2, arhs->height()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR4FromArray4D(*alhs); auto rhs = builder.ConstantR4FromArray4D(*arhs); - auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - ComputeAndCompare(&builder, conv, {}, error_spec_); + ComputeAndCompare(&builder, {}, error_spec_); } }; @@ -121,12 +106,12 @@ template class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 1, 2}); - Shape filter_shape = MakeShapeWrapper({1, 1, 1, 2}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -137,7 +122,7 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { {5.0f, 6.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -152,12 +137,12 @@ template class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -171,7 +156,7 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {5.0f, 6.0f}, {7.0f, 8.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -186,12 +171,12 @@ template class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + builder.Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -206,7 +191,7 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { {7.0f, 8.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -222,12 +207,12 @@ template class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 3, 3}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 3, 3}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + builder.Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, @@ -238,7 +223,7 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { filter_data.FillWithYX(Array2D( {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -249,7 +234,7 @@ TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes); TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -279,10 +264,10 @@ template class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { - Shape input_shape = MakeShapeWrapper({1, 2, 5}); - Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. @@ -315,7 +300,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -346,7 +331,7 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -380,10 +365,10 @@ template class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { - Shape input_shape = MakeShapeWrapper({1, 2, 5}); - Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. @@ -417,7 +402,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_dims = {1, 4, 2, 3, 3}; std::vector filter_dims = {2, 2, 2, 3, 3}; Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); @@ -484,11 +469,11 @@ template class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_dims = {1, 3, 3, 5}; std::vector filter_dims = {3, 3, 5, 3}; - Shape input_shape = MakeShapeWrapper(input_dims); - Shape filter_shape = MakeShapeWrapper(filter_dims); + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); @@ -552,7 +537,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "convolution-canonicalization"); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); @@ -566,8 +551,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, dnums.set_kernel_output_feature_dimension(1); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); - auto conv = builder.ConvWithGeneralDimensions(input, filter, {}, - Padding::kValid, dnums); + builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); Array2D param0(4, 29); param0.FillUnique(); @@ -578,7 +562,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Array2D expected_result(29, 10); expected_result.Fill(0); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(param0)), std::move(*Literal::CreateFromArray(param1))}, error_spec_); @@ -602,7 +586,7 @@ class Convolve1D1WindowTestBase protected: template void TestImpl() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); int64 input_feature = GetParam().input_feature; int64 output_feature = GetParam().output_feature; int64 batch = GetParam().batch; @@ -612,8 +596,8 @@ class Convolve1D1WindowTestBase input_feature}; std::vector filter_dims = {window_size, input_feature, output_feature}; - Shape input_shape = MakeShapeWrapper(input_dims); - Shape filter_shape = MakeShapeWrapper(filter_dims); + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); @@ -699,9 +683,7 @@ INSTANTIATE_TEST_CASE_P( #if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {}; -// TODO(b/72509305): Enable half data type tests for CPU. -XLA_TEST_P(Convolve1D1WindowTestHalf, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(Convolve1D1Window))) { +XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) { TestImpl(); } @@ -719,14 +701,16 @@ INSTANTIATE_TEST_CASE_P( Convolve1DTestParam{130, 1, 1, 1, 3}, Convolve1DTestParam{64, 1, 1, 1, 1}, Convolve1DTestParam{128, 1, 1, 1, 1}, - // TODO(b/72566306): the following three tests fail on CPU - // backend due to result miscompare. +// TODO(b/72566306): The following five tests failed on CPU with unreasonable +// relative errors. Last ran on 2018-02-22. +#if XLA_TEST_BACKEND_GPU Convolve1DTestParam{139, 1, 1, 128, 1}, Convolve1DTestParam{640, 3, 3, 128, 1}, Convolve1DTestParam{900, 1, 1, 10, 1}, Convolve1DTestParam{1, 10, 10, 1, 10}, - Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 10, 130, 1, 1}, +#endif + Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 64, 64, 1, 10}, Convolve1DTestParam{1, 65, 65, 1, 1}, Convolve1DTestParam{1, 128, 128, 1, 1}, @@ -738,13 +722,13 @@ INSTANTIATE_TEST_CASE_P( ); #endif -TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { - ComputationBuilder builder(client_, TestName()); +XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -755,11 +739,34 @@ TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { {bfloat16(5), bfloat16(6)}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } +// Check that GPU convs still work if the CudnnAlgorithmPicker pass is disabled. +// (We run this test on all platforms, because, what the heck.) +XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "cudnn-convolution-algorithm-picker"); + + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillIota(0); + Array4D filter_data(1, 1, 1, 2); + filter_data.FillIota(10); + + ComputeAndCompare(&builder, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 9c1145def8c11f1222c63adf006102887d49f00d..fea850dc135e33fe098aa755c6fdd93319cd2837 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -52,7 +52,7 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase { }; XLA_TEST_F(ConvolutionVariantsTest, Minimal) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Array4D input_array(1, 1, 1, 1, {2}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -67,7 +67,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { } XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -82,7 +82,7 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { } XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(2, 1, 3, 4); input_array.FillWithMultiples(1); @@ -99,7 +99,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 1, {10, 1}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -114,7 +114,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 2, {1, 2}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -129,7 +129,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -144,7 +144,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -159,7 +159,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -174,7 +174,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -189,7 +189,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array( 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 @@ -210,7 +210,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -225,7 +225,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -240,7 +240,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -255,7 +255,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -270,7 +270,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -285,7 +285,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 1, {1}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -300,7 +300,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -315,7 +315,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -333,7 +333,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -348,7 +348,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -363,7 +363,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -378,7 +378,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(64); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -398,7 +398,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(16 * 1 * 1 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -419,7 +419,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int bs = 16; constexpr int kx = 2; @@ -450,7 +450,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int kx = 2; constexpr int ky = 2; @@ -482,7 +482,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(16, 1, 8, 8); for (int i0 = 0; i0 < 16; ++i0) { @@ -510,7 +510,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -536,7 +536,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(2 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -562,7 +562,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(32 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -602,7 +602,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(16, 16, 1, 1); Array4D filter_array(16, 16, 1, 1); @@ -628,7 +628,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 4 * 6); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -640,14 +640,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 2, 2, {3924, 4257, 5922, 6255}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -659,14 +659,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 3 * 4); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -682,8 +682,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, - /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // @@ -693,7 +692,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -705,14 +704,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, -1}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 2, {23, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -724,14 +723,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, 2}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 5, {23, 34, 45, 50, 0}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -743,14 +742,14 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {2, -1}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 5, {0, 1, 12, 23, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -763,7 +762,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {3, 2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); // input: // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] @@ -775,7 +774,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -788,7 +787,7 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-3, -2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); // input: // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] @@ -821,7 +820,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -854,7 +853,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -887,7 +886,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -920,7 +919,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -954,7 +953,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -966,7 +965,7 @@ XLA_TEST_F(ConvolutionVariantsTest, } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1010,7 +1009,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1054,7 +1053,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1095,7 +1094,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 2); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1147,7 +1146,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); @@ -1166,19 +1165,18 @@ XLA_TEST_F(ConvolutionVariantsTest, // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 1, /*values=*/{1})); auto weights = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvGeneralDilated( - gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 3}}, - /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + builder.ConvGeneralDilated(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); } @@ -1187,7 +1185,7 @@ XLA_TEST_F(ConvolutionVariantsTest, // into // BackwardInputConv([1], [1,10,100], padding=(1,1)) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 1, /*values=*/{1})); @@ -1208,7 +1206,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { // However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't // support negative padding on backward convolution yet (b/32744257). XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); @@ -1224,7 +1222,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1240,7 +1238,7 @@ XLA_TEST_F(ConvolutionVariantsTest, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {1, 2}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); @@ -1248,7 +1246,7 @@ XLA_TEST_F(ConvolutionVariantsTest, XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingGreaterThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1266,14 +1264,14 @@ XLA_TEST_F(ConvolutionVariantsTest, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {2, 0}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1293,14 +1291,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {2, 1}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR3FromArray3D( Array3D(1, 1, 1, /*value=*/1)); @@ -1314,26 +1312,26 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto activations = builder.ConstantR3FromArray3D(Array3D({{{1, 2, 3, 4}}})); auto gradients = builder.ConstantR3FromArray3D(Array3D({{{100, 10, 1}}})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1}, - /*padding=*/{{2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/1)); + auto forward_conv = + builder.ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); builder.Transpose(forward_conv, {0, 1, 2}); ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = @@ -1357,7 +1355,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = @@ -1378,7 +1376,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers( + XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index ece7c3b05e7fafa299db7f9cbf50610c8204f95e..2b3390ca98cb2922410d451c06811aa9d4ff8c0b 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -48,7 +49,7 @@ class CopyOpTest : public HloTestBase { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectEqual(literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -246,13 +247,13 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); auto empty = Literal::CreateFromShape(in_shape); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto param0 = builder.Parameter(0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*empty, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b151187c4b8f01c5b46ccadf27d2e22a7c902e98 --- /dev/null +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class TrivialCrossReplicaSumTest : public HloTestBase {}; + +// Currently the CPU and GPU backends only support CrossReplicaSum with one +// replica. But we can at least check this. + +XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal = Literal::CreateR1({1, 2, 3}); + EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); +} + +XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] parameter(1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ( + *Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); +} + +// On the GPU backend, constants get special handling. Someone might pass a +// constant to CRS to e.g. count the number of replicas -- we need to make sure +// it works. +XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] constant({10, 20}) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 2d847a66b0ae7c8f09fa0cb181a4c84ea99be5b1..b43d5c9ff5d75ee0e1b3c9ceb2bc295e631ac107 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -134,9 +134,9 @@ class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these // are reserved for internal use. XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { - ComputationBuilder builder(client_, TestName()); - auto call = builder.CustomCall("$illegal", /*operands=*/{}, - ShapeUtil::MakeShape(F32, {1})); + XlaBuilder builder(TestName()); + builder.CustomCall("$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); StatusOr> result = Execute(&builder, /*arguments=*/{}); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index fe5621e8dc209d6113e74030444c198716d355dc..bfe688e20d182d581c3e3b545ac2289413deef7c 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -36,9 +36,8 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - Computation computation = builder->Build().ConsumeValueOrDie(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) .ConsumeValueOrDie(); @@ -48,7 +47,7 @@ class DeallocationTest : public ClientLibraryTestBase { }; TEST_F(DeallocationTest, DeallocateScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -66,7 +65,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { } TEST_F(DeallocationTest, DeallocateVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -79,7 +78,7 @@ TEST_F(DeallocationTest, DeallocateVector) { } TEST_F(DeallocationTest, DeallocateEmptyVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -92,7 +91,7 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { } XLA_TEST_F(DeallocationTest, DeallocateTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tuple({builder.ConstantR0(42.0), builder.ConstantR1({1.0, 2.0, 3.0})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -106,7 +105,7 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { } XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto element = builder.ConstantR0(42.0); auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); builder.Tuple({element, inner_tuple, element}); @@ -121,7 +120,7 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { } XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), builder.ConstantR1({1.0, 2.0, 3.0})}); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 032c06cd3c9f872f57674d3d7b5adc201c91ea77..12789fe66530fe03eb33316eda652336f29971ab 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -42,9 +42,8 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - Computation computation = builder->Build().ConsumeValueOrDie(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) .ConsumeValueOrDie(); @@ -54,7 +53,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { }; TEST_F(DeconstructTupleTest, DeconstructTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2}); @@ -73,7 +72,7 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2}); @@ -103,7 +102,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2, const2, const1}); @@ -129,7 +128,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2, const1}); @@ -159,7 +158,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -170,7 +169,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = @@ -186,7 +185,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { } XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({builder.Tuple({const1, const2}), const1}); @@ -195,7 +194,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("deconstructing nested tuples not yet supported")); + HasSubstr("Deconstructing nested tuples is not implemented")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 1da7a96fe2388eabd647a72aac81bdf2ef5bb6c6..085a5105aca1c173a7cbc211aebbeb5b254b0753 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { @@ -22,12 +23,12 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { // intended to track, we need to set kDepth to 20000. // Unfortunately, setting it that high causes the test to time out. const int kDepth = 200; - ComputationBuilder b(client_, TestName()); - ComputationDataHandle x; - ComputationDataHandle y; + XlaBuilder b(TestName()); + XlaOp x; + XlaOp y; auto x_data = CreateR0Parameter(3, 0, "x", &b, &x); auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); - ComputationDataHandle z = x; + XlaOp z = x; for (int i = 0; i < kDepth; ++i) { z = b.Add(z, y); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6b0c04c2c083bbfce267dd92d24ef15c06186d26..0fd846cef8095a857dd7b2c12d8afdf409e2bd66 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,169 +34,219 @@ limitations under the License. namespace xla { namespace { -// TODO(b/34468543): use GUnit typed tests when we can do all tests on all -// backends. class DotOperationTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 1e-5}; - - protected: - template - void TestOneElementVectorDot(); - template - void TestVectorDot(); - template - void TestSquareMatrixDot(bool lhs_row_major = false, - bool rhs_row_major = false); - template - void TestNonsquareMatrixDot(bool lhs_row_major = false, - bool rhs_row_major = false); }; -XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - auto result = builder.Dot(lhs, rhs); - - ComputeAndCompareR0(&builder, 0.0, {}, error_spec_); +#if defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = ::testing::Types; +#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = + ::testing::Types; +#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = ::testing::Types; +#else +#error "Situation not handled yet" +#endif + +// Check that we can safely pass an input tuple's elements to a dot operation. +XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { + XlaBuilder builder(TestName()); + + XlaOp param; + auto param_data = CreateParameterAndTransferLiteral( + 0, + *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), + Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + "arg0", &builder, ¶m); + auto lhs = builder.GetTupleElement(param, 0); + auto rhs = builder.GetTupleElement(param, 1); + builder.Dot(lhs, rhs); + + ComputeAndCompareLiteral(&builder, + *Literal::CreateR2({{19, 22}, {43, 50}}), + {param_data.get()}); } -XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2({{3.0, 4.0}}); - auto rhs = builder.ConstantR1({3.0, 4.0}); - auto result = builder.Dot(lhs, rhs); +template +class DotOperationTest_F16F32F64CF64 : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); - ComputeAndCompareR1(&builder, {25.0}, {}, error_spec_); -} +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); -template -void DotOperationTest::TestOneElementVectorDot() { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({2.0}); - auto rhs = builder.ConstantR1({3.0}); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 6.0, {}, error_spec_); + this->template ComputeAndCompareR0(&builder, static_cast(0.0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) { - TestOneElementVectorDot(); -} +template +class DotOperationTest_F16F32F64 : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); -XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) { - TestOneElementVectorDot(); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); + auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); + auto result = builder.Dot(lhs, rhs); + + this->template ComputeAndCompareR1(&builder, {static_cast(25.0f)}, {}, + this->error_spec_); } -template -void DotOperationTest::TestVectorDot() { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({1.0, 2.5, 42.0}); - auto rhs = builder.ConstantR1({11.0, -1.0, 0.5}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantR1({static_cast(2.0f)}); + auto rhs = builder.ConstantR1({static_cast(3.0f)}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 29.5, {}, error_spec_); + this->template ComputeAndCompareR0(&builder, static_cast(6.0f), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot(); } - -XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot(); } +XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); + auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); + auto result = builder.Dot(lhs, rhs); -namespace { + this->template ComputeAndCompareR0(&builder, static_cast(29.5f), {}, + this->error_spec_); +} std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -} // namespace - -XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(0, 0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D( + {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(0, 3), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(0, 3), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) { - ComputationBuilder builder(client_, TestName()); - auto lhs = - builder.ConstantR2({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}}); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantR2FromArray2D( + {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(3, 0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(2, 2, 0.0f), {}, - error_spec_); + this->template ComputeAndCompareR2( + &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TEST_F(DotOperationTest, FusedDot) { - ComputationBuilder builder(client_, TestName()); - auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0"); - auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1"); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto param0 = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); + auto param1 = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); auto exp0 = builder.Exp(param0); auto result = builder.Dot(exp0, param1); - auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}})) - .ConsumeValueOrDie(); - auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2( - {{1.0}, {2.0}, {3.0}, {4.0}})) - .ConsumeValueOrDie(); - - ComputeAndCompareR2( - &builder, Array2D({{296.14560492846033}, {0.8611737683031964}}), - {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - -template -void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, - bool rhs_row_major) { auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 2.0}, {3.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 6.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2D( + {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); + auto rhs_handle = this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2D( + {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) + .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + if (std::is_same::value) { + this->error_spec_ = ErrorSpec{0.0001, 1e-3}; + } - Array2D expected({{15.0, -2.0}, {-25.0, 34.0}}); - ComputeAndCompareR2( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); + this->template ComputeAndCompareR2( + &builder, Array2D({{296.14560492846033f}, {0.8611737683031964f}}), + {lhs_handle.get(), rhs_handle.get()}, this->error_spec_); } +template +class SquareMatrixDot : public DotOperationTest { + public: + void TestImpl(bool lhs_row_major, bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 2.0f}, {3.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(lhs_row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 6.0f}, {7.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(rhs_row_major)))) + .ConsumeValueOrDie(); + XlaBuilder builder(TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + + Array2D expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, error_spec_); + } +}; + +TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64); +XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); } +XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); } +XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); } +XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); } + struct DotTestParam { int m; int k; @@ -225,58 +275,73 @@ string PrintDotTestParam( } class ParametricDotTest : public DotOperationTest, - public ::testing::WithParamInterface {}; + public ::testing::WithParamInterface { + protected: + template + void TestImpl(); +}; -XLA_TEST_P(ParametricDotTest, TestF32) { +template +void ParametricDotTest::TestImpl() { DotTestParam param = GetParam(); - std::unique_ptr> dot_lhs_data = - MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( *dot_lhs_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); - std::unique_ptr> dot_rhs_data = - MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); - std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_rhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); + Layout rhs_layout = LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); + std::unique_ptr dot_rhs_lit = + Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); - std::unique_ptr> addend_data; + std::unique_ptr> addend_data; std::unique_ptr addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { - addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); + addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); addend_lit = Literal::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); } - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); + XlaBuilder builder(TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), + builder.Parameter(0, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.k}, + MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), "dot_lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), + builder.Parameter(1, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.k, param.n}, + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), "dot_rhs")); if (param.has_addend) { result = builder.Add( - result, - builder.Parameter( - 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); + result, builder.Parameter( + 2, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.n}, + MinorToMajorForIsRowMajor(param.addend_row_major)), + "addend")); } - std::unique_ptr> expected; + std::unique_ptr> expected; if (param.has_addend) { expected = ReferenceUtil::ApplyElementwise2D( - std::plus(), + std::plus(), *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), *addend_data); } else { @@ -287,8 +352,11 @@ XLA_TEST_P(ParametricDotTest, TestF32) { if (param.has_addend) { args.push_back(addend_handle.get()); } - - ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); + ErrorSpec error_spec(0.3, 3e-3); + if (std::is_same::value) { + error_spec = ErrorSpec(0.3, 5e-3); + } + ComputeAndCompareR2(&builder, *expected, args, error_spec); } std::vector CreateDotTestParameters() { @@ -305,30 +373,77 @@ std::vector CreateDotTestParameters() { } }; + add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); + add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); + add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); + + return params; +} + +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl(); } +#endif +XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl(); } + +INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, + ::testing::ValuesIn(CreateDotTestParameters()), + PrintDotTestParam); + +class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { + public: + ParametricDotTestWithoutLayoutAssignment() { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "layout-assignment"); + } +}; + +std::vector CreateNoLayoutAssignmentDotTestParameters() { + std::vector params; + auto add_matrix_vector_dot_test = [&](int k, int n) { - for (bool has_addend : {false, true}) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, - /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (n != 1) { - params.push_back( - {/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, - /*has_addend=*/has_addend, /*addend_row_major=*/true}); + for (bool lhs_row_major : {true, false}) { + for (bool rhs_row_major : {true, false}) { + for (bool has_addend : {true, false}) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/true}); + if (has_addend) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/false}); + } + if (n != 1) { + params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/true}); + if (has_addend) { + params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/false}); + } + } + } } } }; - add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); - add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); - add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); - add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/4); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/3); add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); @@ -339,109 +454,60 @@ std::vector CreateDotTestParameters() { return params; } -INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, - ::testing::ValuesIn(CreateDotTestParameters()), - PrintDotTestParam); - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - TestSquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { - TestSquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { - TestSquareMatrixDot(true, false); +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) { + TestImpl(); } - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - TestSquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { - TestSquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { - TestSquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { - TestSquareMatrixDot(true, false); +#endif +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) { + TestImpl(); } - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { - TestSquareMatrixDot(true, true); +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF64) { + TestImpl(); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { - TestSquareMatrixDot(); -} +INSTANTIATE_TEST_CASE_P( + DotTests, ParametricDotTestWithoutLayoutAssignment, + ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()), + PrintDotTestParam); -template -void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, - bool rhs_row_major) { - auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) - .ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); - - Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); - - ComputeAndCompareR2( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - TestNonsquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - TestNonsquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - TestNonsquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - TestNonsquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { - TestNonsquareMatrixDot(); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { - TestNonsquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { - TestNonsquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { - TestNonsquareMatrixDot(true, false); -} +template +class NonsquareMatrixDot : public DotOperationTest { + public: + void TestImpl(bool lhs_row_major, bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(lhs_row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(rhs_row_major)))) + .ConsumeValueOrDie(); + + XlaBuilder builder(TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + + Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, error_spec_); + } +}; -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { - TestNonsquareMatrixDot(true, true); -} +TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64); +XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = @@ -456,7 +522,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), @@ -468,31 +534,41 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { - ComputationBuilder builder(client_, TestName()); - auto matrix1 = builder.ConstantR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix2 = builder.ConstantR2({{5.0, 6.0}, {7.0, 8.0}}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); auto matrix12 = builder.Dot(matrix1, matrix2); auto matrix21 = builder.Dot(matrix2, matrix1); builder.Add(matrix12, matrix21); - Array2D expected({{42.0, 56.0}, {74.0, 96.0}}); - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + Array2D expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); + this->template ComputeAndCompareR2(&builder, expected, {}, + this->error_spec_); } +template +class DotOperationTestForBatchMatMul : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); + // Regression test for b/32055648. The root of the graph is a kFusion of 4 // bitcasts. Although bitcasts don't map to thunks, the root should still be // sync-dependent on bitcasts' operands. -XLA_TEST_F(DotOperationTest, BatchMatMul) { - ComputationBuilder builder(client_, TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y"); +XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { + using T = TypeParam; + XlaBuilder builder(this->TestName()); + auto x = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); + auto y = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "y"); auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); // Slice batches into individual matrices and multiply them. - std::vector out_slices; + std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); @@ -507,29 +583,42 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { auto out_flat = builder.ConcatInDim(out_slices, 0); builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); - auto x_data = client_ - ->TransferToServer(*Literal::CreateR4( - {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, - {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) - .ConsumeValueOrDie(); - auto y_data = client_ - ->TransferToServer(*Literal::CreateR4( - {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, - {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) + auto x_data = this->client_ + ->TransferToServer(*Literal::CreateR4FromArray4D( + {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, + {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, + {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, + {{4000.0f, 400.0f}, {40.0f, 4.0f}}}})) .ConsumeValueOrDie(); + auto y_data = + this->client_ + ->TransferToServer(*Literal::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{11.0f, 22.0f}, {33.0f, 44.0f}}, + {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) + .ConsumeValueOrDie(); - ComputeAndCompareR4( + if (std::is_same::value) { + this->error_spec_ = ErrorSpec{0.0001, 1e-3}; + } + this->template ComputeAndCompareR4( &builder, /*expected=*/ - {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, - {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, - {x_data.get(), y_data.get()}, error_spec_); + {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}}, + {{11400.0f, 13600.0f}, {114.0f, 136.0f}}}, + {{{42900.0f, 79200.0f}, {429.0f, 792.0f}}, + {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TEST_F(DotOperationTest, GeneralMatMul) { - ComputationBuilder builder(client_, TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + auto y = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -539,31 +628,34 @@ XLA_TEST_F(DotOperationTest, GeneralMatMul) { auto out = builder.DotGeneral(x, y, dnums); - auto x_data = client_ - ->TransferToServer(*Literal::CreateR3( - {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) - .ConsumeValueOrDie(); + auto x_data = + this->client_ + ->TransferToServer(*Literal::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); - auto y_data = client_ - ->TransferToServer(*Literal::CreateR3( - {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) - .ConsumeValueOrDie(); + auto y_data = + this->client_ + ->TransferToServer(*Literal::CreateR3FromArray3D( + {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) + .ConsumeValueOrDie(); - ComputeAndCompareR3( + this->template ComputeAndCompareR3( &builder, /*expected=*/ - {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, - {x_data.get(), y_data.get()}, error_spec_); + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {x_data.get(), y_data.get()}, this->error_spec_); } -TEST_F(DotOperationTest, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { + using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { for (bool row_major : {false, true}) { - std::unique_ptr> lhs( - new Array2D({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}})); - std::unique_ptr> rhs( - new Array2D({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}})); + std::unique_ptr> lhs( + new Array2D({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}})); + std::unique_ptr> rhs( + new Array2D({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}})); if (transpose_lhs) { lhs = ReferenceUtil::TransposeArray2D(*lhs); @@ -572,22 +664,20 @@ TEST_F(DotOperationTest, TransposeFolding) { rhs = ReferenceUtil::TransposeArray2D(*rhs); } auto lhs_handle = - client_ - ->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = - client_ - ->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); + XlaBuilder builder(this->TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); auto lhs_arg = builder.Parameter( 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), "lhs"); @@ -602,24 +692,27 @@ TEST_F(DotOperationTest, TransposeFolding) { } auto result = builder.Dot(lhs_arg, rhs_arg); - Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); + Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " << transpose_rhs << " " << row_major; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - error_spec_); + this->template ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, + this->error_spec_); } } } } -TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { - auto prim_type = primitive_util::NativeToPrimitiveType(); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, + DotOfConcatOptimizationWithConstLHS) { + using T = TypeParam; + auto prim_type = primitive_util::NativeToPrimitiveType(); - std::unique_ptr> constant_lhs_array(new Array2D( - {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(this->TestName()); auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); @@ -630,78 +723,325 @@ TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { auto result = builder.Dot( lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); - std::unique_ptr> arg_0_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}})); - std::unique_ptr> arg_1_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); - std::unique_ptr> arg_2_value_array( - new Array2D({{1.0, 2.0}})); + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}})); + std::unique_ptr> arg_2_value_array(new Array2D({{1.0f, 2.0f}})); TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); - Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); - ComputeAndCompareR2( + Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); + this->template ComputeAndCompareR2( &builder, expected, - {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, + this->error_spec_); } -TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { - auto prim_type = primitive_util::NativeToPrimitiveType(); - - std::unique_ptr> constant_rhs_array( - new Array2D({{1.0, 2.0}, - {3.0, 4.0}, - {5.0, 6.0}, - {6.0, 5.0}, - {4.0, 3.0}, - {2.0, 1.0}})); - - ComputationBuilder builder(client_, TestName()); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, + DotOfConcatOptimizationWithConstRHS) { + using T = TypeParam; + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}, + {6.0f, 5.0f}, + {4.0f, 3.0f}, + {2.0f, 1.0f}})); + + XlaBuilder builder(this->TestName()); auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); - auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 3}), "lhs_arg_1"); - auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType({2, 1}), "lhs_arg_2"); auto result = builder.Dot( builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); - std::unique_ptr> arg_0_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}})); - std::unique_ptr> arg_1_value_array( - new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); - std::unique_ptr> arg_2_value_array( - new Array2D({{1.0}, {2.0}})); + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0f}, {2.0f}})); TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); - Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); - ComputeAndCompareR2( + Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); + this->template ComputeAndCompareR2( &builder, expected, - {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, + this->error_spec_); +} + +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{96.0, 105.0, 114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{105.0}, {105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{105.0, 105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{96.0}, {105.0}, {114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{126.0, 129.0, 132.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{129.0}, {129.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{56.0, 168.0, 91.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{168.0}, {168.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 877dc7db0eec229a7119b3627f177a33ed0d971b..49f3a10d227f2f9edfe76405ba13498fe822f8d8 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -36,8 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -56,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR1Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + void TestR1OOB() { + // Slice at dimension boundaries, but with out of bounds indices. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7}); } template @@ -81,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR2Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR2OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } template @@ -109,13 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR3Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR3OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR3( - {{{1, 2}, {3, 4}, {5, 6}}, - {{7, 8}, {9, 10}, {11, 12}}}, - {0, 2, 1}, {2, 1, 2}, - {{{6, 5}}, {{12, 11}}}); + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, + {2, 1, 2}, {{{5, 6}}, {{11, 12}}}); } template @@ -137,9 +132,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -163,9 +158,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -189,9 +184,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -204,21 +199,21 @@ class DynamicSliceTest : public ClientLibraryTestBase { XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } -XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } +XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -281,6 +276,15 @@ XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { class DynamicUpdateSliceTest : public ClientLibraryTestBase { protected: + template + void TestR0() { + // Disable algebraic simplifier, otherwise the op will be replaced by a + // constant. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); + RunR0(0, 123, {}, 123); + } + template void TestR1() { // Slice at dimension start. @@ -328,17 +332,46 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void TestWrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestOOB() { + // // Slice at dimension boundaries, but with out of bounds indices. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {10, 1, 2, 3, 4, 5, 8, 9}); + {0, 1, 2, 3, 4, 8, 9, 10}); // R2 Shape: [3, 3] RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); + } + + template + void RunR0(int input_value_int, int update_value_int, + const std::vector slice_starts, int expected_value_int) { + Literal input_value = + std::move(*Literal::CreateR0(input_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_value = + std::move(*Literal::CreateR0(update_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_value = + std::move(*Literal::CreateR0(expected_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + + XlaBuilder builder(TestName()); + // Initialize and transfer dynamic slice start indices parameter. + XlaOp starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantLiteral(input_value); + auto update = builder.ConstantLiteral(update_value); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); } template @@ -359,9 +392,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -390,9 +423,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -421,9 +454,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -437,33 +470,25 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void RunR3Contiguous(std::vector operand_shape, int32 index, int32 size) { -#ifdef XLA_TEST_BACKEND_CPU_PARALLEL - // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. - if (std::is_same::value) { - return; - } -#endif - const int32 kSeq = operand_shape[0]; const int32 kBatch = operand_shape[1]; const int32 kDim = operand_shape[2]; Array3D input_values(kSeq, kBatch, kDim); Array3D update_values(size, kBatch, kDim); Array3D expected_values(kSeq, kBatch, kDim); + index = std::min(std::max(0, index), kSeq - size); input_values.FillIota(static_cast(0)); T value = static_cast(10); update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when - // the update wraps. According to documentation, the results are technically - // implementation specific where the update is out of bounds, and hence - // we don't really know what to pass into ComputeAndCompareR3. + // the indices are out of bounds. expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = value++; + expected_values(index + i, j, k) = value++; } } } @@ -474,13 +499,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } // Build dynamic slice computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer input parameter. - ComputationDataHandle input; + XlaOp input; std::unique_ptr input_data = CreateR3Parameter(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. - ComputationDataHandle update; + XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1({index, 0, 0}); @@ -500,36 +525,31 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } + // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { - TestR1(); -} +XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } -// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) { - TestR2(); -} +XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } -// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R3BF16)) { - TestR3(); -} +XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32WrapBF16)) { - TestWrap(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -592,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } @@ -672,7 +692,7 @@ void BM_DynamicSlice(int num_iters) { TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); - ComputationBuilder builder(client, "DynamicSlice"); + XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] auto input_literal = Literal::CreateR4( @@ -697,11 +717,11 @@ void BM_DynamicSlice(int num_iters) { auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, *buffer)); + executors[device_ordinal], *start_indices_literal, buffer)); std::unique_ptr executable = client - ->Compile(computation, {&buffer->on_host_shape()}, + ->Compile(computation, {&buffer.on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); @@ -710,14 +730,14 @@ void BM_DynamicSlice(int num_iters) { options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({buffer.get()}, options); + auto result = executable->Run({&buffer}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({buffer.get()}, options); + auto result = executable->Run({&buffer}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 644cbbf40f296eb2a574ae568b4f32aa3d0bd12f..a6ba6db5d3bf86de91f6fda022c46afee01281c2 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -24,8 +25,7 @@ namespace { class ExecutionProfileTest : public ClientLibraryTestBase {}; -XLA_TEST_F(ExecutionProfileTest, - DISABLED_ON_CPU_PARALLEL(ExecuteWithExecutionProfile)) { +XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { Shape shape = ShapeUtil::MakeShape(F32, {256, 256}); TF_ASSERT_OK_AND_ASSIGN( @@ -33,9 +33,9 @@ XLA_TEST_F(ExecutionProfileTest, client_->TransferToServer( *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); - TF_ASSERT_OK_AND_ASSIGN(Computation dot_product, b.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build()); ExecutionProfile execution_profile; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 6fe7737de7af349dca2931b52d62dbc03b14e0b3..0a37e4d423620122f2e109343a86a964f46d778f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -35,7 +36,7 @@ class ExhaustiveF32ElementwiseOpTest int64 input_size = end - begin; LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateFromDimensions(F32, {input_size}); @@ -71,16 +72,14 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #ifdef XLA_TEST_BACKEND_CPU // TODO(b/73141998): The vectorized Log implementation gives results outside // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64): - std::pair known_incorrect_range = {1, 8315654}; + // floats expressed as a zero extended int64). + std::pair known_incorrect_range = {1, 8388608}; #else std::pair known_incorrect_range = {0, 0}; #endif ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Log(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); }, std::log, known_incorrect_range); } @@ -96,17 +95,13 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { #endif ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Exp(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); }, std::exp, known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Tanh(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); }, std::tanh, /*known_incorrect_range=*/{0, 0}); } diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index a5f6872c46c7800b8b76a571a2546795f8814fb5..93d1c921c4a138cda55ed7338b8e3aa82518d114 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -38,7 +38,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern)); // Invoke FileCheck to check whether input matches `pattern`. - const char* file_check_path_suffix = "external/llvm/FileCheck"; + const char* file_check_path_suffix = "org_tensorflow/external/llvm/FileCheck"; string file_check_path; if (const char* test_srcdir = getenv("TEST_SRCDIR")) { file_check_path = JoinPath(test_srcdir, file_check_path_suffix); @@ -66,6 +66,11 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { // the error message generated by FileCheck and the inputs. bool succeeded = (exit_status == 0); if (!succeeded) { + LOG(WARNING) << "Tried to execute FileCheck at " << file_check_path; + if (!env->FileExists(file_check_path).ok()) { + LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; + } + LOG(WARNING) << "FileCheck error: " << standard_error; LOG(WARNING) << "FileCheck input was:"; XLA_LOG_LINES(tensorflow::WARNING, input); diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index e75a41acacc3aaad770f8bba78b43d8bf99b911b..71eb914a8e5eaef2e38b9e6e7d45b8a10ce1bd7a 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -41,7 +41,7 @@ class FloorCeilTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice expected, Function f) { LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") << "}"; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto c = builder.ConstantR1(input); if (f == kCeil) { builder.Ceil(c); @@ -54,7 +54,7 @@ class FloorCeilTest : public ClientLibraryTestBase { void TestR0F32(float input, float expected, Function f) { LOG(INFO) << "input: " << expected; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto c = builder.ConstantR0(input); if (f == kCeil) { builder.Ceil(c); diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index f2aaf6621c1f0d7a7d1bc29b845859579d8e8d9d..73f029b59bc56aa6c3e86200a49fcae0fd177101 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" @@ -27,7 +27,7 @@ namespace { class FmaxSimpleTest : public ClientLibraryTestBase {}; TEST_F(FmaxSimpleTest, FmaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index a292eab1d198fbf69c6dc81c780487ea46756f72..e6f79b5ac55dddfbb213a36cadbee53bc9443d9d 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -25,8 +25,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -50,8 +49,6 @@ limitations under the License. using tensorflow::gtl::ArraySlice; -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -121,9 +118,9 @@ class FusionTest : public HloTestBase { auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); } else { - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } } @@ -224,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{0.5}, {2.72}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -250,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -310,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -325,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(5), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -339,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -354,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -369,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -383,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -397,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -411,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -426,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -441,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -457,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -474,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -491,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -508,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -529,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { /*instructions_to_fuse=*/{negate3, dynamic_slice2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -546,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // TODO(b/64070202): Investigate failure. @@ -564,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -594,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -615,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(-15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -664,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -677,21 +687,20 @@ XLA_TEST_F(FusionTest, SharedConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(Literal::CreateR1({0}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); hlo_module->AddEntryComputation(builder.Build()) - ->CreateFusionInstruction( - {add4, add3, add2, add1, const1}, - HloInstruction::FusionKind::kLoop); + ->CreateFusionInstruction({add4, add3, add2, add1, const1}, + HloInstruction::FusionKind::kLoop); HloComputation* entry_comp = hlo_module->entry_computation(); @@ -701,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { // fused instruction contains the constant(2), the parameter, and 4 adds EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({8}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -781,7 +791,7 @@ void BM_ParallelFusion(int num_iters) { const int64 param2_dim1 = 1024; // Create computation. - ComputationBuilder builder(client, "ParallelFusion"); + XlaBuilder builder("ParallelFusion"); Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); auto param0 = builder.Parameter(0, shape0, "param0"); Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); @@ -796,19 +806,19 @@ void BM_ParallelFusion(int num_iters) { // Transfer literals to device. auto param0_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); - std::unique_ptr buffer0 = + ScopedShapedBuffer buffer0 = client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); - std::unique_ptr buffer1 = + ScopedShapedBuffer buffer1 = client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); - std::unique_ptr buffer2 = + ScopedShapedBuffer buffer2 = client->LiteralToShapedBuffer(*param2_literal, device_ordinal) .ConsumeValueOrDie(); @@ -816,8 +826,8 @@ void BM_ParallelFusion(int num_iters) { std::unique_ptr executable = client ->Compile(computation, - {&buffer0->on_host_shape(), &buffer1->on_host_shape(), - &buffer2->on_host_shape()}, + {&buffer0.on_host_shape(), &buffer1.on_host_shape(), + &buffer2.on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); @@ -838,8 +848,7 @@ void BM_ParallelFusion(int num_iters) { // Run some warm-up executions. const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = - executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); + auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options); ASSERT_TRUE(result.ok()); } @@ -852,8 +861,7 @@ void BM_ParallelFusion(int num_iters) { tensorflow::testing::UseRealTime(); tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = - executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); + auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..143ffbdeb409d91ab6d46d386aa5ff98ebc4ae10 --- /dev/null +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -0,0 +1,636 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +// NB! TODO(b/74360564): These tests do not test out of bounds behavior since +// that hasn't been specced yet. + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class GatherOperationTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* gather_indices) { + RunTest(hlo_text, {operand, gather_indices}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherV1) { + const string hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { + const string hlo_text = R"( +HloModule TensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { + const string hlo_text = R"( +HloModule TensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + ROOT gather = s32[2,1,1,2] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { + const string hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, DynamicSlice) { + const char* hlo_text = R"( +HloModule DynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,0] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 0} +} +)"; + std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, + // we should get rid of the mask and check that backends produce the same + // value for OOB indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndex) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for negative + // accesses, we should get rid of the mask and check that backends produce the + // same value for negative indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), + output_window_dims={0,1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1,3,2} +} +)"; + std::unique_ptr operand = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ScalarResult) { + const char* hlo_text = R"( +HloModule ScalarResult + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[] gather(operand, index), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1} +} +)"; + std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { + const string hlo_text = R"( +HloModule ZeroSizedResult + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[0] parameter(1) + ROOT gather = s32[0,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, + FusedTensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { + const char* hlo_text = R"( +HloModule FusedDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { + const string hlo_text = R"( +HloModule FusedBatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +class GatherClientLibraryTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { + // We create this HLO, but using the XlaBuilder API. + // + // ENTRY main { + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // ROOT gather = s32[2,3] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0}, + // gather_dims_to_operand_dims={0}, + // index_vector_dim=1, + // window_bounds={1, 3} + // } + + XlaBuilder builder("gather_basic"); + + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = builder.Parameter(0, operand_shape, "operand"); + auto indices = builder.Parameter(1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(1); + dim_numbers.add_elided_window_dims(0); + dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.set_index_vector_dim(1); + builder.Gather(operand, indices, dim_numbers, {1, 3}); + + std::vector expected = {}; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, + client_->TransferToServer(*Literal::CreateR2( + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr indices_arg, + client_->TransferToServer(*Literal::CreateR1({0, 2}))); + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + *execution_options.add_device_handles() = devices[0]; + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + std::vector computation_instances = { + {computation, + {operand_arg.get(), indices_arg.get()}, + execution_options, + /*execution_profile=*/nullptr}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector> result_data, + client_->ExecuteParallel(computation_instances)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + client_->Transfer(*(result_data[0]))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index ec2f49d43bd8cee84c6b0abe1892e8b2278eefeb..76bf47845ca045b4eede9a3b47ae5c2ce93ce577 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,8 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -39,7 +38,7 @@ class HalfTestBase : public ClientLibraryTestBase { }; using UnaryBuildFuncTy = - std::function; + std::function; struct UnaryOpTestParam { std::function compute_func; @@ -51,8 +50,8 @@ class UnaryOpTest : public HalfTestBase, XLA_TEST_P(UnaryOpTest, Ops) { std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); @@ -79,30 +78,21 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, - ::testing::Values(UnaryOpTestParam{[](half x) { return abs(x); }, - &ComputationBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &ComputationBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, - &ComputationBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, - &ComputationBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, - &ComputationBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, - &ComputationBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, - &ComputationBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, - &ComputationBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, - &ComputationBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, - &ComputationBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, - &ComputationBuilder::Tanh} + ::testing::Values( + UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, + &XlaBuilder::Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, + UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} - )); + )); struct UnaryPredTestParam { std::function compute_func; @@ -115,8 +105,8 @@ class UnaryPredTest : public HalfTestBase, XLA_TEST_P(UnaryPredTest, Ops) { std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); @@ -136,11 +126,11 @@ XLA_TEST_P(UnaryPredTest, Ops) { INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, - &ComputationBuilder::IsFinite})); + &XlaBuilder::IsFinite})); using BinaryBuildFuncTy = std::function)>; + xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y, + tensorflow::gtl::ArraySlice)>; struct BinaryOpTestParam { std::function compute_func; @@ -153,12 +143,12 @@ class BinaryOpTest : public HalfTestBase, XLA_TEST_P(BinaryOpTest, Ops) { std::vector x({half(1.0), half(2.0), half(3.0), half(-4.0)}); std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); - ComputationDataHandle y_opnd; + XlaOp y_opnd; auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", &builder, &y_opnd); @@ -184,21 +174,21 @@ INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( BinaryOpTestParam{[](half x, half y) { return x + y; }, - &ComputationBuilder::Add}, + &XlaBuilder::Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &ComputationBuilder::Atan2}, + &XlaBuilder::Atan2}, BinaryOpTestParam{[](half x, half y) { return x / y; }, - &ComputationBuilder::Div}, + &XlaBuilder::Div}, BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &ComputationBuilder::Max}, + &XlaBuilder::Max}, BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &ComputationBuilder::Min}, + &XlaBuilder::Min}, BinaryOpTestParam{[](half x, half y) { return x * y; }, - &ComputationBuilder::Mul}, + &XlaBuilder::Mul}, BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &ComputationBuilder::Pow}, + &XlaBuilder::Pow}, BinaryOpTestParam{[](half x, half y) { return x - y; }, - &ComputationBuilder::Sub} + &XlaBuilder::Sub} )); @@ -214,12 +204,12 @@ class BinaryPredTest XLA_TEST_P(BinaryPredTest, Ops) { std::vector x({half(1.0), half(2.0), half(0.2), half(-4.0)}); std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); - ComputationDataHandle y_opnd; + XlaOp y_opnd; auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", &builder, &y_opnd); @@ -239,17 +229,17 @@ XLA_TEST_P(BinaryPredTest, Ops) { INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &ComputationBuilder::Eq}, + &XlaBuilder::Eq}, BinaryPredTestParam{[](half x, half y) { return x != y; }, - &ComputationBuilder::Ne}, + &XlaBuilder::Ne}, BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &ComputationBuilder::Ge}, + &XlaBuilder::Ge}, BinaryPredTestParam{[](half x, half y) { return x > y; }, - &ComputationBuilder::Gt}, + &XlaBuilder::Gt}, BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &ComputationBuilder::Le}, + &XlaBuilder::Le}, BinaryPredTestParam{[](half x, half y) { return x < y; }, - &ComputationBuilder::Lt} + &XlaBuilder::Lt} )); diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index eded2077fce965ab1c729c610764afa2228ca128..cf971dd61b71ad329b20b0bb7c16166126562681 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" @@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase { metadata_.set_op_name("my_sum_op"); } - void BuildAddComputation(ComputationBuilder* builder) { + void BuildAddComputation(XlaBuilder* builder) { auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder->Add(x, y); @@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase { }; TEST_F(HloMetadataTest, MetadataPropagation) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); BuildAddComputation(&builder); builder.ClearOpMetadata(); @@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { } TEST_F(HloMetadataTest, MetadataClearing) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); // Some other pretend computation here. builder.ClearOpMetadata(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 9f5806c5e16c30cf198027cffab5f78c315cb957..242cc5db11ff2bdf69209df7537216573d8afbf3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -35,8 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -91,21 +89,19 @@ HloTestBase::HloTestBase() HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = MakeUnique(); + hlo_verifier_ = MakeUnique(/*allow_mixed_precision=*/true); } /* static */ -std::unique_ptr HloTestBase::CreateNewModule() { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); +std::unique_ptr HloTestBase::CreateNewModule(const string& name) { + return MakeUnique(name, GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); return debug_options; } @@ -115,6 +111,15 @@ StatusOr> HloTestBase::Execute( return test_runner_.Execute(std::move(module), arguments); } +std::unique_ptr HloTestBase::ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_ + .Execute(std::move(module), arguments, + /*run_hlo_passes=*/false) + .ValueOrDie(); +} + std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { @@ -135,22 +140,15 @@ StatusOr> HloTestBase::MakeReferenceModule( "reference preprocessor must not modify the program shape"); } } - TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), - reference_module.get())); + TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status()); return std::move(reference_module); } -template StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { - static_assert( - std::is_same::value || - std::is_same, LiteralPtr>::value, - "The LiteralPtr type only accepts Literal* or std::unique_ptr."); - TF_RETURN_IF_ERROR( - VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); TF_ASSIGN_OR_RETURN(auto reference_module, MakeReferenceModule(*module, reference_preprocessor)); @@ -165,9 +163,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( error); } -template ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -179,9 +176,8 @@ template return result.ValueOrDie(); } -template ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -198,8 +194,14 @@ template const std::function& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompare>( - std::move(module), fake_arguments, error, reference_preprocessor); + + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + + return RunAndCompare(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( @@ -207,8 +209,13 @@ template const std::function& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompareNoHloPasses>( - std::move(module), fake_arguments, error, reference_preprocessor); + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompare( @@ -267,6 +274,28 @@ template reference_preprocessor); } +HloComputation* HloTestBase::FindComputation(HloModule* module, + tensorflow::StringPiece name) { + auto it = c_find_if(module->computations(), + [&](HloComputation* c) { return c->name() == name; }); + if (it == module->computations().end()) { + return nullptr; + } + return *it; +} + +HloInstruction* HloTestBase::FindInstruction(HloModule* module, + tensorflow::StringPiece name) { + for (const HloComputation* c : module->computations()) { + auto it = c_find_if(c->instructions(), + [&](HloInstruction* i) { return i->name() == name; }); + if (it != c->instructions().end()) { + return *it; + } + } + return nullptr; +} + Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4aea9fc9fd027231106e529eb16bcd43f23fbe1c..249da87f489324ed9d377cc46a15cef5a9e74192 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -44,7 +44,7 @@ namespace xla { // enables, for one, explicitly building a graph of HLO instructions to run. // // This can also be used to write text/file-based test cases. Note that the test -// target is responsible for linking the needed backends. A covenient way to do +// target is responsible for linking the needed backends. A convenient way to do // this is to make it an xla_test: it will generate test targets linking with // the respective backends, which will be used as the test backend; the // interpreter backend is already linked with hlo_test_base so it will be the @@ -66,6 +66,15 @@ namespace xla { // // For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { + public: + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + static std::unique_ptr CreateNewModule( + const string& name = TestName()); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -76,36 +85,40 @@ class HloTestBase : public ::testing::Test { // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. - HloTestBase(::perftools::gputools::Platform* test_platform, - ::perftools::gputools::Platform* reference_platform); + HloTestBase(se::Platform* test_platform, se::Platform* reference_platform); ~HloTestBase() override {} - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - static std::unique_ptr CreateNewModule(); - // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. static DebugOptions GetDebugOptionsForTest(); + // Gets an HloModuleConfig with options appropriate for tests. + static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return config; + } + // Executes the given module and return the result as a Literal. StatusOr> Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); + // Same as above, except the module will be executed without running any HLO + // passes on it. + std::unique_ptr ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); // Executes the given hlo module on two backends and compares results. // - // 'arguments': the input of the hlo module. The LiteralPtr type accepts - // Literal* or std::unique_ptr. + // 'arguments': the input of the hlo module. // // 'error': if has value, expects the results to be near (within the error // bound). Otherwise, expects the results to be equal. @@ -114,20 +127,18 @@ class HloTestBase : public ::testing::Test { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - template ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Same as above, except that the module will be executed without Hlo // optimization. - template ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -174,9 +185,13 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT(param_no, - module->mutable_entry_computation_layout()->parameter_count()); - module->mutable_entry_computation_layout() + ASSERT_LT( + param_no, + module->mutable_host_entry_computation_layout()->parameter_count()); + module->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -184,7 +199,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -192,11 +210,23 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->Clear(); } + // Gets the computation/instruction from the given module with the given name. + // + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + HloComputation* FindComputation(HloModule* module, + tensorflow::StringPiece name); + HloInstruction* FindInstruction(HloModule* module, + tensorflow::StringPiece name); + // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } @@ -223,10 +253,9 @@ class HloTestBase : public ::testing::Test { // Runs the module on two platforms with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - template StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 506091ddd8d1d8e6519525bb7031f4e8b296b5fb..22c664d1426c598dbb695ff1b66ce009b0a19c00 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -40,23 +41,41 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; - } + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); + if (!mutated.ok()) { + ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); + } else { + EXPECT_FALSE(mutated.ValueOrDie()) + << "HloVerifier should never mutate the HloModule"; + } +} + HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + +void HloVerifiedTestBase::ParseAndVerifyModule( + tensorflow::StringPiece hlo_text) { + CHECK(!module_) << "Called ParseModule when test already has a module."; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + VerifyModule(module_.get()); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 492688bf7d682cf991cb8c09399492a0437f651b..5b59cc77f61b05092d3afb331e73932c9edc5840 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,6 +44,7 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. @@ -51,10 +52,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; + static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 5aa71a9261dbd414d1499f15c9b83cd63b634b49..cde1dcd9cd10c86107f495a92be42b57bf6a085b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,813 +15,93 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include -#include -#include - -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" namespace xla { -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( - const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return ::testing::AssertionFailure() - << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) - << " got: " << ShapeUtil::HumanString(actual); - } - if (ShapeUtil::IsTuple(expected)) { - if (ShapeUtil::TupleElementCount(expected) != - ShapeUtil::TupleElementCount(actual)) { - return ::testing::AssertionFailure() - << "want tuple element count: " - << ShapeUtil::TupleElementCount(expected) - << " got tuple element count: " - << ShapeUtil::TupleElementCount(actual); - } - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) - << "mismatch in tuple index " << i; - if (!result) { - return result; - } - } - } else { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { - return ::testing::AssertionFailure() - << "want rank of: " << ShapeUtil::HumanString(expected) - << " got rank of: " << ShapeUtil::HumanString(actual); - } - if (expected.element_type() != actual.element_type()) { - return ::testing::AssertionFailure() - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - } - if (expected.dimensions_size() != actual.dimensions_size()) { - return ::testing::AssertionFailure() - << "want dimensions_size " << expected.dimensions_size() - << " got dimensions_size " << actual.dimensions_size(); - } - for (int i = 0; i < expected.dimensions_size(); ++i) { - if (expected.dimensions(i) != actual.dimensions(i)) { - return ::testing::AssertionFailure() - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - } - } - return ::testing::AssertionSuccess(); -} - -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_TRUE(EqualShapes(expected, actual)); -} - -/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( - const Shape& expected, const Shape& actual) { - ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); -} - namespace { -// Return a literal with all arrays of type FromNativeT converted to type -// ToNativeT in the given literal. -template -std::unique_ptr ConvertType(const Literal& literal) { - // First construct shape of the result. - Shape result_shape(literal.shape()); - ShapeUtil::ForEachMutableSubshape( - &result_shape, [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == - primitive_util::NativeToPrimitiveType()) { - subshape->set_element_type( - primitive_util::NativeToPrimitiveType()); - } - }); - auto result = MakeUnique(result_shape); - - // Then copy over the data from 'literal' converting FromNativeT values to - // ToNativeT values as necessary. - ShapeUtil::ForEachSubshape( - literal.shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { - if (subshape.element_type() == - primitive_util::NativeToPrimitiveType()) { - auto src = literal.data(shape_index); - auto dest = result->data(shape_index); - for (int64 i = 0; i < src.size(); ++i) { - dest[i] = static_cast(src[i]); - } - } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); - } - } - }); - return result; -} - -} // namespace - -/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { - return ConvertType(literal); -} - -/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { - return ConvertType(literal); -} - -namespace { - -string Hostname() { - char hostname[1024]; - gethostname(hostname, sizeof hostname); - hostname[sizeof hostname - 1] = 0; - return string(hostname); -} - -// Helper function for comparing a floating point type, FloatT, bitwise equal -// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT -// -- on miscompare, a nice error message is given in the AssertionFailure. -template -::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { - auto ulhs = tensorflow::bit_cast(lhs); - auto urhs = tensorflow::bit_cast(rhs); - auto lhs_double = static_cast(lhs); - auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { - return ::testing::AssertionFailure() << tensorflow::strings::Printf( - "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) - .c_str(), - lhs_double, lhs_double, - tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) - .c_str(), - rhs_double, rhs_double); - } - return ::testing::AssertionSuccess(); -} - -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). -template -::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { - if (lhs == rhs) { +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + auto get_hostname = [] { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); + }; + int64 now_usec = tensorflow::Env::Default()->NowMicros(); + string filename = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), + tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), + now_usec, name.c_str())); + TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, + literal.ToProto())); + LOG(ERROR) << "wrote to " << name << " file: " << filename; +} + +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); +} + +::testing::AssertionResult StatusToAssertion(const Status& s) { + if (s.ok()) { return ::testing::AssertionSuccess(); } - ::testing::Message msg; - msg << "Expected equality of these values:"; - msg << "\n " << lhs; - msg << "\n " << rhs; - - return ::testing::AssertionFailure() << msg; -} - -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. -template <> -::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(complex64 lhs, - complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); - if (!res) { - return res; - } - return CompareEqual(lhs.imag(), rhs.imag()); -} - -// A recursive function which iterates through every index of expected and -// actual literal and compares their values elementwise. Returns true if all -// elements are equal. -template -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { - if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = expected.Get(multi_index); - NativeT actual_value = actual.Get(multi_index); - ::testing::AssertionResult result = - CompareEqual(expected_value, actual_value); - return result; // Defines implicit coersion to bool. - } - - bool all_match = true; - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index[dimension] = i; - all_match = all_match && ExpectLiteralsEqual( - expected, actual, multi_index, dimension + 1); - } - return all_match; + return ::testing::AssertionFailure() << s.error_message(); } } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, - const string& message) { - EXPECT_TRUE(Equal(expected, actual)) - << "expected:\n" - << expected.ToString() << "\n\tvs actual:\n" - << actual.ToString() - << (message.empty() - ? "" - : tensorflow::strings::StrCat("\nmessage: ", message)); -} - -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { - EXPECT_FALSE(Equal(expected, actual)); -} - -/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); - - AssertEqualShapes(expected.shape(), actual.shape()); - std::vector multi_index(expected.shape().dimensions_size(), 0); - bool match = false; - switch (expected.shape().element_type()) { - case PRED: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U8: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case BF16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case C64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case TUPLE: { - bool tuple_match = true; - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(tensorflow::strings::StrCat( - "Tuple index ", i, " in ", - ShapeUtil::HumanString(expected.shape()))); - - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); - tuple_match = tuple_match ? !!result : false; - } - match = tuple_match; - break; - } - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - ::testing::AssertionResult result = ::testing::AssertionSuccess(); - if (!match) { - result = ::testing::AssertionFailure() - << "expected: " << expected.ToString() - << "\nactual: " << actual.ToString(); - VLOG(1) << result.message(); - } - return result; +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + return StatusToAssertion(literal_comparison::EqualShapes(expected, actual)); } -namespace { - -// Helper class for comparing floating-point literals within an error bound. -class NearComparator { - public: - explicit NearComparator(ErrorSpec error) : error_(error) {} - - // Compares the two literals elementwise. EXPECTs each pair of elements to be - // within the error bound. Emits useful log messages and dumps literals to - // temporary files on failure. Returns true if literals match. - bool ExpectNear(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(expected)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(actual)); - - // If the shapes mismatch, we simply fail the expectation instead of - // printing out data, as it's a type error rather than a value error. - ::testing::AssertionResult equal_shapes = - LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); - if (!equal_shapes) { - EXPECT_TRUE(equal_shapes); - return false; - } - - // Set up members used during the comparison. - num_miscompares_ = 0; - abs_diff_sum_ = 0.0; - abs_expected_sum_ = 0.0; - abs_diff_miscompare_sum_ = 0.0; - abs_expected_miscompare_sum_ = 0.0; - max_rel_err_ = 0.0; - max_abs_err_ = 0.0; - first_linear_index_ = -1; - last_linear_index_ = -1; - max_rel_linear_index_ = -1; - max_abs_linear_index_ = -1; - miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); - miscompares_.PopulateWithValue(false); - multi_index_.resize(expected.shape().dimensions_size(), 0); - - switch (expected.shape().element_type()) { - case BF16: - ExpectLiteralsNear(expected, actual, 0); - break; - case F16: - ExpectLiteralsNear(expected, actual, 0); - break; - case F32: - ExpectLiteralsNear(expected, actual, 0); - break; - case F64: - ExpectLiteralsNear(expected, actual, 0); - break; - case C64: - ExpectLiteralsNear(expected, actual, 0); - break; - default: - LOG(FATAL) << "Unsupported primitive type in near comparator: " - << PrimitiveType_Name(expected.shape().element_type()) - << ". Must be floating-point type."; - } - - if (num_miscompares_ > 0) { - if (!VLOG_IS_ON(1)) { - LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << TruncateHugeLiteral(expected); - LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << TruncateHugeLiteral(actual); - LOG(INFO) << "Dumping literals to temp files..."; - WriteLiteralToTempFile(expected, "expected"); - WriteLiteralToTempFile(actual, "actual"); - WriteLiteralToTempFile(miscompares_, "miscompares"); - } - EXPECT_TRUE(num_miscompares_ == 0) - << "\nmax relative mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), max_rel_linear_index_)) - << "\nmaximum relative error " << max_rel_err_ - << "\nmax absolute mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), max_abs_linear_index_)) - << "\nmaximum absolute error " << max_abs_err_ - << "\nfirst mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), first_linear_index_)) - << "\nlast mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), last_linear_index_)) - << "\ntotal absolute error " << abs_diff_sum_ - << "\ntotal absolute error of miscompares " - << abs_diff_miscompare_sum_ << "\ntotal relative error " - << (abs_diff_sum_ / abs_expected_sum_) - << "\ntotal relative error of miscompares " - << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_) - << "\nfailure count " << num_miscompares_; - } - return num_miscompares_ == 0; - } - - private: - template - bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); - } - } - - template - void ExpectNear(NativeT expected, NativeT actual, - const ::testing::Message& message) { - EXPECT_NEAR(expected, actual, error_.abs) - << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" - << message; - } - - // EXPECTs that the two given scalar values are within the error bound. Keeps - // track of how many mismatches have occurred to keep the size of the output - // manageable. - template - bool ExpectValuesNear(NativeT expected, NativeT actual) { - if (expected == actual) { - return true; - } - - const float abs_diff = std::abs(actual - expected); - const float rel_err = abs_diff / std::abs(expected); - const bool nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); - const bool mismatch = - (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); - return !mismatch; - } - - // Assumes that expected vs actual fail ExpectValuesNear. - template - void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual, - const Shape& shape, const int64 linear_index) { - const float abs_diff = std::abs(actual - expected); - const float rel_err = abs_diff / std::abs(expected); - abs_diff_sum_ += abs_diff; - abs_expected_sum_ += std::abs(expected); - if (rel_err > max_rel_err_ || std::isnan(rel_err)) { - max_rel_err_ = rel_err; - max_rel_linear_index_ = linear_index; - } - if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) { - max_abs_err_ = abs_diff; - max_abs_linear_index_ = linear_index; - } - if (VLOG_IS_ON(10)) { - VLOG(10) << tensorflow::strings::Printf( - "index %s abs_diff %f rel_err %f", - LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), - abs_diff, rel_err); - } - abs_diff_miscompare_sum_ += abs_diff; - abs_expected_miscompare_sum_ += std::abs(expected); - const int64 kMaxFailures = 2; - if (num_miscompares_ < kMaxFailures) { - const auto multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); - ::testing::Message msg; - msg << "mismatch at index " - << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff " - << abs_diff << " rel err " << rel_err << " failure #" - << num_miscompares_; - ExpectNear(expected, actual, msg); - } else if (num_miscompares_ == kMaxFailures) { - LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; - } - if (num_miscompares_ == 0) { - first_linear_index_ = linear_index; - } - num_miscompares_++; - last_linear_index_ = linear_index; - miscompares_.data()[linear_index] = true; - } - - // Recursive function which compares the two given literals elementwise. - template - void ExpectLiteralsNear(const Literal& expected, const Literal& actual, - int64 dimension) { - // Fast path optimization for the case were layouts match. - if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected.data(); - tensorflow::gtl::ArraySlice actual_data = - actual.data(); - const int64 len = expected_data.size(); - for (int64 i = 0; i < len; ++i) { - const bool near = ExpectValuesNear(expected_data[i], actual_data[i]); - if (!near) { - UpdateAndLogMiscompares(expected_data[i], actual_data[i], - actual.shape(), i); - } - } - return; - } - - if (dimension == expected.shape().dimensions_size()) { - bool near = ExpectValuesNear(expected.Get(multi_index_), - actual.Get(multi_index_)); - if (!near) { - UpdateAndLogMiscompares( - expected.Get(multi_index_), - actual.Get(multi_index_), actual.shape(), - IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(), - multi_index_)); - } - } else { - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index_[dimension] = i; - ExpectLiteralsNear(expected, actual, dimension + 1); - } - } - } - - // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { - int64 now_usec = tensorflow::Env::Default()->NowMicros(); - string filename = tensorflow::io::JoinPath( - tensorflow::testing::TmpDir(), - tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), - now_usec, name.c_str())); - TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal.ToProto())); - LOG(ERROR) << "wrote to " << name << " file: " << filename; - } - - // Gets the total element count. For tuples, this is not the count of tuple - // elements, but the sum of elements of each tuple element. - int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); - int64 total = 0; - for (int64 i = 0; i < tuple_elements; ++i) { - total += - RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); - } - return total; - } else { - return ShapeUtil::ElementsIn(shape); - } - } - - // Calling ToString on a literal with over 100 million elements takes around - // 3 minutes. The utility of printing a literal with >1000 elements is - // questionable, especially when writing the Literal proto to disk is orders - // of magnitude faster. - string TruncateHugeLiteral(const Literal& literal) { - return RecursiveElementCount(literal.shape()) < 1000 - ? literal.ToString() - : "[TRUNCATED, Literal with more than 1000 values]"; +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) { + if (expected.ShortDebugString() != actual.ShortDebugString()) { + return ::testing::AssertionFailure() + << "want: " << expected.ShortDebugString() + << " got: " << actual.ShortDebugString(); } - - ErrorSpec error_; - - // Number of element miscomparisons encountered so far. - int64 num_miscompares_; - - // A Literal containing which elements did not match in the expected and - // actual literals. miscompares_ contains PREDs and is of the same sizes as - // the comparison literals. - Literal miscompares_; - - // A multidimensional index used when performing the recursive comparison. - std::vector multi_index_; - - // Aggregated Statistics on input. - double abs_diff_sum_; - double abs_expected_sum_; - double abs_diff_miscompare_sum_; - double abs_expected_miscompare_sum_; - float max_rel_err_; - float max_abs_err_; - int64 first_linear_index_; - int64 last_linear_index_; - int64 max_rel_linear_index_; - int64 max_abs_linear_index_; -}; - -template <> -bool NearComparator::NanMismatch(complex64 expected, - complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); -} - -template <> -void NearComparator::ExpectNear(complex64 expected, complex64 actual, - const ::testing::Message& message) { - EXPECT_NEAR(expected.real(), actual.real(), error_.abs) - << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" - << message; - EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs) - << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" - << message; -} - -template <> -bool NearComparator::ExpectValuesNear(bfloat16 expected, - bfloat16 actual) { - return ExpectValuesNear(static_cast(expected), - static_cast(actual)); -} - -template <> -bool NearComparator::ExpectValuesNear(half expected, half actual) { - return ExpectValuesNear(static_cast(std::move(expected)), - static_cast(std::move(actual))); -} - -template <> -void NearComparator::UpdateAndLogMiscompares( - const bfloat16 expected, const bfloat16 actual, const Shape& shape, - const int64 linear_index) { - UpdateAndLogMiscompares(static_cast(expected), - static_cast(actual), shape, linear_index); + return ::testing::AssertionSuccess(); } -template <> -void NearComparator::UpdateAndLogMiscompares(half expected, half actual, - const Shape& shape, - const int64 linear_index) { - UpdateAndLogMiscompares(static_cast(std::move(expected)), - static_cast(std::move(actual)), shape, - linear_index); +/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( + const LiteralSlice& expected, const LiteralSlice& actual) { + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } -} // namespace - /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error) { - ::testing::AssertionResult err = - EqualShapes(expected.shape(), actual.shape()); - if (!err) { - return err; - } - - if (ShapeUtil::IsTuple(expected.shape())) { - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(tensorflow::strings::StrCat( - "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - - ::testing::AssertionResult res = - Near(expected_element, actual_element, error); - if (err && !res) { - err = res; - } - } - return err; - } - - if (ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())) { - NearComparator comparator(error); - return comparator.ExpectNear(expected, actual) - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << "values were not near"; - } - - return Equal(expected, actual); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message) { + return StatusToAssertion(literal_comparison::Near( + expected, actual, error_spec, detailed_message, &OnMiscompare)); } -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - const string& message) { - EXPECT_TRUE(Near(expected, actual, error)) - << (message.empty() - ? "" - : tensorflow::strings::StrCat("\nmessage: ", message)); -} - -/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, +/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; - return Near(expected, actual, *error); + return StatusToAssertion(literal_comparison::Near( + expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); } VLOG(1) << "Expects equal"; - return Equal(expected, actual); -} - -/*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error) { - EXPECT_TRUE(NearOrEqual(expected, actual, error)); -} - -/* static */ string LiteralTestUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(multi_index, ","), "}"); -} - -/* static */ std::unique_ptr LiteralTestUtil::Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { - int64 new_num_elements = 1; - for (int64 i = 0; i < new_dimensions.size(); ++i) { - new_num_elements *= new_dimensions[i]; - } - CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); - CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - - auto new_literal = MakeUnique( - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); - - // Create a new shape with the given minor-to-major layout. This shape is used - // solely for converting linear address to multi-dimensional addresses when - // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); - *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - - // Copy data into new literal, element-by-element. - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - std::vector from_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - std::vector to_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); - switch (literal.shape().element_type()) { - case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - default: - LOG(FATAL) << "Unhandled primitive element type: " - << PrimitiveType_Name(literal.shape().element_type()); - } - } - - return new_literal; + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 7b757a4bd7e7592583b7596b4305ddb7e6c52d75..d1b8a6cf0b2552f1b7d95a2560d502da14ddc39a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -38,279 +39,190 @@ limitations under the License. namespace xla { -// Structure describing permissible absolute and relative error bounds. -struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) - : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} - - float abs; // Absolute error bound. - float rel; // Relative error bound. - - // If relaxed_nans is true then any result is valid if we are expecting NaNs. - // In effect, this allows the tested operation to produce incorrect results - // for inputs outside its mathematical domain. - bool relaxed_nans; -}; - // Utility class for making expectations/assertions related to XLA literals. class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. - static ::testing::AssertionResult EqualShapes(const Shape& expected, - const Shape& actual); - static void AssertEqualShapes(const Shape& expected, const Shape& actual); + static ::testing::AssertionResult EqualShapes( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; // Asserts that the provided shapes are equal as defined in AssertEqualShapes // and that they have the same layout. - static void AssertEqualShapesAndLayouts(const Shape& expected, - const Shape& actual); - - // If the given literal's data type is bfloat16, converts it to a float - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - - // If the given literal's data type is float, converts it to a bfloat16 - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); - - // Asserts that the expected and actual literals are (bitwise) equal for all - // elements in the literal. Also, asserts that the rank, dimensions sizes, and - // primitive type are equal. - static ::testing::AssertionResult Equal( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + static ::testing::AssertionResult EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; - // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual, - const string& message = ""); - - // Expects that expected and actual are Not Equal. - static void ExpectNotEqual(const Literal& expected, const Literal& actual); + static ::testing::AssertionResult Equal(const LiteralSlice& expected, + const LiteralSlice& actual) + TF_MUST_USE_RESULT; // Asserts the given literal are (bitwise) equal to given expected values. template - static void ExpectR0Equal(NativeT expected, const Literal& actual); + static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); + template static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR2Equal( std::initializer_list> expected, - const Literal& actual); + const LiteralSlice& actual); + template static void ExpectR3Equal( std::initializer_list< std::initializer_list>> expected, - const Literal& actual); + const LiteralSlice& actual); // Asserts the given literal are (bitwise) equal to given array. template static void ExpectR2EqualArray2D(const Array2D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR3EqualArray3D(const Array3D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR4EqualArray4D(const Array4D& expected, - const Literal& actual); + const LiteralSlice& actual); - // Asserts that the expected and actual literals are within the given error - // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. + // Decorates literal_comparison::Near() with an AssertionResult return type. // - // Tuples are matched recursively. When comparing tensors of - // non-floating-point type, checks for exact equality, ignoring the ErroSpec. - // - // If the shape of the literals is neither a complex/floating-point tensor nor - // a tuple which contains a complex/floating-point tensor, Near() is - // equivalent to Equal(). We don't raise an error in this case, because we - // want to allow callers to call Near() even if they have no preconceptions - // about the shapes being compared. + // See comment on literal_comparison::Near(). static ::testing::AssertionResult Near( - const Literal& expected, const Literal& actual, - const ErrorSpec& error) TF_MUST_USE_RESULT; - - // Expects expected and actual to be Near with the given error. - static void ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error, const string& message = ""); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, + bool detailed_message = false) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. template - static void ExpectR0Near(NativeT expected, const Literal& actual, + static void ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3Near( std::initializer_list< std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4Near( std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. template static void ExpectR2NearArray2D(const Array2D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3NearArray3D(const Array3D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4NearArray4D(const Array4D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. static ::testing::AssertionResult NearOrEqual( - const Literal& expected, const Literal& actual, + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; - // If the error spec is given, expects the expected and the actual to be near; - // otherwise, expects them to be equal. Tuples will be compared recursively. - static void ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error); - - // Returns a multi-dimensional index as a string. For example: '{7, 8}' will - // be returned for a 2-dimensional index with dimension 0 index equal to 7, - // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); - - // Creates a literal with a new shape with the given new dimensions using the - // data in the given input literal. For reshaping purposes the (flat) data - // buffer of the input literal is assumed to have the given minor_to_major - // layout order. - static std::unique_ptr Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const Literal& literal); - - // Creates a literal with the supplied shape, and uses the provided value - // generator to populate the literal's values. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation, and using the engine as entropy generator. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, typename E, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); - private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR0(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR1(expected), actual); + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR2(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3Equal( std::initializer_list>> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR3(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( - const Array2D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); + const Array2D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( - const Array3D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); + const Array3D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( - const Array4D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); + const Array4D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR0(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const Literal& actual, + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR1(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3Near( std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); } template @@ -318,63 +230,29 @@ template std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( - const Array2D& expected, const Literal& actual, + const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( - const Array3D& expected, const Literal& actual, + const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( - const Array4D& expected, const Literal& actual, + const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); - return std::move(literal); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - std::normal_distribution generator(mean, stddev); - return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { - std::minstd_rand0 engine; - return CreateRandomLiteral(shape, &engine, mean, stddev); + EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 3a421f8458268a14dcdd84889bcae4990c095ea4..bbac7285aefbb1f028fad152e4b7fe6af01e9f6d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { std::unique_ptr literal = Literal::MakeTuple({ Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - LiteralTestUtil::ExpectEqual(*literal, *literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -89,7 +89,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { EXPECT_EQ("2", literal->ToString()); } else if (result.find("actual") != string::npos) { EXPECT_EQ("4", literal->ToString()); - } else if (result.find("miscompares") != string::npos) { + } else if (result.find("mismatches") != string::npos) { EXPECT_EQ("true", literal->ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; @@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { + auto expected = Literal::CreateR1({1, 2, 3}); + auto actual = Literal::CreateR1({4, 5, 6}); + ::testing::AssertionResult result = + LiteralTestUtil::Equal(*expected, *actual); + EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); + EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); +} + TEST(LiteralTestUtilTest, NearComparatorR1) { auto a = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 7e92439c494b677f718a63c71c20828d65bebef4..082bc34136e004795ce300c66591758f47c665fe 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -43,7 +43,7 @@ class LLVMCompilerTest : public ::testing::Test { ~LLVMCompilerTest() override {} protected: - using Platform = ::perftools::gputools::Platform; + using Platform = se::Platform; explicit LLVMCompilerTest(string platform_name) : platform_name_(std::move(platform_name)) {} @@ -95,7 +95,7 @@ class LLVMCompilerTest : public ::testing::Test { modules.push_back(hlo_module->Clone()); modules.push_back(std::move(hlo_module)); - std::vector> executors; + std::vector> executors; executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); @@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 99514baf23cafe61adc28a30dfdfe2691ab82d32..2c45f19c090d2690878430363bf0d20252b2f3df 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -49,11 +50,11 @@ void LLVMIRGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - ASSERT_TRUE(CompileToExecutable(std::move(hlo_module)).ok()); + TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status()); ResetIrHook(); StatusOr filecheck_result = RunFileCheck(ir_, pattern); - ASSERT_TRUE(filecheck_result.ok()); + TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.ValueOrDie()); } @@ -61,8 +62,8 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - ASSERT_TRUE( - CompileToAotCompilationResult(std::move(hlo_module), options).ok()); + TF_ASSERT_OK( + CompileToAotCompilationResult(std::move(hlo_module), options).status()); ResetIrHook(); StatusOr filecheck_result = RunFileCheck(ir_, pattern); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 3d30ceeaf1b0369b6fdc0cd9620c04aae287941c..f21f83992ffb7c07dff31c68a7e9e3f7944bf512 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -15,9 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +37,7 @@ class LocalClientAllocationTest : public LocalClientTestBase { }; XLA_TEST_F(LocalClientAllocationTest, AddVectors) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -53,7 +53,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { // deallocation happen on the right allocator. ExecutableRunOptions options; options.set_allocator(allocator); - std::unique_ptr result = + tensorflow::gtl::optional result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), options); @@ -66,14 +66,14 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { // Deallocate result and verify that deallocate was called once. int64 deallocation_count_before = allocator_->deallocation_count(); - result = nullptr; + result.reset(); EXPECT_EQ(deallocation_count_before + 1, allocator_->deallocation_count()); } XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { // Run a computation on every device on the system. Verify that allocation // occurs on the proper device. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 3704ddd8010bf727b75ff81b63605e8b7ffe2ca8..a366afe8262e1f537b225e395bba9cb2fc22683a 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,7 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -29,27 +30,31 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +namespace { + using xla::string; -xla::Computation Doubler(xla::Client* client) { - xla::ComputationBuilder builder(client, "doubler"); +xla::XlaComputation Doubler() { + xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto x = builder.Parameter(0, r0f32, "x"); builder.Mul(x, builder.ConstantR0(2.0)); return std::move(builder.Build().ValueOrDie()); } +} // namespace + int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); - xla::ComputationBuilder builder(client, "aot_test_helper"); + xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); auto opaque_param = builder.Parameter(0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(client), {sum}); + builder.Call(Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; @@ -71,8 +76,8 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); - xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::CompileOnlyClient::AotComputationInstance instance{ + xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); + xla::CompileOnlyClient::AotXlaComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 2462ea39f914b1dbb525ea777a48d9ce66035638..96858c00d6bbe59b673a34e7d5ca261756709596 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -43,8 +42,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -56,61 +53,57 @@ class LocalClientExecuteTest : public LocalClientTestBase { }; XLA_TEST_F(LocalClientExecuteTest, Constant) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto y = builder.ConstantR0(123.0f); - std::unique_ptr result = + ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(*result), + LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddScalars) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.ConstantR0(123.0f); builder.Add(x, y); auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); - std::unique_ptr result = - ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_value.get()}); - - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(*result), + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); + LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x"); auto y = builder.ConstantR1({}); builder.Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); - std::unique_ptr result = - ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); - - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(*result), + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); + LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectors) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); - std::unique_ptr result = - ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); - + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -118,18 +111,17 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; - std::unique_ptr result = ExecuteLocallyOrDie( - builder.Build().ValueOrDie(), {x_array.get()}, - DefaultExecutableBuildOptions(), + ScopedShapedBuffer result = ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Add(x, y); @@ -138,31 +130,31 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); - std::unique_ptr result_colmaj = - ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + ScopedShapedBuffer result_colmaj = + ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_colmaj), + *ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. - std::unique_ptr result_param_swap = - ExecuteLocallyOrDie(computation, {y_array.get(), x_array.get()}); + ScopedShapedBuffer result_param_swap = + ExecuteLocallyOrDie(computation, {&y_array, &x_array}); LiteralTestUtil::ExpectR2Near( {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_param_swap), error_spec_); + *ShapedBufferToLiteral(result_param_swap), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Add(x, y); @@ -174,32 +166,32 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. - std::unique_ptr result_colmaj = ExecuteLocallyOrDie( - computation, {x_array.get(), y_array.get()}, + ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( + computation, {&x_array, &y_array}, DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_colmaj), + *ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. - std::unique_ptr result_rowmaj = ExecuteLocallyOrDie( - computation, {x_array.get(), y_array.get()}, + ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie( + computation, {&x_array, &y_array}, DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_rowmaj), + *ShapedBufferToLiteral(result_rowmaj), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, TupleResult) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Tuple({x, y, x}); @@ -210,24 +202,24 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); - EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {1})); + LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); auto inner_tuple = builder.Tuple({x, y, x}); @@ -239,29 +231,29 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 0})); + LiteralSlice(*result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {0, 1})); + LiteralSlice(*result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 2})); + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { // Verify setting the result layout of a computation with a tuple output. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Tuple({x, y}); @@ -276,15 +268,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0})}); options.set_result_layout(shape_with_layout); - std::unique_ptr result = ExecuteLocallyOrDie( - builder.Build().ValueOrDie(), {array.get(), array.get()}, options, - DefaultExecutableRunOptions()); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, + options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,7 +290,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { // Computation adds the respective array and vector elements from each tuple // argument and returns the results as a tuple. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, tuple_shape0, "x"); auto y = builder.Parameter(1, tuple_shape1, "y"); auto x_0 = builder.GetTupleElement(x, 0); @@ -320,18 +312,18 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { auto x_buffer = LiteralToShapedBuffer(*x_literal); auto y_buffer = LiteralToShapedBuffer(*y_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralView::Create(*result_literal, {0})); + LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); + {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -345,7 +337,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { // Computation negates the array element and sums the two vector elements in // the nested tuple. The resulting array and vector are returned as a tuple. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, nested_tuple_shape, "param"); auto inner_tuple = builder.GetTupleElement(param, 0); auto inner_array = builder.GetTupleElement(inner_tuple, 0); @@ -365,14 +357,13 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Literal::CreateR1({222.0, -2.0, 10.0}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); + {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -384,7 +375,7 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { const Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape}); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, tuple_shape, "param"); auto element_0 = builder.GetTupleElement(param, 0); auto element_1 = builder.GetTupleElement(param, 1); @@ -396,22 +387,20 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Literal::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result_0 = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(*result_0); + ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal( {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralView::Create(*result_0_literal, {0})); + LiteralSlice(*result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); + {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); - std::unique_ptr result_1 = - ExecuteLocallyOrDie(computation, {result_0.get()}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(*result_1); + ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); + std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); + {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -430,11 +419,11 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { std::vector element_shapes(kElementCount, element_shape); const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, tuple_shape, "param"); // Add each element's tuple index value to every element. - std::vector result_elements; + std::vector result_elements; for (int i = 0; i < kElementCount; ++i) { auto element = builder.GetTupleElement(param, i); result_elements.push_back( @@ -453,21 +442,17 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { Literal::MakeTupleOwned(std::move(arg_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } -// TODO(b/66968986): Test times out on CPU parallel backend. Disabled -// 2017-09-26. -XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { +XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { // Construct and run a computation which takes a two-level nested tuple // parameter with a large fanout. const int kFanout = 40; @@ -479,15 +464,15 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { std::vector inner_tuple_shapes(kFanout, inner_tuple_shape); const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, tuple_shape, "param"); // The computation increments each leaf value by an amount equal to the leaf's // ordinal position in a traversal of the tuple. - std::vector result_elements; + std::vector result_elements; for (int i = 0; i < kFanout; ++i) { auto outer_element = builder.GetTupleElement(param, i); - std::vector inner_result_elements; + std::vector inner_result_elements; for (int j = 0; j < kFanout; ++j) { auto inner_element = builder.GetTupleElement(outer_element, j); inner_result_elements.push_back(builder.Add( @@ -511,14 +496,13 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), error_spec_); } } @@ -535,7 +519,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { shape = ShapeUtil::MakeTupleShape({shape}); } - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto element = builder.Parameter(0, shape, "param"); for (int i = 0; i < kTupleDepth; ++i) { element = builder.GetTupleElement(element, 0); @@ -556,21 +540,20 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { } auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralView::Create(*result_literal, index)); + 165.0, LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { // Test passing in an invalid number of arguments. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); builder.Add(x, y); @@ -578,7 +561,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = - ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); + ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), @@ -587,14 +570,14 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { // Test passing in an argument with the wrong shape. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); builder.Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = - ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); + ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), @@ -604,14 +587,14 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { // Test passing in an invalid result layout parameter. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); builder.Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( - builder.Build().ValueOrDie(), {x_array.get()}, + builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{1, 2, 3, 4}, @@ -627,7 +610,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { // Try to run a trivial computation on every device on the system. If a // specific device is not supported, check that the right error is returned. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { @@ -644,9 +627,9 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { computation, {}, DefaultExecutableBuildOptions().set_device_ordinal(d), DefaultExecutableRunOptions().set_device_ordinal(d)); - EXPECT_EQ(d, result->device_ordinal()); + EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(*result)); + *ShapedBufferToLiteral(result)); } } } @@ -654,7 +637,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { // Try running computations on devices with device ordinal values which do not // exist. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto computation = builder.Build().ConsumeValueOrDie(); @@ -671,7 +654,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // Run a computation on a specific stream on each device on the system. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto computation = builder.Build().ConsumeValueOrDie(); @@ -689,9 +672,9 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { DefaultExecutableRunOptions().set_stream(&stream)); // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. - EXPECT_EQ(d, result->device_ordinal()); + EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(*result)); + *ShapedBufferToLiteral(result)); } } @@ -707,7 +690,7 @@ XLA_TEST_F(LocalClientExecuteTest, se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie()); wrong_stream.Init(); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), @@ -724,7 +707,7 @@ XLA_TEST_F(LocalClientExecuteTest, .ValueOrDie(); TestAllocator allocator(wrong_platform); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto y = builder.ConstantR0(123.0f); auto execute_status = ExecuteLocally( @@ -737,7 +720,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { // Try to run a computation on a stream that has not been initialized. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); @@ -757,7 +740,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { } XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -767,17 +750,17 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); builder.Select(builder.ConstantR0(false), tuple12, tuple21); - std::unique_ptr result = + ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(*result); + std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); + {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -793,12 +776,12 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); - std::unique_ptr result = - executable->Run({x_array.get()}, DefaultExecutableRunOptions()) + ScopedShapedBuffer result = + executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -811,7 +794,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { literal, local_client_->default_device_ordinal(), allocator_)); TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, - local_client_->ShapedBufferToLiteral(*shaped_buffer)); + local_client_->ShapedBufferToLiteral(shaped_buffer)); EXPECT_EQ(literal, *transferred_literal); }; @@ -851,7 +834,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { literal, local_client_->default_device_ordinal(), allocator_)); TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, - local_client_->ShapedBufferToLiteral(*shaped_buffer)); + local_client_->ShapedBufferToLiteral(shaped_buffer)); EXPECT_EQ(literal, *transferred_literal); }; @@ -867,9 +850,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { // TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. // 2017-10-18. -XLA_TEST_F(LocalClientExecuteTest, - DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(InfeedOutfeedTest))) { - ComputationBuilder builder(local_client_, TestName()); +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { + XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = builder.Infeed(shape); auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); @@ -907,7 +889,7 @@ void BM_LocalClientOverhead(int num_iters) { int device_ordinal = client->default_device_ordinal(); // Use a tiny add operation as the computation. - ComputationBuilder builder(client, "Add"); + XlaBuilder builder("Add"); auto shape = ShapeUtil::MakeShape(F32, {2, 3}); auto x = builder.Parameter(0, shape, "x"); builder.Add(x, x); @@ -919,12 +901,12 @@ void BM_LocalClientOverhead(int num_iters) { .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, *buffer)); + executors[device_ordinal], *literal, buffer)); const int kWarmups = 2; auto executable_status = client->Compile( - computation, {&buffer->on_host_shape()}, ExecutableBuildOptions()); + computation, {&buffer.on_host_shape()}, ExecutableBuildOptions()); ASSERT_IS_OK(executable_status); std::unique_ptr executable = executable_status.ConsumeValueOrDie(); @@ -936,13 +918,13 @@ void BM_LocalClientOverhead(int num_iters) { run_options.set_allocator(&allocator).set_stream(&stream); for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({buffer.get()}, run_options); + auto result = executable->Run({&buffer}, run_options); ASSERT_IS_OK(result); } tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({buffer.get()}, run_options); + auto result = executable->Run({&buffer}, run_options); ASSERT_IS_OK(result); } } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 96b976d25d75d35f46adfd104a03aceb363661eb..88797a7d0a7d0567b3a380c5fb1ad0c0ee875587 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -35,19 +35,20 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { +StatusOr TestAllocator::Allocate(int device_ordinal, + uint64 size, + bool retry_on_failure) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { tensorflow::mutex_lock lock(count_mutex_); allocation_count_++; device_allocation_count_[device_ordinal]++; } - return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size); + return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size, + retry_on_failure); } -tensorflow::Status TestAllocator::Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) { +Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { tensorflow::mutex_lock lock(count_mutex_); @@ -88,7 +89,7 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { } /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator( - perftools::gputools::Platform* platform) { + se::Platform* platform) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); tensorflow::mutex_lock lock(mu); @@ -115,8 +116,7 @@ struct LocalClientTestBase::EigenThreadPoolWrapper { std::unique_ptr device; }; -LocalClientTestBase::LocalClientTestBase( - perftools::gputools::Platform* platform) +LocalClientTestBase::LocalClientTestBase(se::Platform* platform) : local_client_( ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), thread_pool_wrapper_(new EigenThreadPoolWrapper()) { @@ -128,7 +128,7 @@ LocalClientTestBase::LocalClientTestBase( LocalClientTestBase::~LocalClientTestBase() {} -std::unique_ptr LocalClientTestBase::LiteralToShapedBuffer( +ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( const Literal& literal) { return local_client_ ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal()) @@ -148,23 +148,21 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ExecutableRunOptions run_options; - run_options.set_inter_op_thread_pool( - local_client_->backend().inter_op_thread_pool()); run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); return run_options; } -std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( - const Computation& computation, +ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()) .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( - const Computation& computation, +ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { @@ -172,17 +170,15 @@ std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( .ConsumeValueOrDie(); } -StatusOr> -LocalClientTestBase::ExecuteLocally( - const Computation& computation, +StatusOr LocalClientTestBase::ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } -StatusOr> -LocalClientTestBase::ExecuteLocally( - const Computation& computation, +StatusOr LocalClientTestBase::ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index f0c73f04f6eb67b2e9cb5e111eccdc3818059b2b..258226523d830b40ecaa761df95988dc90f5ca47 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -21,8 +21,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -41,15 +41,14 @@ namespace xla { class TestAllocator : public StreamExecutorMemoryAllocator { public: - explicit TestAllocator(perftools::gputools::Platform* platform) + explicit TestAllocator(se::Platform* platform) : StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { } - StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) override; - tensorflow::Status Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64 allocation_count() const; @@ -75,18 +74,15 @@ class TestAllocator : public StreamExecutorMemoryAllocator { class LocalClientTestBase : public ::testing::Test { protected: struct EigenThreadPoolWrapper; - explicit LocalClientTestBase( - perftools::gputools::Platform* platform = nullptr); + explicit LocalClientTestBase(se::Platform* platform = nullptr); virtual ~LocalClientTestBase(); - static TestAllocator* GetOrCreateAllocator( - perftools::gputools::Platform* platform); + static TestAllocator* GetOrCreateAllocator(se::Platform* platform); // Copy the given literal onto the default device and return a // ScopedShapedBuffer. Convenience wrapper around // LocalClient::LiteralToShapedBuffer. - std::unique_ptr LiteralToShapedBuffer( - const Literal& literal); + ScopedShapedBuffer LiteralToShapedBuffer(const Literal& literal); // Construct and return a literal containing the array represented by // shaped_buffer. @@ -95,20 +91,20 @@ class LocalClientTestBase : public ::testing::Test { // Execute the given computation on the local client. With and without // options. - StatusOr> ExecuteLocally( - const Computation& computation, + StatusOr ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments); - StatusOr> ExecuteLocally( - const Computation& computation, + StatusOr ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); - std::unique_ptr ExecuteLocallyOrDie( - const Computation& computation, + ScopedShapedBuffer ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteLocallyOrDie( - const Computation& computation, + ScopedShapedBuffer ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); @@ -128,7 +124,7 @@ class LocalClientTestBase : public ::testing::Test { // of the process. So make the allocator static. static TestAllocator* allocator_; - perftools::gputools::StreamExecutor* stream_executor_; + se::StreamExecutor* stream_executor_; TransferManager* transfer_manager_; LocalClient* local_client_; diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index 174d433a9e17312c3548668feeeb2e92712c87f8..c0c02e584c2348f64a9d7d0800038f5ca67a2171 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -29,7 +29,7 @@ namespace { class LogTest : public ClientLibraryTestBase {}; XLA_TEST_F(LogTest, LogZeroValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); builder.Log(x); @@ -41,7 +41,7 @@ TEST_F(LogTest, LogTenValues) { std::vector input = {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1(input); builder.Log(x); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2b0f7e6e80c48435ca55432a2afa3b6d69162625..3975e9125703ee081d4e84fa8bd27fcbe483ac34 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -39,7 +39,7 @@ namespace { class MapTest : public ClientLibraryTestBase { public: - explicit MapTest(perftools::gputools::Platform* platform = nullptr) + explicit MapTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -50,18 +50,18 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} ----> (add) // / // 1.0f ---------/ - Computation CreateAdderToOne() { - ComputationBuilder mapped_builder(client_, TestName()); + XlaComputation CreateAdderToOne() { + XlaBuilder mapped_builder(TestName()); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = mapped_builder.ConstantR0(1.0); - auto adder_to_one = mapped_builder.Add(x, one); + mapped_builder.Add(x, one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); } - Computation CreateMax() { - ComputationBuilder b(client_, TestName()); + XlaComputation CreateMax() { + XlaBuilder b(TestName()); auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); b.Max(lhs, rhs); @@ -73,8 +73,8 @@ class MapTest : public ClientLibraryTestBase { // Creates a computation that accepts an F32 and returns T(1) (ignoring the // argument). template - Computation CreateScalarOne() { - ComputationBuilder mapped_builder(client_, "scalar_one"); + XlaComputation CreateScalarOne() { + XlaBuilder mapped_builder("scalar_one"); (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); mapped_builder.ConstantR0(1); auto computation_status = mapped_builder.Build(); @@ -87,11 +87,11 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} ----> (mul) // / // 2.0f ---------/ - Computation CreateMulByTwo() { - ComputationBuilder mapped_builder(client_, TestName()); + XlaComputation CreateMulByTwo() { + XlaBuilder mapped_builder(TestName()); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto two = mapped_builder.ConstantR0(2.0); - auto mul_by_two = mapped_builder.Mul(x, two); + mapped_builder.Mul(x, two); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -105,12 +105,12 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} ----> (add) ----> (mul) // / // 1.0f ---------/ - Computation CreateAdderToOneTimesItself() { - ComputationBuilder mapped_builder(client_, TestName()); + XlaComputation CreateAdderToOneTimesItself() { + XlaBuilder mapped_builder(TestName()); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = mapped_builder.ConstantR0(1.0); auto adder_to_one = mapped_builder.Add(x, one); - auto result = mapped_builder.Mul(x, adder_to_one); + mapped_builder.Mul(x, adder_to_one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -122,12 +122,13 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} -----------> (map) ----> (add) // / / // embedded_computation --/ n --/ - Computation CreateMapPlusN(const Computation& embedded_computation, float n) { - ComputationBuilder builder(client_, TestName()); + XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation, + float n) { + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto map = builder.Map({x}, embedded_computation, {}); auto constant_n = builder.ConstantR0(n); - auto add = builder.Add(map, constant_n); + builder.Add(map, constant_n); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -135,11 +136,11 @@ class MapTest : public ClientLibraryTestBase { // Creates a binary function with signature (F32, F32) -> Pred // defined by (x, y) -> x > y. - Computation CreateGt() { - ComputationBuilder b(client_, "Gt"); + XlaComputation CreateGt() { + XlaBuilder b("Gt"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto gt = b.Gt(x, y); + b.Gt(x, y); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -152,13 +153,13 @@ class MapTest : public ClientLibraryTestBase { // y {R0F32} ----> (add) ---> (add) // / // z {R0F32} ---------------/ - Computation CreateTernaryAdder() { - ComputationBuilder mapped_builder(client_, "TernaryAdder"); + XlaComputation CreateTernaryAdder() { + XlaBuilder mapped_builder("TernaryAdder"); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); auto xy = mapped_builder.Add(x, y); - auto xyz = mapped_builder.Add(xy, z); + mapped_builder.Add(xy, z); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -167,13 +168,13 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {}); + builder.Map({param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -181,13 +182,13 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {0}); + builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -195,55 +196,55 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {0}); + builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapEachF32ElementToS32Constant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne(), {0}); + builder.Map({param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } TEST_F(MapTest, MapEachF32ElementToU32Constant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne(), {0}); + builder.Map({param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0}); + builder.Map({param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -253,14 +254,14 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); + builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -269,7 +270,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = @@ -277,7 +278,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { auto param = builder.Parameter(0, param0_literal->shape(), "param0"); auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); + builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -285,14 +286,14 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {0, 1}); + builder.Map({param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -317,18 +318,18 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed2 = CreateMapPlusN(embed1, 2.0); auto embed3 = CreateMapPlusN(embed1, 4.0); - ComputationBuilder embed4_builder(client_, "embed4"); + XlaBuilder embed4_builder("embed4"); auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); - auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); + embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); auto embed4 = embed4_status.ConsumeValueOrDie(); auto embed5 = CreateMapPlusN(embed2, 6.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto constant_42 = builder.ConstantR0(42.0); auto constant_7 = builder.ConstantR0(7.0); auto map_42 = builder.Map({constant_42}, embed5, {}); @@ -338,50 +339,9 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); } -TEST_F(MapTest, VersionedEmbeddedComputation) { - // Build a computation X, use it in a map, then add an additional operation to - // computation X and use it again in a different map. Verify that the proper - // versions of computation X are used in each of the maps. - - // Create a (embedded) computation which adds one to its parameter argument. - ComputationBuilder embedded_builder(client_, "EmbeddedComputation"); - auto param_0 = - embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - auto constant_one = embedded_builder.ConstantR0(1.0); - auto adder_to_one = embedded_builder.Add(param_0, constant_one); - auto computation_status = embedded_builder.Build(); - ASSERT_IS_OK(computation_status.status()); - auto embedded_computation = computation_status.ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto constant_vector = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0}); - - // Add another Add(1) operation to the existing embedded computation. This - // requires using the stub interface because the ComputationBuilder does not - // allow modification to the Computation objects after they have been built. - BinaryOpRequest request; - request.set_binop(BINOP_ADD); - *request.mutable_lhs() = adder_to_one; - *request.mutable_rhs() = constant_one; - OpRequest op_request; - *op_request.mutable_computation() = embedded_computation.handle(); - *op_request.mutable_binary_op_request() = request; - OpResponse response; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - ASSERT_TRUE(s.ok()); - - auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0}); - - // The original vector has Add(1) applied to it with a map, followed by - // Add(1+1) resulting in a net Add(3). - ComputeAndCompareR1(&builder, {4.0, 5.0, 6.0, 7.0}, {}, - ErrorSpec(0.01f)); -} - TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = @@ -393,8 +353,7 @@ TEST_F(MapTest, MapBinaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, - CreateScalarAddComputation(F32, &builder), {0}); + builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -404,7 +363,7 @@ TEST_F(MapTest, MapBinaryAdder) { // Adds two rank-2 arrays with different layouts. This test exercises a path // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = @@ -417,8 +376,8 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, - CreateScalarAddComputation(S32, &builder), {0, 1}); + builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1}); Array2D expected(2, 2); expected(0, 0) = 11; @@ -430,7 +389,7 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { } XLA_TEST_F(MapTest, AddR3_3x0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = @@ -443,8 +402,8 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, - CreateScalarAddComputation(S32, &builder), {0, 1, 2}); + builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -452,7 +411,7 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = @@ -469,7 +428,7 @@ TEST_F(MapTest, MapTernaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); + builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -479,24 +438,24 @@ TEST_F(MapTest, MapTernaryAdder) { TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto gt = CreateGt(); b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt, {0}); ComputeAndCompareR1(&b, {false, true}, {}); } TEST_F(MapTest, NestedBinaryMap) { - Computation max_with_square; + XlaComputation max_with_square; { // max_with_square(x) = do max(x, x^2) via a map. - ComputationBuilder b(client_, "max_with_square"); + XlaBuilder b("max_with_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); b.Map({x, b.Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); b.Map({input}, max_with_square, {0}); ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); @@ -505,13 +464,13 @@ TEST_F(MapTest, NestedBinaryMap) { TEST_F(MapTest, MapOperantionWithBuildError) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported // type combination (F32 + U16) to test that the error is reported to the - // outermost ComputationBuilder. - ComputationBuilder builder(client_, TestName()); + // outermost XlaBuilder. + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); - auto adder = sub_builder->Add(x, y); + sub_builder->Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = @@ -525,14 +484,13 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, error_add, {0}); + builder.Map({param0, param1}, error_add, {0}); - StatusOr computation_status = builder.Build(); + StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT( - computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all @@ -545,7 +503,7 @@ using MapTestWithFullOpt = ClientLibraryTestBase; // to have issues with such patterns and maybe invalidate the pointer to entry // computation. TEST_F(MapTestWithFullOpt, MapScalarPower) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); @@ -572,7 +530,7 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { // Regression test for b/35786417, where the inliner would not notice the change // of parameter order inside the map. TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); @@ -598,7 +556,7 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { // Regression test for b/35786417, where the inliner would CHECK-fail due to the // mul inside the map having more parameters than the map does. TEST_F(MapTestWithFullOpt, MapSquare) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 6c86dd5b9ef673c9facffafa37e00a859ce82010..27fd36e06acdc589f3a84ad561164e4a33b93506 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -29,6 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -38,258 +40,211 @@ limitations under the License. namespace xla { namespace { -class MatOpsSimpleTest : public ClientLibraryTestBase { - protected: - Computation BuildSum() { - // sum(x, y) = x + y - ComputationBuilder builder(client_, "sum"); - auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto y_value = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y_value"); - builder.Add(x_value, y_value); - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return computation_status.ConsumeValueOrDie(); - } +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TypesF16F32 = ::testing::Types; +#else +using TypesF16F32 = ::testing::Types; +#endif - void TestLinspaceMax(int64 rows, int64 cols) { - float from = -128.0, to = 256.0; - std::unique_ptr> alhs = - MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, 1.0); - - ComputationBuilder builder( - client_, - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - auto max = builder.Max(lhs, rhs); +class MatOpsSimpleTest : public ClientLibraryTestBase {}; - Array2D aexpected(rows, cols); - for (int row = 0; row < rows; ++row) { - for (int col = 0; col < cols; ++col) { - aexpected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); - } - } +template +class MatOpsSimpleTest_F16F32 : public MatOpsSimpleTest {}; - ComputeAndCompareR2(&builder, aexpected, {}, ErrorSpec(1e-6)); - } -}; +TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); -TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { - ComputationBuilder builder(client_, "exp_2x2"); - auto data = builder.ConstantR2({ - {1.0, 0.0}, // row 0 - {-1.0, 0.5}, // row 1 +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { + using T = TypeParam; + XlaBuilder builder("exp_2x2"); + auto data = builder.ConstantR2FromArray2D({ + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 }); builder.Exp(data); std::unique_ptr expected = - Literal::CreateR2({{2.71828, 1.00000}, // row 0 - {0.36788, 1.64872}}); // row 1 + Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 + {0.36788f, 1.64872f}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } -TEST_F(MatOpsSimpleTest, MapTwoByTwo) { - Computation add_half; +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { + using T = TypeParam; + XlaComputation add_half; { // add_half(x) = x + 0.5 - ComputationBuilder builder(client_, "add_half"); + XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto half = builder.ConstantR0(0.5); + builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); + auto half = builder.ConstantR0(static_cast(0.5)); builder.Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); } - ComputationBuilder builder(client_, "map_2x2"); - auto data = builder.ConstantR2({ - {1.0, 0.0}, // row 0 - {-1.0, 0.5}, // row 1 + XlaBuilder builder("map_2x2"); + auto data = builder.ConstantR2FromArray2D({ + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 }); auto map = builder.Map({data}, add_half, {0, 1}); std::unique_ptr expected = - Literal::CreateR2({{1.5, 0.5}, // row 0 - {-0.5, 1.0}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 + {-0.5f, 1.0f}}); // row 1 + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } -TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { - ComputationBuilder builder(client_, "max_2x2"); - auto lhs = builder.ConstantR2({ - {7.0, 2.0}, // row 0 - {3.0, -4.0}, // row 1 +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { + using T = TypeParam; + XlaBuilder builder("max_2x2"); + auto lhs = builder.ConstantR2FromArray2D({ + {7.0f, 2.0f}, // row 0 + {3.0f, -4.0f}, // row 1 }); - auto rhs = builder.ConstantR2({ - {5.0, 6.0}, // row 0 - {1.0, -8.0}, // row 1 + auto rhs = builder.ConstantR2FromArray2D({ + {5.0f, 6.0f}, // row 0 + {1.0f, -8.0f}, // row 1 }); auto max = builder.Max(lhs, rhs); std::unique_ptr expected = - Literal::CreateR2({{7.0, 6.0}, // row 0 - {3.0, -4.0}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 + {3.0f, -4.0f}}); // row 1 + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } -TEST_F(MatOpsSimpleTest, Max1x1Linspace) { TestLinspaceMax(1, 1); } - -TEST_F(MatOpsSimpleTest, Max2x2Linspace) { TestLinspaceMax(2, 2); } - -TEST_F(MatOpsSimpleTest, Max3x3Linspace) { TestLinspaceMax(3, 3); } - -TEST_F(MatOpsSimpleTest, Max4x4Linspace) { TestLinspaceMax(4, 4); } - -TEST_F(MatOpsSimpleTest, Max6x6Linspace) { TestLinspaceMax(6, 6); } - -TEST_F(MatOpsSimpleTest, Max8x8Linspace) { TestLinspaceMax(8, 8); } - -TEST_F(MatOpsSimpleTest, Max12x12Linspace) { TestLinspaceMax(12, 12); } - -TEST_F(MatOpsSimpleTest, Max16x16Linspace) { TestLinspaceMax(16, 16); } +struct TestLinspaceMaxParam { + int64 rows; + int64 cols; +}; -TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); } +class TestLinspaceMaxParametric + : public MatOpsSimpleTest, + public ::testing::WithParamInterface { + public: + template + void TestImpl() { + TestLinspaceMaxParam param = GetParam(); + int64 rows = param.rows; + int64 cols = param.cols; + float from = -128.0, to = 256.0; + std::unique_ptr> alhs = + MakeLinspaceArray2D(from, to, rows, cols); + auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); -TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); } + XlaBuilder builder( + tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + auto max = builder.Max(lhs, rhs); -class MatOpsDotAddTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; - -TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { - bool row_major = std::get<0>(GetParam()); - bool add_lhs = std::get<1>(GetParam()); - bool transpose = std::get<2>(GetParam()); - Array2D lhs({{1.0, 2.0}, {3.0, 4.0}}); - Array2D rhs({{10.0, 11.0}, {12.0, 13.0}}); - - auto minor_to_major = [](bool row_major) -> std::vector { - return {row_major ? 1 : 0, row_major ? 0 : 1}; - }; - - auto prim_type = primitive_util::NativeToPrimitiveType(); - Shape lhs_shape = - ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); - Shape rhs_shape = - ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - - TF_ASSERT_OK_AND_ASSIGN( - auto lhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSERT_OK_AND_ASSIGN( - auto rhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - - ComputationBuilder builder(client_, TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); - auto lhs_mat_arg = lhs_arg; - if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); - } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); - Array2D expected; - if (add_lhs) { - result = builder.Add(result, lhs_arg); - if (transpose) { - expected = Array2D({{47, 52}, {71, 78}}); - } else { - expected = Array2D({{35, 39}, {81, 89}}); + Array2D expected(rows, cols); + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + expected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); + } } - } else { - result = builder.Add(result, rhs_arg); - if (transpose) { - expected = Array2D({{56, 61}, {80, 87}}); - } else { - expected = Array2D({{44, 48}, {90, 98}}); + ErrorSpec error_spec(1e-6); + if (std::is_same::value) { + error_spec = ErrorSpec(1e-6, 2e-4); } + ComputeAndCompareR2(&builder, expected, {}, error_spec); } +}; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(1e-6)); +string PrintTestLinspaceMaxParam( + const ::testing::TestParamInfo& test_param) { + const TestLinspaceMaxParam& param = test_param.param; + return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); } -INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, - ::testing::Combine(::testing::Bool(), ::testing::Bool(), - ::testing::Bool())); +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(TestLinspaceMaxParametric, TestF16) { TestImpl(); } +#endif +XLA_TEST_P(TestLinspaceMaxParametric, TestF32) { TestImpl(); } + +INSTANTIATE_TEST_CASE_P( + TestLinspaceMax, TestLinspaceMaxParametric, + ::testing::Values(TestLinspaceMaxParam{1, 1}, TestLinspaceMaxParam{2, 2}, + TestLinspaceMaxParam{3, 3}, TestLinspaceMaxParam{4, 4}, + TestLinspaceMaxParam{6, 6}, TestLinspaceMaxParam{8, 8}, + TestLinspaceMaxParam{12, 12}, + TestLinspaceMaxParam{16, 16}, TestLinspaceMaxParam{32, 8}, + TestLinspaceMaxParam{64, 8}), + PrintTestLinspaceMaxParam); -class MatOpsDotAddTest_bf16 +class MatOpsDotAddTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; - -TEST_P(MatOpsDotAddTest_bf16, Dot_Add_2x2_2x2) { - bool row_major = std::get<0>(GetParam()); - bool add_lhs = std::get<1>(GetParam()); - bool transpose = std::get<2>(GetParam()); - Array2D lhs( - {{bfloat16(1.0f), bfloat16(2.0f)}, {bfloat16(3.0), bfloat16(4.0)}}); - Array2D rhs( - {{bfloat16(10.0f), bfloat16(11.0f)}, {bfloat16(12.0f), bfloat16(13.0f)}}); - - auto minor_to_major = [](bool row_major) -> std::vector { - return {row_major ? 1 : 0, row_major ? 0 : 1}; - }; - - auto prim_type = primitive_util::NativeToPrimitiveType(); - Shape lhs_shape = - ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); - Shape rhs_shape = - ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - - TF_ASSERT_OK_AND_ASSIGN( - auto lhs_handle, - client_->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSERT_OK_AND_ASSIGN( - auto rhs_handle, - client_->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - - ComputationBuilder builder(client_, TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); - auto lhs_mat_arg = lhs_arg; - if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); - } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); - Array2D expected; - if (add_lhs) { - result = builder.Add(result, lhs_arg); + public ::testing::WithParamInterface> { + public: + template + void TestImpl() { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + bool transpose = std::get<2>(GetParam()); + Array2D lhs({{1.0f, 2.0f}, {3.0f, 4.0f}}); + Array2D rhs({{10.0f, 11.0f}, {12.0f, 13.0f}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSERT_OK_AND_ASSIGN( + auto lhs_handle, + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSERT_OK_AND_ASSIGN( + auto rhs_handle, + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + XlaBuilder builder(TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_mat_arg = lhs_arg; if (transpose) { - expected = Array2D( - {{bfloat16(47), bfloat16(52)}, {bfloat16(71), bfloat16(78)}}); - } else { - expected = Array2D( - {{bfloat16(35), bfloat16(39)}, {bfloat16(81), bfloat16(89)}}); + lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); } - } else { - result = builder.Add(result, rhs_arg); - if (transpose) { - expected = Array2D( - {{bfloat16(56), bfloat16(61)}, {bfloat16(80), bfloat16(87)}}); + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_mat_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + if (transpose) { + expected = Array2D({{47.0f, 52.0f}, {71.0f, 78.0f}}); + } else { + expected = Array2D({{35.0f, 39.0f}, {81.0f, 89.0f}}); + } } else { - expected = Array2D( - {{bfloat16(44), bfloat16(48)}, {bfloat16(90), bfloat16(98)}}); + result = builder.Add(result, rhs_arg); + if (transpose) { + expected = Array2D({{56.0f, 61.0f}, {80.0f, 87.0f}}); + } else { + expected = Array2D({{44.0f, 48.0f}, {90.0f, 98.0f}}); + } } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); } +}; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(1e-6)); -} +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2BF16) { TestImpl(); } +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2F16) { TestImpl(); } +#endif +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2F32) { TestImpl(); } -INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest_bf16, +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool())); diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 11c0bf7a5a5bde9edcfb7f76a5c10ac4dd77bcee..0791a71aacf7614286fe964623a3172a174d4722 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -32,7 +32,7 @@ namespace { class SliceTest : public ClientLibraryTestBase {}; XLA_TEST_F(SliceTest, Slice2D) { - ComputationBuilder builder(client_, "slice_2d"); + XlaBuilder builder("slice_2d"); auto original = builder.ConstantR2( {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); builder.Slice(original, {2, 1}, {4, 3}, {1, 1}); @@ -42,7 +42,7 @@ XLA_TEST_F(SliceTest, Slice2D) { } XLA_TEST_F(SliceTest, Slice3D) { - ComputationBuilder builder(client_, "slice_3d"); + XlaBuilder builder("slice_3d"); Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); auto original = builder.ConstantR3FromArray3D(array_3d); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 0a603f4954badd12adf3144320789a5edd0d9c6c..41f723edf1ff3518686231f31b61b64291b1f6bf 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -108,7 +107,7 @@ class MultiOutputFusionTest : public HloTestBase { expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer( std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -168,7 +167,7 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } }; @@ -211,5 +210,309 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) + } + + ENTRY PredFloatMOF { + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[] parameter(0) + multiply = f32[] multiply(p, p) + less-than = pred[] less-than(p, multiply) + ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + } + + map_computation { + p0 = f32[] parameter(0) + fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + const = f32[] constant(0) + ROOT select = f32[] select(gte0, gte1, const) + } + + ENTRY MapMOF { + p1 = f32[3] parameter(0) + ROOT map = f32[3] map(p1), to_apply=map_computation + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); +} + +const char* const kScalarOps = R"( + HloModule m + + Add { + lhsadd = f32[] parameter(0) + rhsadd = f32[] parameter(1) + ROOT add = f32[] add(lhsadd, rhsadd) + } + + Max { + lhsmax = f32[] parameter(0) + rhsmax = f32[] parameter(1) + ROOT max = f32[] maximum(lhsmax, rhsmax) + } +)"; + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(1.17549e-38) + r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max + r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add + ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) + tuple(p0, r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) + tuple(r1, mul, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) + tuple(r1, mul, mul2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR1({14, 22}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR3( + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + init1 = f32[] parameter(1) + init2 = f32[] parameter(2) + r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add + r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + i = f32[] parameter(1) + j = f32[] parameter(2) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = Literal::CreateR0(5); + auto init2 = Literal::CreateR0(6); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + Execute(std::move(module), {param.get(), init1.get(), init2.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{167, 172}, {176, 180}}), + Literal::CreateR2({{6, 6}, {6, 8}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 8cef8dd34dc7b16b1e58ded67d6b6a4ba79f20db..ce295b832d79e4f00656f2893c2ba1162693dd73 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -85,7 +85,7 @@ class PadTestFloat : public PadTest, // Tests a Pad() with a zero-element input and output. XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Set up the padding configuration {low: 0, high: 0, interior: 0}. PaddingConfig padding_config; auto dimension = padding_config.add_dimensions(); @@ -100,7 +100,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { // Tests a Pad() with a zero-element input but a non-zero-element output. XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; auto dimension = padding_config.add_dimensions(); @@ -115,7 +115,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { } XLA_TEST_P(PadTestFloat, Pad1DS3Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; auto dimension = padding_config.add_dimensions(); @@ -130,7 +130,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { } XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, @@ -138,7 +138,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { } TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ {1.0f, 2.0f}, // row 0 @@ -162,7 +162,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { } TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -181,7 +181,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { } TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); PaddingConfig padding_config; auto dimension0 = padding_config.add_dimensions(); @@ -223,7 +223,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { } XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); PaddingConfig padding_config; auto dimension0 = padding_config.add_dimensions(); @@ -266,7 +266,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { } XLA_TEST_F(PadTest, Pad4DU8Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ {1, 2}, // row 0 @@ -290,7 +290,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { } XLA_TEST_F(PadTest, Pad4DPredArray) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Since bool is currently not well supported, use Broadcast operation to // create the operand for Pad. @@ -317,7 +317,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { } XLA_TEST_P(PadTestFloat, Large2DPad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto ones = MakeUnique>(4, 4); ones->Fill(1.0f); @@ -329,15 +329,14 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, AllTypes2DPad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; @@ -352,15 +351,14 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, High2DPad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 129; constexpr int64 in_cols = 129; @@ -378,8 +376,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -387,7 +384,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { } XLA_TEST_P(PadTestFloat, NegativePadding2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 129; constexpr int64 in_cols = 129; @@ -406,8 +403,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -415,7 +411,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { } XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 8; constexpr int64 in_cols = 11; @@ -434,8 +430,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -444,20 +439,19 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { // Regression test for b/31827337. XLA_TEST_P(PadTestFloat, ReducePad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto ones = MakeUnique>(2, 2, 2, 2); ones->Fill(1.0); auto input = AddParam(*ones, &b); - Computation add = CreateScalarAddComputation(FloatType(), &b); + XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - auto padded = b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), - padding_config); + b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index bb7e800df84121f2045141bc366c34b94ba694ea..838f1b4e2f0f0e0871ec717bdeefcbbc653397e3 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -41,7 +41,7 @@ namespace { class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -53,7 +53,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { } XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -65,7 +65,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { } XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = @@ -78,7 +78,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { } XLA_TEST_F(ParamsTest, ConstantR1U8Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); string str("hello world"); std::unique_ptr param0_literal = Literal::CreateR1U8(str); std::unique_ptr param0_data = @@ -91,7 +91,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { } XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = @@ -104,7 +104,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { } XLA_TEST_F(ParamsTest, ConstantR2F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = @@ -119,7 +119,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { } XLA_TEST_F(ParamsTest, TwoParameters) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = @@ -156,19 +156,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); - auto computation = builder.Build().ConsumeValueOrDie(); + auto computation_status = builder.Build(); - auto execute_status = client_->Execute(computation, {data.get(), data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/nullptr); - ASSERT_EQ(execute_status.status().code(), - tensorflow::error::FAILED_PRECONDITION); + ASSERT_NE(computation_status.status(), Status::OK()); } XLA_TEST_F(ParamsTest, UnusedParameter) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = @@ -188,7 +184,7 @@ XLA_TEST_F(ParamsTest, UnusedParameter) { XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // Build a computation with a couple unused parameters which are used in an // unused expression. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = @@ -214,12 +210,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { } XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int size = 8 * 128 * 2; std::vector init_value = {{0, 1}}; init_value.resize(size); - ComputationDataHandle sum_handle = builder.ConstantR1(init_value); + XlaOp sum_handle = builder.ConstantR1(init_value); std::vector sum = {{0, 1}}; sum.resize(size); @@ -237,8 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); sum_handle = builder.Add(sum_handle, param); } @@ -262,10 +257,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { // compilation. XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector> param_data_owner; - ComputationDataHandle sum_handle = builder.ConstantR0(0.0f); + XlaOp sum_handle = builder.ConstantR0(0.0f); float target = 0.0; constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { @@ -273,8 +268,7 @@ XLA_TEST_F(ParamsTest, std::unique_ptr literal = Literal::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); sum_handle = builder.Add(sum_handle, param); } @@ -294,25 +288,24 @@ XLA_TEST_F(ParamsTest, // compilation. XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ThreeThousandParametersAndOutputElements))) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector> param_data_owner; - ComputationDataHandle sum_handle = builder.ConstantR1({0, 0}); + XlaOp sum_handle = builder.ConstantR1({0, 0}); int32 target = 0; constexpr int kParamCount = 3000; - std::vector params; + std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); params.push_back(param); sum_handle = builder.Add(sum_handle, param); } - std::vector outputs; + std::vector outputs; for (int i = 0; i < kParamCount; ++i) { outputs.push_back(builder.Add(params[i], sum_handle)); } @@ -353,18 +346,17 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( // 2017-12-12. XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector> param_data_owner; constexpr int kParamCount = 1900; - std::vector params; + std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); params.push_back(param); parameter_shapes.push_back(literal->shape()); } @@ -374,7 +366,7 @@ XLA_TEST_F(ParamsTest, std::unique_ptr bool_literal = Literal::CreateR0(false); param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); - ComputationDataHandle bool_param = + XlaOp bool_param = builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); params.push_back(bool_param); parameter_shapes.push_back(bool_literal->shape()); @@ -383,9 +375,9 @@ XLA_TEST_F(ParamsTest, // Create a computation for the condition: while(bool_param). Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto condition_parameter = builder.Parameter(0, while_shape, "condition_parameter"); builder.GetTupleElement(condition_parameter, kParamCount); @@ -394,11 +386,11 @@ XLA_TEST_F(ParamsTest, // Create a computation for the body. // Add {1, 1} to the each tuple element. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); - std::vector updates; + std::vector updates; for (int i = 0; i < kParamCount; ++i) { auto add = builder.Add(builder.GetTupleElement(body_parameter, i), builder.ConstantR1({1, 1})); @@ -413,7 +405,7 @@ XLA_TEST_F(ParamsTest, auto loop = builder.While(condition, body, init); - std::vector outputs; + std::vector outputs; for (int i = 0; i < kParamCount; ++i) { outputs.push_back(builder.GetTupleElement(loop, i)); } @@ -437,7 +429,7 @@ XLA_TEST_F(ParamsTest, #endif XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3}); @@ -464,7 +456,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, literal->shape(), "input"); std::unique_ptr data = @@ -476,7 +468,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, literal->shape(), "input"); std::unique_ptr data = @@ -501,7 +493,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { ASSERT_EQ(2, literal->Get({0, 1})); } // Use the original shape in building the computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.Parameter(0, original, "input"); // Use the slice operator to get an off-diagonal element. builder.Slice(input, {0, 1}, {1, 2}, {1, 1}); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 10e44b274a8a9f3ac28dc40d7b1938d24a9ee40c..77159efb26f3b7dd4918f24305f7269a2d6ff647 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,63 +29,62 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare(bool lhs, bool rhs, bool expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + void TestCompare( + bool lhs, bool rhs, bool expected, + XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; TEST_F(PredTest, ConstantR0PredTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(true); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, ConstantR0PredFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(false); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &ComputationBuilder::Eq); + TestCompare(true, false, false, &XlaBuilder::Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &ComputationBuilder::Ne); + TestCompare(true, false, true, &XlaBuilder::Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &ComputationBuilder::Le); + TestCompare(true, false, false, &XlaBuilder::Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &ComputationBuilder::Lt); + TestCompare(true, false, false, &XlaBuilder::Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &ComputationBuilder::Ge); + TestCompare(true, false, true, &XlaBuilder::Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &ComputationBuilder::Gt); + TestCompare(true, false, true, &XlaBuilder::Gt); } TEST_F(PredTest, ConstantR1Pred) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, false, true}); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } TEST_F(PredTest, ConstantR2Pred) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { @@ -96,28 +95,28 @@ TEST_F(PredTest, ConstantR2Pred) { } TEST_F(PredTest, AnyR1True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false}); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR1False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false}); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR1VacuouslyFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR2True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({ {false, false, false}, {false, false, false}, @@ -128,7 +127,7 @@ TEST_F(PredTest, AnyR2True) { } TEST_F(PredTest, AnyR2False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({ {false, false, false}, {false, false, false}, diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 6aafb9fa6cb2175c478f0e9a5e16f5808cbea590..1a2de6937c3e134852a730f62f7b56417cf49b28 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -52,13 +52,14 @@ class PrngTest : public ClientLibraryTestBase { template std::unique_ptr PrngTest::UniformTest( T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngUniform( builder.ConstantR0(a), builder.ConstantR0(b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); SetSeed(seed); - auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + auto actual = + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); @@ -81,8 +82,7 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( - DISABLED_ON_CPU(ScalarBF16Tests)))) { +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { for (int64 seed = 0; seed < 100; ++seed) { // The largest negative number smaller than zero in bf16 that's not // denormalized. @@ -105,8 +105,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU( - DISABLED_ON_CPU_PARALLEL(ScalarBF16CountTests)))) { +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, // they should get similar counts. bfloat16 low = static_cast(32.25); @@ -141,13 +140,14 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, int64 seed) { int32 sample_size = range_size * expected_count; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(range_size), ShapeUtil::MakeShape(S32, {sample_size})); SetSeed(seed); - auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + auto actual = + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); actual->EachCell([&counts](tensorflow::gtl::ArraySlice, int32 value) { ++counts[value]; }); @@ -182,16 +182,15 @@ XLA_TEST_F(PrngTest, Uniformity256) { XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. - auto build_sum_rng = [this](ComputationBuilder& builder) { + auto build_sum_rng = [this](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input"); - b->Add(x, - b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), - ShapeUtil::MakeShape(F32, {}))); + b->Add(x, b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), + ShapeUtil::MakeShape(F32, {}))); return b->BuildAndNoteError(); }; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, @@ -226,7 +225,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { // Build a U[0,1) computation. auto build_computation = [this]() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(1), ShapeUtil::MakeShape(F32, {10})); @@ -274,32 +273,32 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - LiteralTestUtil::ExpectEqual(*result1, *result2); - LiteralTestUtil::ExpectEqual(*result1, *result3); - LiteralTestUtil::ExpectNotEqual(*result1, *result4); - LiteralTestUtil::ExpectNotEqual(*result4, *result5); - LiteralTestUtil::ExpectNotEqual(*result5, *result6); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), ShapeUtil::MakeShape(F32, {10})); SetSeed(42); - ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); // TODO(b/25995601): Test that resultant values are reasonable } XLA_TEST_F(PrngTest, RngUniformCrash) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // This used to crash XLA during LLVM IR generation for CPUs. auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(1000 * 1000), ShapeUtil::MakeShape(S32, {})); SetSeed(0); - ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); } } // namespace diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 212512207cfdc4d2ebdc4e7fd8f5794852cc6a79..f95e75648343aa88bd7c39de4ee9f387f2b60506 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,13 +30,13 @@ namespace { class QueryInferredShapeTest : public ClientLibraryTestBase {}; TEST_F(QueryInferredShapeTest, OnePlusOneShape) { - ComputationBuilder builder(client_, "one_plus_one"); + XlaBuilder builder("one_plus_one"); auto one = builder.ConstantR0(1.0); auto result = builder.Add(one, one); - StatusOr> shape_status = builder.GetShape(result); + StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.ConsumeValueOrDie(); - ASSERT_TRUE(ShapeUtil::Equal(*shape, ShapeUtil::MakeShape(F32, {}))); + ASSERT_TRUE(ShapeUtil::Equal(shape, ShapeUtil::MakeShape(F32, {}))); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index c0a2c0ca4cb8414e0771a541b9f963f9aedc8376..9052b188ed09a715b6ad7c3a40dc853d02cdd70c 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -73,7 +73,7 @@ ENTRY reduce.1 { } )"; - return tools::Parse(hlo_string); + return ParseHloString(hlo_string); } // TODO(b/72454718): XLA:GPU does not support executing code compiled without diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index dc7ce3253cee255a7949326fa5b49fc8917432b8..b311785449f1774c3bc1e4d7ad35c2866e3b4061 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" @@ -228,15 +228,14 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { // This is required for proper handling of NaN values. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = builder.Parameter(0, a_literal->shape(), "a"); - auto reduce_precision = - builder.ReducePrecision(a, exponent_bits, mantissa_bits); + builder.ReducePrecision(a, exponent_bits, mantissa_bits); ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); } @@ -252,7 +251,7 @@ class ReducePrecisionInsertionTest : public ClientLibraryTestBase {}; // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -265,7 +264,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the // reduce-precision operation showed up in the correct place in the // graph. - auto log = builder.Log(abs); + builder.Log(abs); // Insert precision-reduction after the Abs(x) operation, rounding that // result to exactly 1.0f. @@ -281,7 +280,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -290,7 +289,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass after operation fusion, suffixing kAbs operations. This // should not see into the fusion nodes and thus should not affect the @@ -307,7 +306,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -316,7 +315,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass after operation fusion, suffixing kFusion operations. auto reduce_precision_pass = execution_options_.mutable_debug_options() @@ -331,7 +330,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -340,7 +339,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass suffixing fusion nodes containing kCos operations. This // should have no effect. @@ -356,7 +355,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -365,7 +364,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass suffixing fusion nodes containing kAbs operations. This // should see the kAbs operation within the above fusion node. diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 50d7b5074d201d2292cf90224ef4cd37efdbb8d3..d671d40456a276a44b462f390c95aa4af301263a 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -34,11 +34,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -57,6 +58,10 @@ limitations under the License. namespace xla { namespace { +using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*); + +using FuncGenerator = XlaComputation (*)(XlaBuilder*); + class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { @@ -81,8 +86,8 @@ class ReduceTest : public ClientLibraryTestBase { // Runs an R1 => R0 reduction test with the given number of elements. void RunR1ToR0Test(int64 element_count) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -111,13 +116,13 @@ class ReduceTest : public ClientLibraryTestBase { void RunR1ToR0PredTest(bool and_reduce, tensorflow::gtl::ArraySlice input_data) { const int element_count = input_data.size(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); auto input_par = builder.Parameter(0, input_shape, "input"); auto pred_values = builder.Eq(input_par, builder.ConstantR1(element_count, 1)); - ComputationDataHandle init_value; - Computation reduce; + XlaOp init_value; + XlaComputation reduce; if (and_reduce) { init_value = builder.ConstantR0(true); reduce = CreateScalarAndComputation(&builder); @@ -149,13 +154,13 @@ class ReduceTest : public ClientLibraryTestBase { template void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1, int64 major = 0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto input_pred = builder.Eq(input, builder.ConstantR0(1)); - ComputationDataHandle init_value; - Computation reduce_op; + XlaOp init_value; + XlaComputation reduce_op; if (and_reduce) { init_value = builder.ConstantR0(true); reduce_op = CreateScalarAndComputation(&builder); @@ -194,8 +199,8 @@ class ReduceTest : public ClientLibraryTestBase { // Runs an R2 => R0 reduction test with the given number of (rows, cols). void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -222,8 +227,8 @@ class ReduceTest : public ClientLibraryTestBase { // Runs an R2 => R1 reduction test with the given number of (rows, cols). void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -253,7 +258,7 @@ class ReduceTest : public ClientLibraryTestBase { template void ComputeAndCompareGeneric( typename std::enable_if::value, - ComputationBuilder>::type* builder, + XlaBuilder>::type* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { ComputeAndCompareR1(builder, expected, arguments, @@ -263,7 +268,7 @@ class ReduceTest : public ClientLibraryTestBase { template void ComputeAndCompareGeneric( typename std::enable_if::value, - ComputationBuilder>::type* builder, + XlaBuilder>::type* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { ComputeAndCompareR1(builder, expected, arguments); @@ -271,15 +276,15 @@ class ReduceTest : public ClientLibraryTestBase { template void RunVectorizedReduceTestForType( - const std::function& + const std::function& reduction_function_generator, const std::function& reference_reduction_function, const NativeT& initial_value) { const int rows = 64, cols = 128; const int minor = 1, major = 0; - ComputationBuilder builder(client_, TestName()); - Computation reduction_function = reduction_function_generator(&builder); + XlaBuilder builder(TestName()); + XlaComputation reduction_function = reduction_function_generator(&builder); const Shape input_shape = ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); @@ -314,7 +319,7 @@ class ReduceTest : public ClientLibraryTestBase { } void RunVectorizedReduceTest( - const std::function& + const std::function& reduction_function_generator_for_type, const std::function& reference_reduction_function_for_floats, @@ -326,21 +331,21 @@ class ReduceTest : public ClientLibraryTestBase { uint32 unsigned_int_identity) { // Float version RunVectorizedReduceTestForType( - [&](ComputationBuilder* builder) { + [&](XlaBuilder* builder) { return reduction_function_generator_for_type(F32, builder); }, reference_reduction_function_for_floats, floating_point_identity); // Signed int version RunVectorizedReduceTestForType( - [&](ComputationBuilder* builder) { + [&](XlaBuilder* builder) { return reduction_function_generator_for_type(S32, builder); }, reference_reduction_function_for_ints, signed_int_identity); // Unsigned int version RunVectorizedReduceTestForType( - [&](ComputationBuilder* builder) { + [&](XlaBuilder* builder) { return reduction_function_generator_for_type(U32, builder); }, reference_reduction_function_for_uints, unsigned_int_identity); @@ -434,8 +439,8 @@ XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { const int64 rows = 111, cols = 50; - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -465,8 +470,8 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { const int64 rows = 111, cols = 50; - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -497,28 +502,25 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { // Test that algebraic simplifier does not incorrectly fold a transpose into a // reduction operation. XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); - ComputationDataHandle input = builder.Parameter(0, input_shape, "input"); - ComputationDataHandle zero = builder.ConstantR0(0.0); - ComputationDataHandle transpose = - builder.Transpose(input, /*permutation=*/{1, 0, 2}); - ComputationDataHandle reduce = - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + XlaOp input = builder.Parameter(0, input_shape, "input"); + XlaOp zero = builder.ConstantR0(0.0); + XlaOp transpose = builder.Transpose(input, /*permutation=*/{1, 0, 2}); + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, reduce, {std::move(*input_data)}, - ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const int64 rows = 111, cols = 50; - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -564,7 +566,7 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) { // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); auto broadcasted = builder.Broadcast(scalar, {500, 500}); @@ -576,7 +578,7 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { // Max-reduces a broadcasted scalar matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); auto broadcasted = builder.Broadcast(scalar, {500, 500}); @@ -588,7 +590,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { // Max-reduces a matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); @@ -603,7 +605,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { // Min-reduces matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, MinReduce2DToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); @@ -618,7 +620,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { } XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto min = CreateScalarMinComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); @@ -631,7 +633,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { } XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto max = CreateScalarMaxComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); @@ -645,7 +647,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); @@ -656,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); @@ -666,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { - ComputationBuilder builder(client_, "reduce_among_y"); + XlaBuilder builder("reduce_among_y"); auto m = builder.ConstantLiteral(*literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); @@ -676,7 +678,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {1, 2}); @@ -686,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); @@ -696,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { } XLA_TEST_F(ReduceTest, ReduceR3ToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1, 2}); @@ -706,7 +708,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); @@ -721,7 +723,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); @@ -738,7 +740,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {2}); @@ -755,60 +757,64 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { } XLA_TEST_F(ReduceTest, VectorizedReduce_Add) { - RunVectorizedReduceTest(CreateScalarAddComputation, - [](float a, float b) { return a + b; }, - [](int32 a, int32 b) { - return static_cast(static_cast(a) + - static_cast(b)); - }, - [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); + RunVectorizedReduceTest( + static_cast(CreateScalarAddComputation), + [](float a, float b) { return a + b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) + + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); } XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) { - RunVectorizedReduceTest(CreateScalarMultiplyComputation, - [](float a, float b) { return a * b; }, - [](int32 a, int32 b) { - return static_cast(static_cast(a) * - static_cast(b)); - }, - [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); + RunVectorizedReduceTest( + static_cast(CreateScalarMultiplyComputation), + [](float a, float b) { return a * b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) * + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); } XLA_TEST_F(ReduceTest, VectorizedReduce_Max) { - RunVectorizedReduceTest(CreateScalarMaxComputation, - [](float a, float b) { return std::max(a, b); }, - [](int32 a, int32 b) { return std::max(a, b); }, - [](uint32 a, uint32 b) { return std::max(a, b); }, - std::numeric_limits::min(), - std::numeric_limits::min(), - std::numeric_limits::min()); + RunVectorizedReduceTest( + static_cast(CreateScalarMaxComputation), + [](float a, float b) { return std::max(a, b); }, + [](int32 a, int32 b) { return std::max(a, b); }, + [](uint32 a, uint32 b) { return std::max(a, b); }, + std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::min()); } XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { - RunVectorizedReduceTest(CreateScalarMinComputation, - [](float a, float b) { return std::min(a, b); }, - [](int32 a, int32 b) { return std::min(a, b); }, - [](uint32 a, uint32 b) { return std::min(a, b); }, - std::numeric_limits::max(), - std::numeric_limits::max(), - std::numeric_limits::max()); + RunVectorizedReduceTest( + static_cast(CreateScalarMinComputation), + [](float a, float b) { return std::min(a, b); }, + [](int32 a, int32 b) { return std::min(a, b); }, + [](uint32 a, uint32 b) { return std::min(a, b); }, + std::numeric_limits::max(), std::numeric_limits::max(), + std::numeric_limits::max()); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType( - CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); + static_cast(CreateScalarAndComputation), + [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType( - CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); + static_cast(CreateScalarOrComputation), + [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, public ::testing::WithParamInterface {}; XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const auto& bounds = GetParam().bounds; Array3D input_array(bounds[0], bounds[1], bounds[2]); // input_array.FillRandom(3.14f, 0.05); @@ -822,7 +828,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_activations = builder.Parameter(0, input_literal->shape(), "input"); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); auto sum = builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, GetParam().reduce_dims); @@ -862,8 +868,8 @@ INSTANTIATE_TEST_CASE_P( // IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on // 2017-07-26. XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { - ComputationBuilder builder(client_, TestName()); - Computation max_f32 = CreateScalarMaxComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); auto a = builder.ConstantR0(2.0f); auto a2 = builder.Abs(a); @@ -884,5 +890,78 @@ XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) { RunR2ToR1PredTest(/*and_reduce=false*/ false, /*rows=64*/ 64); } +// Tests reductions with different initial values. There's no test macro that +// combines TYPED_TEST and TYPED_P, so we have to do it manually. +class ReduceInitializerTest : public ReduceTest { + protected: + template + void DoTest(T initializer, int num_elems) { + XlaBuilder builder(TestName()); + XlaComputation max_fn = CreateScalarMaxComputation( + primitive_util::NativeToPrimitiveType(), &builder); + + auto init = builder.ConstantR0(initializer); + std::vector input_arr(num_elems, std::numeric_limits::lowest()); + auto input_literal = Literal::CreateR1(input_arr); + auto input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, + max_fn, {0}); + + ComputeAndCompareR0(&builder, initializer, {input_data.get()}); + } +}; + +XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest(42, 2); } + +XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest(42, 4096); } + +XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) { + DoTest(42, 4095); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) { + DoTest(0, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) { + DoTest(1, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { + DoTest(1234556789123, 1024); +} + +// Test the operational semantic that the init value is passed on the lhs for +// reduces. Can be tested by performing an "identity" reduce (that simply +// returns one of the parameters). In this case, we return the rhs, which for +// a 1D array with one element, should not be the init value. +XLA_TEST_F(ReduceTest, ReduceIdentity) { + XlaBuilder builder(TestName()); + Shape single_float = ShapeUtil::MakeShape(F32, {}); + builder.Parameter(0, single_float, "lhs-unused"); + builder.Parameter(1, single_float, "rhs-used"); + auto computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); + builder.Reduce(builder.Parameter(0, operand_shape, "operand"), + builder.Parameter(1, single_float, "init"), + computation_status.ValueOrDie(), {0}); + + float operand[] = {42.0f}; + float init = 58.5f; + float expected = 42.0f; + std::unique_ptr input_literal = Literal::CreateR1(operand); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr input_literal2 = Literal::CreateR0(init); + std::unique_ptr input_global_data2 = + client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + ComputeAndCompareR0( + &builder, expected, {input_global_data.get(), input_global_data2.get()}, + ErrorSpec(0.0001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index b11b64e40a582150d6adf29e915cd70b4bcb982b..266760e8202fddc48792ac66dda334255e428808 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -21,10 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -63,11 +64,9 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { class ReduceWindowTest : public ::testing::WithParamInterface, public ReduceWindowTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) { - set_use_bfloat16(GetParam()); - } + ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } - void ReduceWindowAdd(const ComputationDataHandle& input, + void ReduceWindowAdd(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -78,16 +77,17 @@ class ReduceWindowTest : public ::testing::WithParamInterface, window_dimensions, window_strides, padding); } - void ReduceWindowMax(const ComputationDataHandle& input, + void ReduceWindowMax(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); - builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, - window_strides, padding); + builder_.ReduceWindow(input, init, + CreateScalarMaxComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } - void ReduceWindowMin(const ComputationDataHandle& input, + void ReduceWindowMin(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -97,7 +97,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface, window_dimensions, window_strides, padding); } - ComputationBuilder builder_; + XlaBuilder builder_; }; TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { @@ -252,6 +252,48 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { DefaultErrorSpec()); } +// Tests the super windowing logic w.r.t handling prime number of windows in a +// major dimension with reduction. +TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { + Array4D input_array(15, 15, 4, 128); + input_array.FillRandom(2.f, 4.f); + + int win_len = 3; + int win_stride = 2; + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + +TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { + Array4D input_array(19, 17, 8, 256); + input_array.FillWithMinorDimNum(); + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + // Tests a reduction function that is not a simple add/min/max/etc. XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); @@ -268,7 +310,7 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto rhs = b->Parameter(1, scalar, "rhs"); b->Min(b->Add(lhs, rhs), CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); - Computation reduce_fn = b->BuildAndNoteError(); + XlaComputation reduce_fn = b->BuildAndNoteError(); builder_.ReduceWindow( input, @@ -296,7 +338,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -314,12 +356,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 1.0f; - }; - TF_EXPECT_OK(arg_literal->Populate(generator)); - + auto arg_literal = MakeUnique(shape); + arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); Padding padding = Padding::kValid; @@ -329,13 +367,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); - auto out_generator = - [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 27.0f; - }; - TF_EXPECT_OK(expected->Populate(out_generator)); - + auto expected = MakeUnique(result_shape); + expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -364,7 +397,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -386,7 +419,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -408,7 +441,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -509,7 +542,7 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = builder_.Broadcast( + XlaOp input = builder_.Broadcast( CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; @@ -568,7 +601,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } void DoIt() { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -579,7 +612,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -819,8 +852,7 @@ INSTANTIATE_TEST_CASE_P( class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {}; // TODO(b/72234705): Fix the test cases failed on CPU and GPU. -XLA_TEST_P(R4ReduceWindowAnyDimsTest, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { +XLA_TEST_P(R4ReduceWindowAnyDimsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { DoIt(); } @@ -920,7 +952,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, }; TEST_P(R3ReduceWindowTest, Add) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); @@ -931,7 +963,7 @@ TEST_P(R3ReduceWindowTest, Add) { Literal::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); auto init_value = @@ -960,45 +992,73 @@ struct R2ReduceWindowTestData { int64 base_bounds[2]; int64 window_bounds[2]; int64 strides[2]; + int64 pad_low[2]; + int64 pad_high[2]; int64 layout[2]; - Padding padding; Reducer reducer; } kR2TestCases[] = { {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 2}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, - /*strides=*/{2, 99}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, +// TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a +// ptxas bug. +#ifndef XLA_TEST_BACKEND_GPU {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, - /*strides=*/{5, 4}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, +#endif {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, - /*strides=*/{3, 3}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, - /*strides=*/{4, 5}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, - /*strides=*/{2, 2}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + // Regression test for b/73903312: bf16 lacks precision to store result of + // very large windows. Testing with a reasonable window larger than 128. + {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, }; string R2ReduceWindowTestDataToString( @@ -1008,10 +1068,11 @@ string R2ReduceWindowTestDataToString( string str = tensorflow::strings::StrCat( "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__padding_", param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", param.layout[0], "_", param.layout[1], // + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__layout_", param.layout[0], "_", param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { str = tensorflow::strings::StrCat(str, "_bfloat16"); @@ -1026,7 +1087,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } void DoIt() { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); @@ -1036,20 +1097,32 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); + std::vector> padding(2); + for (int i = 0; i < 2; ++i) { + padding[i] = {param.pad_low[i], param.pad_high[i]}; + } + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + auto expected = ReferenceUtil::ReduceWindow2DGeneric( + /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/padding); ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); @@ -1068,14 +1141,15 @@ class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; // TODO(b/72234705): Fix the test cases failed on CPU and GPU. XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { DoIt(); } const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -1211,7 +1285,7 @@ class R1ReduceWindowTest : public ReduceWindowTestBase, }; TEST_P(R1ReduceWindowTest, DoIt) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); @@ -1220,7 +1294,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1265,7 +1339,7 @@ INSTANTIATE_TEST_CASE_P( class ReduceWindowTextTest : public HloTestBase {}; TEST_F(ReduceWindowTextTest, R2General256x384) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1282,7 +1356,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1299,7 +1373,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= } TEST_F(ReduceWindowTextTest, R2General2x5) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1315,5 +1389,77 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } +TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { + const string hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[1,1]{1,0} parameter(0) + negate = f32[1,1]{1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { + const string hlo_string = R"( +HloModule R3Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R3Window { + operand = f32[1,1,1]{2,1,0} parameter(0) + negate = f32[1,1,1]{2,1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(HloTestBase, ReduceWindowIdentity) { + const string hlo_string = R"( +HloModule ReduceWindowIdentity +identity.pad_to_reduce_window { + param0 = f32[] parameter(0) + ROOT param1 = f32[] parameter(1) +} +ENTRY reduce-window-identity { + operand = f32[1,32,64]{2,1,0} parameter(0) + constant.4466 = f32[] constant(0) + ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + +TEST_F(HloTestBase, ReduceWindowS32) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] { + %param0 = s32[] parameter(0) + ROOT %param1 = s32[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { + %parameter.0 = s32[81,8]{1,0} parameter(0) + %parameter.1 = s32[] parameter(1) + ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 6d063ffc363c092a1fbc40cbc22e87181d0c2502..36d763b0f7f4267ede076c0b25cfaf9654e96e0d 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -38,17 +38,17 @@ class ReplayTest : public ClientLibraryTestBase {}; TEST_F(ReplayTest, TwoPlusTwoReplay) { // Make 2+2 computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto two = builder.ConstantR0(2); builder.Add(two, two); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. - std::unique_ptr module = + std::unique_ptr module = computation.Snapshot().ConsumeValueOrDie(); // Replay it. - Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); // Check signature is the same. std::unique_ptr original_shape = @@ -69,18 +69,18 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Make computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); builder.Add(x, y); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. - std::unique_ptr module = + std::unique_ptr module = computation.Snapshot().ConsumeValueOrDie(); // Replay it. - Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); // Check signature is the same. std::unique_ptr original_shape = @@ -109,24 +109,24 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { TEST_F(ReplayTest, MapPlusTwoOverR1) { // As above, but with map(+2) over some constant array. - ComputationBuilder plus_two_builder(client_, "plus two"); + XlaBuilder plus_two_builder("plus two"); auto input = plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); plus_two_builder.Add(input, plus_two_builder.ConstantR0(2)); - Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); + XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); - ComputationBuilder mapper_builder(client_, TestName()); + XlaBuilder mapper_builder(TestName()); auto original = mapper_builder.ConstantR1({1, 2, 3}); mapper_builder.Map({original}, plus_two, {0}); - Computation computation = mapper_builder.Build().ConsumeValueOrDie(); + XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie(); // Serialize it out. - std::unique_ptr module = + std::unique_ptr module = computation.Snapshot().ConsumeValueOrDie(); // Replay it. - Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); // Check signature is the same. std::unique_ptr original_shape = @@ -135,10 +135,6 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { client_->GetComputationShape(replayed).ConsumeValueOrDie(); ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); - // Destroy the originals. - computation.Reset(); - plus_two.Reset(); - // Run it. std::unique_ptr literal = client_ diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index e045e164e2e2db7d3480e7c2d1e20f461820ae67..da1b588ec41cef711412367e89b2a9b1029bca71 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -20,10 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -34,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +43,7 @@ namespace { using ReshapeMotionTest = ClientLibraryTestBase; TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{2, 3, 5}, {7, 11, 13}}); auto b = builder.ConstantR2({{17, 19}, {23, 29}, {31, 37}}); auto c = builder.Reshape(a, {6}); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index f7b04debd4f5c40a904e32c832b6fc384a03c33b..a4580cd71d46ad0a0186eddd51291f9c322b6f49 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -52,11 +52,11 @@ class ReshapeTest : public ::testing::WithParamInterface, // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -67,9 +67,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { } XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); @@ -80,9 +80,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { } XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); @@ -94,11 +94,11 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -111,15 +111,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); auto a = builder.Neg(parameter); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = Literal::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -130,10 +129,10 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(0, 3); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -146,11 +145,11 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -163,10 +162,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(3, 0); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -177,9 +176,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { // Collapses a 2-dimensional row vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -190,9 +189,9 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { // Collapses a 2-dimensional column vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial3x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -207,9 +206,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -221,10 +220,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { // Splits a vector into a matrix. XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -241,9 +240,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -255,10 +254,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { // Transposes a 2-dimensional row vector to a column vector. XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = Literal::CreateFromArray(*simple); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -272,10 +271,10 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { // Transposes a 2-dimensional array. XLA_TEST_P(ReshapeTest, TransposeAsReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -291,11 +290,11 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. // -// Transposes a 0x4 array with ComputationBuilder::Trans. +// Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -306,10 +305,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { // Transposes a 2-dimensional array with ComputationBuilder::Trans. XLA_TEST_P(ReshapeTest, Transpose4x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -327,9 +326,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -343,9 +342,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, @@ -358,10 +357,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -378,9 +377,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { // with an incorrect result rank. // XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -393,10 +392,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -420,9 +419,9 @@ static Array3D ArrayForDocR3Tests() { } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, @@ -435,9 +434,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { } XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, @@ -455,9 +454,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, @@ -470,9 +469,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { } XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, @@ -490,9 +489,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { } XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, @@ -520,12 +519,12 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // // 1 2 3 4 5 6 1 2 3 4 5 6 XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = Literal::CreateFromArray(t2x2x2x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); @@ -539,7 +538,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { // As above, but uses reshape directly. XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; t(0, 0, 0, 1) = 1; @@ -550,7 +549,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; auto input_literal = Literal::CreateFromArray(t); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, @@ -565,7 +564,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { // Reshape various ranks to a scalar. XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -573,7 +572,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { std::vector zeros(rank, 0); // this is {0, ..., 0}. input_literal.Set(zeros, 83.0f); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); b.Reshape(parameter, dimensions, {}); @@ -585,9 +584,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) { } XLA_TEST_P(ReshapeTest, BadDimensions) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input_literal = Literal::CreateR1({1.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); b.Reshape(parameter, {}, {}); @@ -597,9 +596,9 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { } XLA_TEST_P(ReshapeTest, BadNewSizes) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input_literal = Literal::CreateR1({1.0f, 2.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); b.Reshape(parameter, {1}, {}); @@ -608,7 +607,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { } XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { @@ -634,7 +633,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -645,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { {222, 333, 444, 555, 666, 777, 888, 999}, }); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, @@ -657,19 +656,19 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + expected = Literal::ConvertF32ToBF16(*expected); } - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -690,13 +689,13 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -716,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { } XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -726,19 +725,19 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); + Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -748,20 +747,20 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); + Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -771,7 +770,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, @@ -788,7 +787,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { } XLA_TEST_P(ReshapeTest, NoopReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -798,12 +797,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = @@ -818,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + auto expected = Literal::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -826,12 +825,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { } XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, @@ -845,8 +844,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, @@ -879,15 +878,15 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -908,15 +907,15 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -937,15 +936,15 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -967,15 +966,15 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -996,15 +995,15 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 8fc841f14087cdea02fe44cdaea521ff92122aec..e7bd142dc9ddefbd8bebfb77d72218d662645c31 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -85,7 +85,7 @@ TEST_P(FloatReverseTest, Reverses) { auto r1_literal = Literal::CreateR1(input_vector); auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = AddParam(*input_literal, &builder); builder.Rev(a, spec.reversal); @@ -114,7 +114,7 @@ class ReverseTest : public ClientLibraryTestBase {}; // Tests the reverse operation on a 4D U8 array on dimension 0 and 3. XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Input shape is U8[1x2x3x4]. // clang-format off Array4D input({{ @@ -144,7 +144,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { // Tests the reverse operation on a 4D float array on dimension 0 and 1. TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Input shape is float[4x3x2x1]. // clang-format off Array4D input({ diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 8cbfcc6f5c4272706a0f9fd809041516bf32432b..7cfca781acda15879075f4386c2096e537877aac 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index 32db45f8a66266712ba4091c2aa6368f0b822bd2..f334a8c1318a59bbfdd27dd1a63ed162600089ce 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase { client_->TransferToServer(original).ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data).ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(original, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 4da6ee91607941b395b00befc98a10e7c17746ed..308d3fc78a51e63c0e3db8c0cda18caf11f665bd 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,9 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -43,83 +44,80 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare(NativeT lhs, NativeT rhs, bool expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + void TestCompare( + NativeT lhs, NativeT rhs, bool expected, + XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(2.1f); ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(2.1f)); ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(2)); ComputeAndCompareR0(&builder, -2, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); ComputeAndCompareR0(&builder, 7, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const uint64 a = static_cast(1) << 63; const uint64 b = a + 1; builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); @@ -128,7 +126,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const int64 a = static_cast(1) << 62; const int64 b = a - 1; builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); @@ -137,7 +135,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(0.25), builder.ConstantR0(3.5)); @@ -145,25 +143,25 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); ComputeAndCompareR0(&builder, -3, {}); } XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); builder.ConvertElementType(a, F32); - int64 value = 3LL << 32; + int64 value = 3LL << 35; std::unique_ptr a_literal = Literal::CreateR0(value); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -172,7 +170,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.Mul(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)), builder.ConstantR0(0.5f)); @@ -191,7 +189,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { for (int32 x : data) { for (int32 y : data) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); // Signed integer overflow is undefined behavior in C++. Convert the input @@ -210,7 +208,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { for (uint32 x : data) { for (uint32 y : data) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); uint32 expected = x * y; @@ -220,7 +218,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul( builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), builder.ConstantR0(1)); @@ -229,7 +227,7 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR0(2.1f); std::unique_ptr b_literal = Literal::CreateR0(5.5f); std::unique_ptr c_literal = Literal::CreateR0(0.5f); @@ -241,9 +239,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { std::unique_ptr c_data = client_->TransferToServer(*c_literal).ConsumeValueOrDie(); - ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a"); - ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b"); - ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c"); + XlaOp a = builder.Parameter(0, a_literal->shape(), "a"); + XlaOp b = builder.Parameter(1, b_literal->shape(), "b"); + XlaOp c = builder.Parameter(2, c_literal->shape(), "c"); builder.Mul(builder.Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -252,14 +250,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); @@ -282,7 +280,7 @@ class DivS32Test : public ClientLibraryTestBase, XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Div(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); @@ -291,7 +289,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); @@ -300,9 +298,9 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividendd = CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = @@ -315,9 +313,9 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividendd = CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = @@ -364,13 +362,13 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; // clang-format on - Computation div_computation; + XlaComputation div_computation; { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle dividend = + XlaOp dividend = builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); - ComputationDataHandle divisor = + XlaOp divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Div(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); @@ -392,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend / divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } @@ -405,13 +403,13 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; // clang-format on - Computation rem_computation; + XlaComputation rem_computation; { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle dividend = + XlaOp dividend = builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); - ComputationDataHandle divisor = + XlaOp divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Rem(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); @@ -433,14 +431,14 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend % divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } } XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); @@ -450,7 +448,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). builder.Div(builder.ConstantR0(0xFFFFFFFE), @@ -460,7 +458,7 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); ComputeAndCompareR0(&builder, 2, {}); @@ -469,7 +467,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x && y, {}); @@ -480,7 +478,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) { XLA_TEST_F(ScalarComputationsTest, AndS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x & y, {}); @@ -491,7 +489,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) { XLA_TEST_F(ScalarComputationsTest, AndU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x & y, {}); @@ -502,7 +500,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) { XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x || y, {}); @@ -513,7 +511,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) { XLA_TEST_F(ScalarComputationsTest, OrS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x | y, {}); @@ -524,7 +522,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) { XLA_TEST_F(ScalarComputationsTest, OrU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x | y, {}); @@ -534,7 +532,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, !x, {}); @@ -543,7 +541,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) { XLA_TEST_F(ScalarComputationsTest, NotS32) { for (int32 x : {-1, 0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, ~x, {}); @@ -552,7 +550,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) { XLA_TEST_F(ScalarComputationsTest, NotU32) { for (uint32 x : {0, 1, 2}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, ~x, {}); @@ -560,7 +558,7 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) { } XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Select(builder.ConstantR0(true), // The predicate. builder.ConstantR0(123.0f), // The value on true. builder.ConstantR0(42.0f)); // The value on false. @@ -569,7 +567,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Select(builder.ConstantR0(false), // The predicate. builder.ConstantR0(123.0f), // The value on true. builder.ConstantR0(42.0f)); // The value on false. @@ -580,7 +578,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { // This test is an explicit version of what is happening in the following // templatized comparison tests. XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); ComputeAndCompareR0(&builder, true, {}); @@ -588,157 +586,156 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &ComputationBuilder::Eq); + TestCompare(2, 1, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &ComputationBuilder::Eq); + TestCompare(3, 3, true, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &ComputationBuilder::Ne); + TestCompare(2, 1, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &ComputationBuilder::Ge); + TestCompare(2, 1, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &ComputationBuilder::Gt); + TestCompare(1, 5, false, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &ComputationBuilder::Le); + TestCompare(2, 1, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(9, 7, false, &XlaBuilder::Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, - &ComputationBuilder::Lt); + std::numeric_limits::max(), true, &XlaBuilder::Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &ComputationBuilder::Eq); + TestCompare(2, 1, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &ComputationBuilder::Ne); + TestCompare(2, 1, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &ComputationBuilder::Ge); + TestCompare(2, 1, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &ComputationBuilder::Ge); + TestCompare(3, 3, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &ComputationBuilder::Gt); - TestCompare(5, 5, false, &ComputationBuilder::Gt); - TestCompare(5, 1, true, &ComputationBuilder::Gt); + TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(5, 5, false, &XlaBuilder::Gt); + TestCompare(5, 1, true, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &ComputationBuilder::Le); + TestCompare(2, 1, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(9, 7, false, &XlaBuilder::Lt); TestCompare(0, std::numeric_limits::max(), true, - &ComputationBuilder::Lt); + &XlaBuilder::Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &ComputationBuilder::Eq); + TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &ComputationBuilder::Ne); + TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &ComputationBuilder::Ge); + TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &ComputationBuilder::Ge); + TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &ComputationBuilder::Gt); + TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &ComputationBuilder::Le); + TestCompare(2.0, 1.2, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &ComputationBuilder::Lt); + TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &ComputationBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &ComputationBuilder::Lt); + TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &ComputationBuilder::Lt); + TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &ComputationBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &ComputationBuilder::Ge); + TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &ComputationBuilder::Ge); + TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Exp(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { - ComputationBuilder builder(client_, "log"); + XlaBuilder builder("log"); builder.Log(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tanh(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tanh(builder.ConstantR0(2.0)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -747,7 +744,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(2), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -756,7 +753,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(-5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -765,7 +762,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -774,7 +771,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(2), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -783,7 +780,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(0), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -792,7 +789,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(5.0f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -801,7 +798,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(2.5f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -810,7 +807,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(-5.0f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -819,58 +816,70 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &ComputationBuilder::Min); + TestMinMax(10, 3, 3, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &ComputationBuilder::Min); + TestMinMax(-100, 3, -100, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &ComputationBuilder::Max); + TestMinMax(10, 3, 10, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &ComputationBuilder::Max); + TestMinMax(-100, 3, 3, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &ComputationBuilder::Min); + TestMinMax(large, 3, 3, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &ComputationBuilder::Min); + TestMinMax(0, 5, 0, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &ComputationBuilder::Max); + TestMinMax(large, 3, large, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &ComputationBuilder::Max); + TestMinMax(0, 5, 5, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); +} + +XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { + SetFastMathDisabled(true); + TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); + TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); +} + +XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { + SetFastMathDisabled(true); + TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); + TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Div( b.Sub(b.Mul(b.ConstantR0(1), b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), @@ -883,7 +892,7 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Sub(b.Mul(b.ConstantR0(1), b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), b.Add(b.ConstantR0(7), b.ConstantR0(0)))), @@ -893,21 +902,20 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { } XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Literal zero_literal = Literal::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - ComputationDataHandle zero = - builder.Parameter(0, zero_literal.shape(), "zero"); + XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero"); builder.SqrtF32(zero); ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RoundScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Round(builder.ConstantR0(1.4f)); ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 9ee94b8571e5fc8789b60501462986967ce909a0..7015e5a6a31f506d30c2629d7735482cf354455a 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -50,7 +50,7 @@ class SelectAndScatterTest : public ClientLibraryTestBase, public ::testing::WithParamInterface { public: - SelectAndScatterTest() : builder_(client_, TestName()) { + SelectAndScatterTest() : builder_(TestName()) { // Create S32 GE and ADD computations for select and scatter respectively. ge_s32_ = CreateScalarGeComputation(S32, &builder_); add_s32_ = CreateScalarAddComputation(S32, &builder_); @@ -60,13 +60,13 @@ class SelectAndScatterTest min_f32_ = CreateScalarMinComputation(F32, &builder_); } - ComputationBuilder builder_; - Computation ge_s32_; - Computation add_s32_; - Computation ge_f32_; - Computation add_f32_; - Computation max_f32_; - Computation min_f32_; + XlaBuilder builder_; + XlaComputation ge_s32_; + XlaComputation add_s32_; + XlaComputation ge_f32_; + XlaComputation add_f32_; + XlaComputation max_f32_; + XlaComputation min_f32_; }; XLA_TEST_P(SelectAndScatterTest, ParamTest) { @@ -80,12 +80,11 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { s.FillRandom(12.0f); auto source = builder_.ConstantFromArray(s); - auto select_and_scatter = builder_.SelectAndScatter( - operand, ge_f32_, GetParam().window_dimensions, GetParam().window_strides, - GetParam().padding_type, source, builder_.ConstantR0(0.0f), - add_f32_); + builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, + GetParam().window_strides, GetParam().padding_type, + source, builder_.ConstantR0(0.0f), add_f32_); - ComputeAndCompare(&builder_, select_and_scatter, {}, ErrorSpec(1e-5)); + ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5)); } INSTANTIATE_TEST_CASE_P( @@ -252,6 +251,21 @@ XLA_TEST_F(SelectAndScatterTest, R2S32) { ComputeAndCompareR2(&builder_, expected, {}); } +// Test for tie breaking rule in ge_f32_. When a tie is present, the operand +// that has the lower lexicographical order (smaller index) should be chosen. +XLA_TEST_F(SelectAndScatterTest, R2F32Tie) { + const auto operand = builder_.ConstantR2( + {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); + const auto source = builder_.ConstantR2( + {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + Array2D expected( + {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); +} + // Similar to SelectAndScatterTest.R2S32 but the input is transposed. XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { const auto operand = builder_.ConstantR2( diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 009e7d24c5cbface4da910e2366db1ff749d5d68..72707f224446c7585d1d90ac6681a7b38c41d5f1 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -16,13 +16,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -35,7 +34,7 @@ class SelectTest : public ClientLibraryTestBase { }; TEST_F(SelectTest, SelectScalarF32True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR0(123.0f); auto on_false = builder.ConstantR0(42.0f); @@ -45,7 +44,7 @@ TEST_F(SelectTest, SelectScalarF32True) { } TEST_F(SelectTest, SelectScalarS32True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR0(-42); auto on_false = builder.ConstantR0(42); @@ -55,7 +54,7 @@ TEST_F(SelectTest, SelectScalarS32True) { } TEST_F(SelectTest, SelectScalarF32False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto on_true = builder.ConstantR0(123.0f); auto on_false = builder.ConstantR0(42.0f); @@ -65,7 +64,7 @@ TEST_F(SelectTest, SelectScalarF32False) { } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR1({}); auto on_true = builder.ConstantR1({}); auto on_false = builder.ConstantR1({}); @@ -75,7 +74,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR1({false, true, false, true, false}); auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); @@ -88,7 +87,7 @@ TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({}); auto v2 = builder.ConstantR1({}); auto cmp = builder.Eq(v1, v2); @@ -102,7 +101,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); auto cmp = builder.Eq(v1, v2); @@ -116,7 +115,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); auto cmp = builder.Gt(v1, v2); @@ -131,9 +130,9 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { // Selects among two R1F32s, which come from parameters. v1 and v2 are // compared, and selection between them happens based on a gt-comparison mask. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -151,7 +150,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the // data size passed in and out is large. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Number of floats in the data passed into and out of the computation. constexpr int datalen = 15 * 1000; @@ -174,7 +173,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { expected_vec.push_back(larger); } - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -192,7 +191,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, -1, 2, -2}); auto s = builder.ConstantR0(0); auto cmp = builder.Gt(v, s); @@ -209,7 +208,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); auto s = builder.ConstantR0(2.5f); auto cmp = builder.Gt(v, s); @@ -225,7 +224,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(which); auto on_true = builder.ConstantR1({}); auto on_false = builder.ConstantR1({}); @@ -236,7 +235,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { } TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR1({-2.5f, 25.5f}); auto on_false = builder.ConstantR1({10.0f, 5.0f}); @@ -246,7 +245,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto on_true = builder.ConstantR1({-2.5f, 25.5f}); auto on_false = builder.ConstantR1({10.0f, 5.0f}); diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc deleted file mode 100644 index 29f79ec28a1ae6fcd5299846e85eec992ad2e46f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class SetReturnValueTest : public ClientLibraryTestBase {}; - -TEST_F(SetReturnValueTest, NoSetValue) { - ComputationBuilder builder(client_, "no_set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - - std::vector expected = {1.0, 3.0, 4.0, 0.0, -1.0, - 5.0, 6.0, -2.0, -3.0, 7.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValue) { - ComputationBuilder builder(client_, "set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueAndModify) { - ComputationBuilder builder(client_, "set_value_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { - ComputationBuilder builder(client_, "set_value_multiple_times_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(aax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaaax = builder.Add(alpha, aaax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index ac163df127e0087c02777fa3d5ce7970c51b97b9..5653bf11a7364bf9ed79bcb6b53f7db31f454803 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -41,7 +41,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); @@ -54,7 +54,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); @@ -67,7 +67,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); @@ -77,7 +77,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { } XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); @@ -85,7 +85,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); @@ -93,7 +93,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); @@ -108,7 +108,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); @@ -126,7 +126,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { Array2D values(1, 4096); std::iota(values.data(), values.data() + 4096, 0.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); @@ -147,7 +147,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -159,7 +159,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); @@ -172,7 +172,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { /*strides=*/{{1, 1, 2, 1}}); auto expected_literal = Literal::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), @@ -193,24 +193,33 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - std::vector input(spec.input_dim0); + // This can't be an std::vector, since you can't grab an ArraySlice of a + // vector. + tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); + auto literal = Literal::CreateR1(input); - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(input); + XlaBuilder builder(TestName()); + auto original = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); - std::vector expected; + // Ditto. + tensorflow::gtl::InlinedVector expected; for (int i = spec.slice_start; i < spec.slice_limit; i += spec.slice_stride) { expected.push_back(i); } - ComputeAndCompareR1(&builder, expected, {}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); + ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; +// A version of SliceR1Test used to label and disable 'large' tests +class SliceR1LargeTest : public SliceR1Test {}; + string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, @@ -230,6 +239,21 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } + + // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off @@ -267,13 +291,32 @@ INSTANTIATE_TEST_CASE_P( R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, - R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1} + ), + SliceR1TestDataToString +); + // TODO(b/69425338): This uses too much memory on GPU. #ifndef XLA_TEST_BACKEND_GPU - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +INSTANTIATE_TEST_CASE_P( + SliceR1TestBigSlicesInstantiation, + SliceR1LargeTest, + ::testing::Values( + R1Spec{ + 16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{ + 16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{ + 16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1} + ), + SliceR1TestDataToString +); #endif + +INSTANTIATE_TEST_CASE_P( + SliceStridedR1TestInstantiation, + SliceR1Test, + ::testing::Values( R1Spec{10, 2, 4, 2}, R1Spec{10, 0, 10, 2}, R1Spec{10, 0, 10, 3}, @@ -285,8 +328,24 @@ INSTANTIATE_TEST_CASE_P( R1Spec{2047, 1024 - 24, 1024 + 160, 31}, R1Spec{2047, 1, 2046, 3 * 128}, R1Spec{4096, 1024 + 3, 4095, 500}, - R1Spec{8192, 0, 8192, 1024 * 3 + 400} - ), + R1Spec{8192, 0, 8192, 1024 * 3 + 400}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 2}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 8}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 7}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 125}, + R1Spec{1024 * 1024, 3, 1024 - 9, 2}, + R1Spec{1024 * 1024, 3, 1024 - 9, 8}, + R1Spec{1024 * 1024, 3, 1024 - 9, 7}, + R1Spec{1024 * 1024, 3, 1024 - 9, 125}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 2}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 8}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 7}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 125}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125} + ), SliceR1TestDataToString ); // clang-format on @@ -309,15 +368,18 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2DWithLayout( + auto literal = Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(spec.layout)); + + XlaBuilder builder(TestName()); + auto a = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR2(&builder, *expected, {}); + ComputeAndCompareR2(&builder, *expected, {arg.get()}); } INSTANTIATE_TEST_CASE_P( @@ -397,10 +459,10 @@ class SliceR4Test : public ClientLibraryTestBase, void Run(const R4Spec& spec) { Array4D values(spec.input_dims[0], spec.input_dims[1], spec.input_dims[2], spec.input_dims[3]); - values.FillRandom(3.14f); + values.FillIota(3.14159); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto literal = Literal::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = builder.Parameter(0, literal->shape(), "p0"); diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 978a669bcab720bddec5c4bcd0144810ba3c8477..be35ec6c6ee4c015755622b2dc9bb92e23af7c85 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index e2d406f66d94f8ec76faa5b7d2d2e84dcaf6db57..7ca99a91635e85cd0888e59ecde31e47fec21844 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #define DISABLED_ON_CPU(X) X -#define DISABLED_ON_CPU_PARALLEL(X) X #define DISABLED_ON_GPU(X) X #define DISABLED_ON_INTERPRETER(X) X @@ -51,13 +50,6 @@ limitations under the License. # define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X) #endif // XLA_TEST_BACKEND_CPU -#ifdef XLA_TEST_BACKEND_CPU_PARALLEL -# undef DISABLED_ON_CPU -# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X) -# undef DISABLED_ON_CPU_PARALLEL -# define DISABLED_ON_CPU_PARALLEL(X) XLA_TEST_PASTE(DISABLED_, X) -#endif // XLA_TEST_BACKEND_CPU_PARALLEL - #ifdef XLA_TEST_BACKEND_GPU # undef DISABLED_ON_GPU # define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X) diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0bc7df2a65b44a76f877b6513e6bf93b99fbc1a3..dd7c541733634213606b5a7983b59bb1f14bf75c 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -23,14 +23,15 @@ namespace xla { namespace { -template -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { +template +void PopulateWithRandomFloatingPointDataImpl(Literal* literal, + std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. - std::uniform_real_distribution generator(1.0f, 1.125f); + std::uniform_real_distribution generator(1.0f, 1.125f); const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice indices) { @@ -52,15 +53,30 @@ void PopulateWithRandomFloatingPointData(Literal* literal, FloatT index_bias = static_cast(index_product % 113 - negative_bias) / static_cast(256.0f); - return (generator(*engine) - 1.0625) + index_bias; + return static_cast(generator(*engine) - 1.0625f) + index_bias; })); } +template +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + CHECK(engine != nullptr); + PopulateWithRandomFloatingPointDataImpl(literal, engine); +} + +template <> +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + CHECK(engine != nullptr); + PopulateWithRandomFloatingPointDataImpl(literal, engine); +} + // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( @@ -72,6 +88,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution generator( @@ -95,11 +112,17 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } - std::unique_ptr literal = Literal::CreateFromShape(shape); + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); break; + case F16: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; case F32: PopulateWithRandomFloatingPointData(literal.get(), engine); break; @@ -145,27 +168,38 @@ StatusOr> MakeFakeLiteralInternal( return std::move(literal); } -// Matches binary addition computations. -bool LooksLikeSum(const HloComputation& computation) { +enum class ConstantType { kUnknown, kZero, kOne }; + +// Return the constant type required by this computation, if known. +ConstantType GetInitValue(const HloComputation& computation) { const HloInstruction* const root = computation.root_instruction(); - return root->opcode() == HloOpcode::kAdd && - computation.num_parameters() == 2 && - root->operand(0)->opcode() == HloOpcode::kParameter && - root->operand(1)->opcode() == HloOpcode::kParameter && - root->operand(0) != root->operand(1); + if (computation.num_parameters() != 2 || root->operand_count() != 2 || + root->operand(0)->opcode() != HloOpcode::kParameter || + root->operand(1)->opcode() != HloOpcode::kParameter || + root->operand(0) == root->operand(1)) { + return ConstantType::kUnknown; + } + + switch (root->opcode()) { + case HloOpcode::kAdd: + return ConstantType::kZero; + case HloOpcode::kMultiply: + return ConstantType::kOne; + default: + return ConstantType::kUnknown; + } } -// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, -// which requires an init_value of 0 rather than a random value. -bool NeedsZeroInitValue(const HloUse& use) { +// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random +// initialization value. +bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; return ( ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1 && LooksLikeSum(*instruction->to_apply())) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && - LooksLikeSum(*instruction->scatter()))); + op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); } // Generate random values that are constrained to the input_shape minus the @@ -175,11 +209,13 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1(start_indices); } @@ -207,7 +243,7 @@ std::vector FindConstrainedUses( auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), fused_uses.end()); - } else if (NeedsZeroInitValue(use)) { + } else if (NeedsInitValue(use)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kConvert || opcode == HloOpcode::kReducePrecision) { @@ -228,7 +264,8 @@ StatusOr> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { HloInstruction* needs_index = nullptr; - HloInstruction* needs_zero = nullptr; + HloInstruction* needs_constant = nullptr; + ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: @@ -243,8 +280,13 @@ StatusOr> CreateLiteralForConstrainedUses( case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + needs_constant = use; + constant_type = GetInitValue(*use->to_apply()); + break; + case HloOpcode::kSelectAndScatter: - needs_zero = use; + needs_constant = use; + constant_type = GetInitValue(*use->scatter()); break; default: @@ -253,17 +295,26 @@ StatusOr> CreateLiteralForConstrainedUses( use->ToString().c_str()); } } - if (needs_index != nullptr && needs_zero != nullptr) { + if (needs_index != nullptr && needs_constant != nullptr) { return Unimplemented( "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "zero: %s\n", - needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + "constant: %s\n", + needs_index->ToString().c_str(), needs_constant->ToString().c_str()); } if (needs_index != nullptr) { return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), needs_index->shape(), engine); - } else if (needs_zero != nullptr) { - return Literal::CreateFromShape(param.shape()); + } else if (needs_constant != nullptr) { + switch (constant_type) { + case ConstantType::kZero: + return Literal::Zero(param.shape().element_type()).CloneToUnique(); + case ConstantType::kOne: + return Literal::One(param.shape().element_type()).CloneToUnique(); + case ConstantType::kUnknown: + // We want the identity element for the computation, but we don't really + // know what it is - so any value we generate will be just as wrong. + return MakeFakeLiteralInternal(param.shape(), engine); + } } else { return MakeFakeLiteralInternal(param.shape(), engine); } @@ -280,27 +331,27 @@ StatusOr> MakeConstrainedArgument( } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } -Status VerifyHloModule(const perftools::gputools::Platform& platform, - HloModule* const module) { - return HloVerifier().Run(module).status(); +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { + return HloVerifier(allow_mixed_precision).Run(module).status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2..a8689f64981569ceb7c8a712f8ece00c99e8cf2d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -55,21 +55,33 @@ class PseudorandomGenerator { }; // Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random. +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // // Will handle special cases such as making sure that indices used for dynamic // slices are bounded, reduces that call adds use 0 as an init value, etc. +// +// If pseudo_random is true, the generated numbers will be generated +// deterministically in a pseudo random way unless the values are constrated to +// be e.g. init values as above. If pseudo_random is false, the returned values +// will be generated in a faster way that yields less interesting data, e.g. the +// values may all be just the same value. +// +// TODO(b/79942829): Make interesting argument generation fast enough that using +// pseudo_random does not save any noticeable amount of time so that the +// parameter can be removed. StatusOr>> MakeFakeArguments( - HloModule* const module); + HloModule* const module, bool pseudo_random = true); // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(const perftools::gputools::Platform& platform, - HloModule* const module); +Status VerifyHloModule(HloModule* const module, + bool allow_mixed_precision = false); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..59afd28a80c0fbf3df38457cd05961c883769856 --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_utils.h" + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +// A test fixture is used because we need a client for our computation builder. +class TestUtilsTest : public LocalClientTestBase {}; + +XLA_TEST_F(TestUtilsTest, UnusedParam) { + XlaBuilder builder(TestName()); + // Make the reduction lambda. + Shape single_float = ShapeUtil::MakeShape(F32, {}); + builder.Parameter(0, single_float, "unused"); + builder.Parameter(1, single_float, "used"); + auto computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + // Make the reduction. + Shape pair_float = ShapeUtil::MakeShape(F32, {2}); + builder.Reduce(builder.Parameter(0, pair_float, "operand"), + builder.Parameter(1, single_float, "init"), + computation_status.ValueOrDie(), {0}); + computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + auto executable_status = local_client_->Compile( + computation_status.ValueOrDie(), {&pair_float, &single_float}, + ExecutableBuildOptions()); + TF_ASSERT_OK(executable_status.status()); + HloModule& module = const_cast( + executable_status.ValueOrDie()->executable()->module()); + TF_ASSERT_OK(MakeFakeArguments(&module).status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4585244ce81c14ab6d4d629bb7d208d73c82248d --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +// TODO(b/79770375): Compile, not just verify the HLO module when the backends +// support kGenerateToken. +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Entry root is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 268ba338f2e6740a1d1a046d5a85494f3cf2e9f8..0063e7ad415e9b6718c164f415ced6fb76cbf44a 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -45,7 +45,7 @@ class TransferManagerTest : public LocalClientTestBase { ~TransferManagerTest() override = default; - std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { + ScopedShapedBuffer AllocateDeviceBuffer(const Shape& shape) { return transfer_manager_ ->AllocateScopedShapedBuffer( shape, GetOrCreateAllocator(local_client_->platform()), @@ -64,10 +64,10 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR0Equal(42, *result); } @@ -80,10 +80,10 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, *result); @@ -98,10 +98,10 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR1Equal(test_vector, *result); } @@ -114,10 +114,10 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); EXPECT_EQ(result->GetR1U8AsString(), test_string); } @@ -130,10 +130,10 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); @@ -150,10 +150,10 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); EXPECT_FALSE( LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); @@ -170,12 +170,12 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { @@ -184,12 +184,12 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { @@ -204,12 +204,12 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { @@ -219,12 +219,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { @@ -238,12 +238,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe5a1778a2cecff0121cee4d8b406c5b23a13e40..fe1e3da7eca00e128377e6e56af877868aafa836 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -16,14 +16,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -38,7 +37,7 @@ class TransposeTest : public ClientLibraryTestBase { }; XLA_TEST_F(TransposeTest, Transpose0x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); auto result = builder.Transpose(lhs, {1, 0}); @@ -46,7 +45,7 @@ XLA_TEST_F(TransposeTest, Transpose0x0) { } XLA_TEST_F(TransposeTest, Transpose0x42) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); auto result = builder.Transpose(lhs, {1, 0}); @@ -54,7 +53,7 @@ XLA_TEST_F(TransposeTest, Transpose0x42) { } XLA_TEST_F(TransposeTest, Transpose7x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); auto result = builder.Transpose(lhs, {1, 0}); @@ -62,7 +61,7 @@ XLA_TEST_F(TransposeTest, Transpose7x0) { } TEST_F(TransposeTest, Transpose2x2) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2({ {1.0, 2.0}, {3.0, 4.0}, }); @@ -74,7 +73,7 @@ TEST_F(TransposeTest, Transpose2x2) { } XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); auto result = builder.Transpose(operand, {1, 2, 0}); @@ -82,7 +81,7 @@ XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { } TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {1, 2, 0}); @@ -92,7 +91,7 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { } TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {2, 1, 0}); @@ -102,7 +101,7 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { } TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {0, 1, 2}); @@ -116,7 +115,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { Array2D transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}}); for (int transposes = 0; transposes <= 10; ++transposes) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto computed = builder.ConstantR2FromArray2D(input); for (int i = 0; i < transposes; ++i) { computed = builder.Transpose(computed, {1, 0}); @@ -130,7 +129,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { TEST_F(TransposeTest, Small_1x1) { auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); - ComputationBuilder builder(client_, "transpose_1x1"); + XlaBuilder builder("transpose_1x1"); auto operand = builder.ConstantR2FromArray2D(*aoperand); builder.Transpose(operand, {1, 0}); @@ -142,7 +141,7 @@ TEST_F(TransposeTest, Small_1x1) { TEST_F(TransposeTest, Small_2x2) { auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); - ComputationBuilder builder(client_, "transpose_2x2"); + XlaBuilder builder("transpose_2x2"); auto operand = builder.ConstantR2FromArray2D(*aoperand); builder.Transpose(operand, {1, 0}); @@ -162,7 +161,7 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR3FromArray3D(aoperand); builder.Transpose(operand, {0, 2, 1}); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 2029312f94a14bc81706368b9ecfc2727fd9fe4c..41189231b90e842292830a932cf381af60456d4c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -17,14 +17,15 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +41,7 @@ class TupleTest : public ClientLibraryTestBase { // Tests a tuple-shaped constant. XLA_TEST_F(TupleTest, TupleConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -53,13 +54,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { Literal::CreateR1(constant_vector).get(), Literal::CreateR2(constant_matrix).get()}); - auto result = builder.ConstantLiteral(*value); + builder.ConstantLiteral(*value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } // Tests a tuple made of scalar constants. XLA_TEST_F(TupleTest, TupleScalarConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; @@ -67,13 +68,13 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), Literal::CreateR0(constant_scalar2).get()}); - auto result = builder.ConstantLiteral(*value); + builder.ConstantLiteral(*value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -81,9 +82,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto result = builder.Tuple({builder.ConstantR0(constant_scalar), - builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); + builder.Tuple({builder.ConstantR0(constant_scalar), + builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); auto expected = Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), @@ -94,9 +95,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - auto result = builder.Tuple( + builder.Tuple( {builder.ConstantR0(7.0), builder.ConstantR1({})}); auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), @@ -106,15 +107,15 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { // Tests the creation of an empty tuple. XLA_TEST_F(TupleTest, EmptyTupleCreate) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.Tuple({}); + XlaBuilder builder(TestName()); + builder.Tuple({}); auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElement) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -122,23 +123,23 @@ XLA_TEST_F(TupleTest, GetTupleElement) { }; auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + builder.GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto tuple_data = builder.Tuple( {builder.ConstantR1({}), builder.ConstantR2FromArray2D(Array2D(0, 101))}); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + builder.GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto value = builder.ConstantR1({4.5f}); builder.GetTupleElement(value, 1); auto result_status = builder.Build(); @@ -151,7 +152,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { // Extracts both elements from a tuple with GetTupleElement and then adds them // together. XLA_TEST_F(TupleTest, AddTupleElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -163,22 +164,22 @@ XLA_TEST_F(TupleTest, AddTupleElements) { auto matrix_element = builder.GetTupleElement(tuple_data, 1); auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); - auto result = builder.Add(matrix_element, vector_element, - /*broadcast_dimensions=*/{1}); + builder.Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); Array2D expected({ {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } // Extracts both elements from a tuple and then puts them into a new tuple in // the opposite order. XLA_TEST_F(TupleTest, TupleGTEToTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -186,8 +187,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { }; auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); + builder.Tuple({builder.GetTupleElement(tuple_data, 1), + builder.GetTupleElement(tuple_data, 0)}); auto expected = Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), Literal::CreateR1(constant_vector).get()}); @@ -195,8 +196,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle v1, v2; + XlaBuilder b(TestName()); + XlaOp v1, v2; for (bool direction : {false, true}) { std::unique_ptr v1_data = @@ -209,7 +210,7 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v2_gt = b.Gt(v2, v1); // true auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} - auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = Literal::MakeTuple({Literal::CreateR0(direction).get(), Literal::CreateR0(!direction).get()}); @@ -236,7 +237,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { // \ (tuple10)-- / // \ / \ / // -----(GTE 0)-- --(GTE 1)---------- - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -256,8 +257,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { auto addvectors = builder.Add(vector_from_01, vector_from_10); auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); - auto result = builder.Add(addmatrices, addvectors, - /*broadcast_dimensions=*/{1}); + builder.Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); Array2D expected({ {4.f, 8.f, 12.f}, // row 0 @@ -266,9 +267,9 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { // Tests a selection between tuples with "false" path taken. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -277,21 +278,20 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { auto tuple21 = builder.Tuple( {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + builder.Select(builder.ConstantR0(false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { - Computation tuple_computation; + XlaComputation tuple_computation; { // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples. // // Need to put a select in there to prevent HLO-level optimizations from // optimizing out the tuples. - ComputationBuilder b(client_, "sort_square"); + XlaBuilder b("sort_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto x2 = b.Mul(x, x); auto x_smaller_tuple = b.Tuple({x, x2}); @@ -305,15 +305,15 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { tuple_computation = computation_status.ConsumeValueOrDie(); } - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); b.Map({input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } -XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { // Tests a selection between tuples with "true" path taken. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -322,8 +322,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { auto tuple21 = builder.Tuple( {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - auto select = - builder.Select(builder.ConstantR0(true), tuple12, tuple21); + builder.Select(builder.ConstantR0(true), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -332,7 +331,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { // Tests a selection between tuples but the final result is an element of the // tuple, not the whole tuple. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -343,13 +342,13 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto element = builder.GetTupleElement(select, 0); + builder.GetTupleElement(select, 0); ComputeAndCompareR1(&builder, vec2, {}, error_spec_); } // Cascaded selects between tuple types. -XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) { // // vec1 vec2 vec2 vec1 // | | | | @@ -367,7 +366,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { // / --(GTE 1)-- // / // (tuple 21) - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -383,17 +382,16 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); auto select2 = builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); - auto result = builder.Add(builder.GetTupleElement(select2, 0), - builder.GetTupleElement(select2, 1)); + builder.Add(builder.GetTupleElement(select2, 0), + builder.GetTupleElement(select2, 1)); ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } -XLA_TEST_F(TupleTest, - DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { // Similar to SelectBetweenTuples, but the constants are shared between the // input tuples. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -402,19 +400,18 @@ XLA_TEST_F(TupleTest, auto tuple12 = builder.Tuple({c1, c2}); auto tuple21 = builder.Tuple({c2, c1}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto inner_tuple = builder.Tuple( {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); - auto outer_tuple = - builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); auto expected_v1 = Literal::CreateR1({1.0, 2.0}); auto expected_s = Literal::CreateR0(42.0); @@ -428,7 +425,7 @@ XLA_TEST_F(TupleTest, NestedTuples) { } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {3}); Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); @@ -459,7 +456,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { } XLA_TEST_F(TupleTest, ComplexTuples) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape c64r0 = ShapeUtil::MakeShape(C64, {}); Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); @@ -498,7 +495,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = Literal::CreateFromShape(sum->shape()); + auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -514,5 +511,30 @@ XLA_TEST_F(TupleTest, ComplexTuples) { error_spec_); } +class TupleHloTest : public HloTestBase {}; + +// Disabled on the interpreter because bitcast doesn't exist on the interpreter. +XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { + const char* testcase = R"( + HloModule m + + ENTRY test { + name.1 = (f32[3]{0}) parameter(0) + get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0 + bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1) + copy = f32[1,3]{1,0} copy(bitcast) + ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 835e2d7e5594d7c8c6e523f9806e32dce23a87e9..c3abe22797f5eaa76ced2ad8534bd68c32983e60 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -37,7 +37,7 @@ class UnaryOpTest : public ClientLibraryTestBase { } template void AbsSize0TestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({}); auto abs = builder.Abs(arg); @@ -50,7 +50,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); auto abs = builder.Abs(arg); @@ -59,7 +59,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); auto sign = builder.Sign(arg); @@ -69,7 +69,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignAbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({-2, 25, 0, -123}); auto sign = builder.Sign(arg); auto abs = builder.Abs(arg); @@ -84,9 +84,14 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +int64 UnaryOpTest::inf() { + return 0x7FFFFFFFFFFFFFFFl; +} + template <> void UnaryOpTest::AbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, @@ -102,7 +107,7 @@ void UnaryOpTest::AbsTestHelper() { template <> void UnaryOpTest::SignTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); auto sign = builder.Sign(arg); @@ -114,7 +119,7 @@ void UnaryOpTest::SignTestHelper() { template <> void UnaryOpTest::SignAbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); auto sign = builder.Sign(arg); @@ -139,7 +144,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) { } XLA_TEST_F(UnaryOpTest, AbsTestR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto argi = builder.ConstantR0(-5); auto absi = builder.Abs(argi); auto argf = builder.ConstantR0(-3.0f); @@ -155,7 +160,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) { } XLA_TEST_F(UnaryOpTest, SignTestR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto argi = builder.ConstantR0(-5); auto sgni = builder.Sign(argi); // -1 auto argf = builder.ConstantR0(-4.0f); @@ -176,6 +181,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); + SignTestHelper(); SignTestHelper(); SignTestHelper(); } @@ -187,7 +193,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { } XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {2, 25, 0, 123, std::numeric_limits::max()}); auto abs = builder.Abs(arg); @@ -197,7 +203,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { } XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {2, 25, 0, 123, std::numeric_limits::max()}); auto sign = builder.Sign(arg); @@ -206,7 +212,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); auto sign = builder.Sign(arg); auto abs = builder.Abs(arg); @@ -216,7 +222,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 1}); auto rhs = builder.ConstantR1({1, 1}); builder.ConvertElementType(builder.Eq(lhs, rhs), S32); @@ -225,7 +231,7 @@ XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 1}); auto rhs = builder.ConstantR1({1, 1}); builder.ConvertElementType(builder.Eq(lhs, rhs), F32); diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 32ba067a10df6c15348344da813e6a960f05491c..82d301983fc7885ef5c1c1ed05b74fc017bb7727 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -33,9 +33,9 @@ namespace { class VecOpsReduceTest : public ClientLibraryTestBase { public: - VecOpsReduceTest() : builder_(client_, TestName()) {} + VecOpsReduceTest() : builder_(TestName()) {} - ComputationDataHandle BuildSampleConstantCube() { + XlaOp BuildSampleConstantCube() { // clang-format off Array3D x3d({ {{1.0, 2.0, 3.0}, // | dim 1 // } plane 0 in dim 0 @@ -49,7 +49,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase { return builder_.ConstantR3FromArray3D(x3d); } - ComputationBuilder builder_; + XlaBuilder builder_; ErrorSpec errspec_{1e-3, 0}; }; diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index b52c718814d4ffeff68c60588a6637a2159d57e5..5cce7a2bf82c1a8403536a91e67910f949ef185a 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -39,7 +39,7 @@ namespace { class VecOpsSimpleTest : public ClientLibraryTestBase { public: - explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr) + explicit VecOpsSimpleTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -49,7 +49,7 @@ class VecOpsSimpleTest : public ClientLibraryTestBase { }; XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto exp = builder.Exp(x); @@ -63,7 +63,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector exponents; exponents.reserve(count); for (int i = 0; i < count; ++i) { @@ -84,7 +84,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { } XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D exponents(2, 2, 2, 2); std::vector exponents_vector; @@ -106,7 +106,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { } XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); builder.Neg(x); @@ -117,7 +117,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { } XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); builder.Neg(x); @@ -126,7 +126,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { } XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {0, 1, 42, static_cast(-1), static_cast(-12)}); builder.Neg(x); @@ -136,7 +136,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { } XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); builder.SquareF32(x); @@ -147,7 +147,7 @@ XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { } XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); builder.ReciprocalF32(x); @@ -159,7 +159,7 @@ XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { } XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0.0, -0.0}); auto exp = builder.SqrtF32(x); @@ -167,7 +167,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { } XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); auto exp = builder.SqrtF32(x); @@ -176,7 +176,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { } XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); auto exp = builder.Pow(x, builder.ConstantR0(-.5f)); @@ -188,7 +188,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { } XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); auto x = builder.ConstantR1( @@ -203,7 +203,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { } XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1( @@ -218,8 +218,8 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { // Similar to MaxTenValues, except that the inputs come from params rather // than constants. - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle v1, v2; + XlaBuilder builder(TestName()); + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -236,7 +236,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { // Similar to MaxTenValuesFromParams, except that the data size passed in and // out is large. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Number of floats in the data passed into and out of the computation. constexpr int datalen = 15 * 1000; @@ -259,7 +259,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { expected_vec.push_back(larger); } - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -274,7 +274,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { } XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR0(0); @@ -286,7 +286,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { } XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1( @@ -299,7 +299,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { } XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(1); auto x = builder.ConstantR1( @@ -312,7 +312,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { } XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(1); auto x = builder.ConstantR1( @@ -325,7 +325,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { } XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR1({0.0f, 0.0f}); auto one = builder.ConstantR1({1.0f, 1.0f}); auto x = builder.ConstantR1({2.1, -2.6}); @@ -336,7 +336,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { } XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto one = builder.ConstantR0(1); auto two = builder.ConstantR0(2); auto x = builder.ConstantR1( @@ -348,11 +348,22 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { + XlaBuilder builder(TestName()); + auto zero = builder.ConstantR0(0); + auto one = builder.ConstantR0(10); + auto x = builder.ConstantR1({-3, 3, 9, 13}); + auto clamp = builder.Clamp(zero, x, one); + + std::vector expected = {0, 3, 9, 10}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { - Computation add_half; + XlaComputation add_half; { // add_half(x) = x + 0.5 - ComputationBuilder builder(client_, "add_half"); + XlaBuilder builder("add_half"); auto x_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); auto half = builder.ConstantR0(0.5); @@ -362,10 +373,10 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { add_half = computation_status.ConsumeValueOrDie(); } - Computation clamp; + XlaComputation clamp; { // clamp(y) = clamp<0,5>(y) - ComputationBuilder builder(client_, "clamp"); + XlaBuilder builder("clamp"); auto y_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); auto zero = builder.ConstantR0(0.0); @@ -375,10 +386,10 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { clamp = computation_status.ConsumeValueOrDie(); } - Computation mult_relu_add; + XlaComputation mult_relu_add; { // mult_relu_add(z) = clamp(add_half(2 * max(z, 0))) - ComputationBuilder builder(client_, "mult_relu_add"); + XlaBuilder builder("mult_relu_add"); auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); auto zero = builder.ConstantR0(0.0); @@ -392,7 +403,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { mult_relu_add = computation_status.ConsumeValueOrDie(); } - ComputationBuilder builder(client_, "map10"); + XlaBuilder builder("map10"); { auto x = builder.ConstantR1( {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -405,7 +416,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { } XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); auto y = builder.ConstantR0(3); builder.Rem(x, y); @@ -415,7 +426,7 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { } XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({false, true}); auto y = builder.ConstantR1({true, false}); builder.Eq(x, y); @@ -425,7 +436,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { } XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({false, true}); auto y = builder.ConstantR1({true, false}); builder.Ne(x, y); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 52157b837c383205f77a030ef98b2fd03a41aff5..c463f3eac55e5b8ab32dc52d5a38e7840241bc58 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,8 +37,6 @@ limitations under the License. #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -54,29 +52,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -91,29 +88,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -123,31 +119,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto orig_shape = ShapeUtil::MakeShape(S32, {2}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Reduce(builder.ConstantR1(2, 1), builder.ConstantR0(0), CreateScalarAddComputation(S32, &builder), {0}); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -156,28 +151,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Ne(builder.ConstantR0(true), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: or condition with true. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Or(prev, builder.ConstantR0(true)); + builder.Or(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true)); - auto result = builder.While(condition, body, init); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } @@ -194,9 +189,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -205,33 +200,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 15.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1({}); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1({}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.0001)); } @@ -247,9 +243,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -258,33 +254,34 @@ TEST_F(WhileTest, WhileWithVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -306,9 +303,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -317,34 +314,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); builder.Tuple({result}); // Individual elements with increase by 1/8 each time through the loop, so @@ -366,9 +363,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(N), iteration); @@ -377,28 +374,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(N); auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); @@ -419,9 +416,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(N), iteration); @@ -430,21 +427,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the body. // Add 1 to the iteration variable permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); @@ -455,7 +452,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); std::vector expected = {6.f, 6.f, 6.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -474,9 +471,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -486,26 +483,27 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( @@ -523,9 +521,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -534,27 +532,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and or the predicate with true - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); auto new_pred = builder.Or(pred, builder.ConstantR0(true)); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple({builder.ConstantR0(0), builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true))}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_predicate = Literal::CreateR0(true); @@ -570,9 +568,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -582,25 +580,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the body. // Add 1 to the iteration variable and set the other tuple element to a // constant. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); - auto result = - builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), - builder.ConstantR0(7)}); + builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), + builder.ConstantR0(7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR0(7)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR0(7); @@ -631,20 +628,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -654,34 +651,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - Computation body2; + XlaComputation body2; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -692,11 +689,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -710,20 +707,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -733,21 +730,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -758,11 +755,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -777,20 +774,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -800,21 +797,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -824,11 +821,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -844,9 +841,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -856,9 +853,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -873,18 +870,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( @@ -910,23 +907,23 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -TEST_F(WhileTest, WhileWithPrngScalarResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prev = builder.Reshape( builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); + {}); builder.Gt(builder.ConstantR0(count), prev); return builder.Build().ConsumeValueOrDie(); }; // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, v6s32, "prev"); auto inc = builder.ConcatInDim( {builder.ConstantR1({1}), @@ -934,16 +931,15 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { builder.ConstantR0(100), ShapeUtil::MakeShape(S32, {5}))}, 0); - auto result = builder.Add(inc, prev); + builder.Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); - auto result = builder.While(build_condition(count), body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(build_condition(count), body, init); return builder.Build(); }; @@ -961,22 +957,21 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); auto t = outer.Tuple({p, outer.ConstantR1({1, 1})}); - TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr tuple_shape, - outer.GetShape(t)); + TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t)); - ComputationBuilder cond(client_, "cond"); - auto cond_t = cond.Parameter(0, *tuple_shape, "t"); + XlaBuilder cond("cond"); + auto cond_t = cond.Parameter(0, tuple_shape, "t"); TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), cond.ConstantR1({42, 42})), &cond) .status()); - ComputationBuilder body(client_, "body"); - auto body_t = body.Parameter(0, *tuple_shape, "t"); + XlaBuilder body("body"); + auto body_t = body.Parameter(0, tuple_shape, "t"); auto e = body.GetTupleElement(body_t, 1); body.Tuple({e, e}); @@ -997,15 +992,15 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); TF_ASSERT_OK( Any(cond.Eq(cond_t, cond.ConstantR1({42, 42})), &cond).status()); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, element_shape, "t"); auto e = body.Broadcast(body.ConstantR0(1.0), {2}); @@ -1023,14 +1018,14 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); cond.Eq(cond_t, cond.ConstantR0(42)); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, element_shape, "t"); auto tuple = body.Tuple({body_t, body.Add(body_t, body.ConstantR0(1))}); @@ -1059,23 +1054,23 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { auto result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Tuple({outer.ConstantR0(0), outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto params = cond.Parameter(0, result_shape, "prev"); auto cond_t = cond.Add(cond.GetTupleElement(params, 1), cond.GetTupleElement(params, 0)); cond.Lt(cond_t, cond.ConstantR0(30)); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, result_shape, "t"); auto tuple = body.Tuple( - {body.Add(body.GetTupleElement(params, 0), body.ConstantR0(1)), - body.Add(body.GetTupleElement(params, 1), body.ConstantR0(1))}); + {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0(1)), + body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0(1))}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); @@ -1107,9 +1102,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { auto inner_result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - Computation inner_condition; + XlaComputation inner_condition; { - ComputationBuilder builder(client_, "inner_condition"); + XlaBuilder builder("inner_condition"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); builder.Lt(i, builder.ConstantR0(7)); @@ -1118,9 +1113,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the outer loop condition: // repeat while result < 30. - Computation outer_condition; + XlaComputation outer_condition; { - ComputationBuilder builder(client_, "outer_condition"); + XlaBuilder builder("outer_condition"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); builder.Lt(prev, builder.ConstantR0(30)); outer_condition = builder.Build().ConsumeValueOrDie(); @@ -1128,34 +1123,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to // `result`. - Computation inner_body; + XlaComputation inner_body; { - ComputationBuilder builder(client_, "inner_body"); + XlaBuilder builder("inner_body"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); auto result = builder.GetTupleElement(params, 1); i = builder.Add(builder.ConstantR0(1), i); result = builder.Add(builder.ConstantR0(2), result); - auto output = builder.Tuple({i, result}); + builder.Tuple({i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop: run the inner loop with i = 0. - Computation outer_body; + XlaComputation outer_body; { - ComputationBuilder builder(client_, "outer_body"); + XlaBuilder builder("outer_body"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); auto init = builder.Tuple({builder.ConstantR0(0), prev}); auto result = builder.While(inner_condition, inner_body, init); - auto output = builder.GetTupleElement(result, 1); + builder.GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(outer_condition, outer_body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(outer_condition, outer_body, init); ComputeAndCompareR0(&builder, 42, {}); } @@ -1166,22 +1160,22 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithCallInsideCondition) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition_callee; + XlaComputation condition_callee; { - ComputationBuilder builder(client_, "condition_callee"); + XlaBuilder builder("condition_callee"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Tuple({builder.Gt(builder.ConstantR0(5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto result = builder.Call(condition_callee, {prev}); builder.GetTupleElement(result, 0); @@ -1189,20 +1183,19 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) { } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -1214,28 +1207,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_shape, "state"); builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_shape, "state"); auto indvar = builder.GetTupleElement(state, 0); auto input_0 = builder.GetTupleElement(state, 1); auto input_1 = builder.GetTupleElement(state, 2); auto output = builder.Tanh(builder.Dot(input_0, input_1)); auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); - auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + builder.Tuple({indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); auto init = builder.Tuple( {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); @@ -1268,9 +1261,9 @@ void BM_WhileLoop(int num_iters) { // Create while condition computation with 'loop_limit'. const int32 loop_limit = 100; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(loop_limit)); @@ -1278,9 +1271,9 @@ void BM_WhileLoop(int num_iters) { } // Create while body computation with unit loop increment. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -1294,12 +1287,12 @@ void BM_WhileLoop(int num_iters) { auto starts = builder.ConstantR1({0, 0, 0}); // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. - ComputationBuilder builder(client, "while"); + XlaBuilder builder("while"); auto zero = builder.ConstantR0(0.0); auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); auto init = builder.Tuple({builder.ConstantR0(0), input}); @@ -1327,10 +1320,6 @@ void BM_WhileLoop(int num_iters) { } } -// TODO(b/32470510): Benchmark fails on parallel CPU backend. -#ifndef XLA_TEST_BACKEND_CPU_PARALLEL BENCHMARK(BM_WhileLoop); -#endif - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 9ad2a1985331b80625dd0687ea052300bc99e440..3c9a01653c67203cbc962a3d3d967142f7a2102c 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,8 +17,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,13 +28,14 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -namespace se = ::perftools::gputools; + namespace gtl = ::tensorflow::gtl; class HloProfileTest : public ClientLibraryTestBase {}; @@ -82,8 +84,8 @@ Status ParseOneProfileOutputLine( string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = "([^ ]+)"; - string match_trops = "([^ ]+)"; + string match_flops = "([^ ]*)"; + string match_trops = "([^ ]*)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; @@ -118,7 +120,7 @@ Status ParseOneProfileOutputLine( // Returns void so that we can ASSERT. void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, - const Computation& computation, + const XlaComputation& computation, const Shape& lhs_arg_shape, const Shape& rhs_arg_shape) { LocalService* service = ClientLibrary::GetXlaService(client->platform()); @@ -128,23 +130,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, auto* transfer_manager = backend->transfer_manager(); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr lhs_arg, + ScopedShapedBuffer lhs_arg, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(lhs_arg_shape), *lhs_arg)); + executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr rhs_arg, + ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(rhs_arg_shape), *rhs_arg)); + executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions())); + ExecutableBuildOptions().set_hlo_profile(true))); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( @@ -164,7 +166,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, backend->eigen_intra_op_thread_pool()); TF_ASSERT_OK_AND_ASSIGN( auto execution_result, - executable->ExecuteOnStream(&run_options, {lhs_arg.get(), rhs_arg.get()}, + executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, &hlo_execution_profile)); (void)execution_result; @@ -174,8 +176,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, XLA_VLOG_LINES(4, *profile_output); } -// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. -XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { +XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { const int64 m = 256, k = 256, n = 256; Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k}); @@ -185,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(platform)); - ComputationBuilder builder(client, TestName()); + XlaBuilder builder(TestName()); auto result = builder.Tanh(builder.Add( builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); @@ -238,12 +239,9 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { EXPECT_TRUE(HasTrops(tanh_profile)); } -// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. -// // TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo // instructions "interior" to while nodes. -XLA_TEST_F(HloProfileTest, - DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(ProfileWhileComputation))) { +XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { const int64 size = 256; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = @@ -254,18 +252,18 @@ XLA_TEST_F(HloProfileTest, TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(platform)); - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_result_shape, "state"); auto iteration = builder.GetTupleElement(state, 0); builder.Gt(builder.ConstantR0(5), iteration); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_result_shape, "state"); auto matrix = builder.GetTupleElement(state, 1); auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), @@ -274,7 +272,7 @@ XLA_TEST_F(HloProfileTest, TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client, TestName()); + XlaBuilder builder(TestName()); auto initial_while_state = builder.Tuple({builder.ConstantR0(0), builder.Parameter(0, matrix_shape, "initial_value")}); @@ -294,7 +292,8 @@ XLA_TEST_F(HloProfileTest, auto while_body_profile_start = std::find_if(profile_output_lines.begin(), profile_output_lines.end(), [](tensorflow::StringPiece s) { - return s.starts_with("Execution profile for body"); + return tensorflow::str_util::StartsWith( + s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.end()); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 92b2b1ee778f8b0f8104e7d7ff27a5c11db59768..a9f2915b458b1816926de727b3da21982d06f6c0 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; @@ -25,7 +29,38 @@ GTEST_API_ int main(int argc, char** argv) { return 2; } + // If the --benchmarks flag is passed in then only run the benchmarks, not the + // tests. + for (int i = 1; i < argc; i++) { + tensorflow::StringPiece arg(argv[i]); + if (arg == "--benchmarks" || + tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + const char* pattern = nullptr; + if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + pattern = argv[i] + strlen("--benchmarks="); + } else { + // Handle flag of the form '--benchmarks foo' (no '='). + if (i + 1 >= argc || + tensorflow::str_util::StartsWith(argv[i + 1], "--")) { + LOG(ERROR) << "--benchmarks flag requires an argument."; + return 2; + } + pattern = argv[i + 1]; + } + // Unfortunately Google's internal benchmark infrastructure has a + // different API than Tensorflow's. +#if defined(PLATFORM_GOOGLE) + base::SetFlag(&FLAGS_benchmarks, pattern); + RunSpecifiedBenchmarks(); +#else + tensorflow::testing::Benchmark::Run(pattern); +#endif + return 0; + } + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 6fa4c48e11d1102367b21bc21d4734466495ef0e..56702feab9a4e8d00df3a165ab994aef2d42d830 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -38,11 +38,11 @@ namespace xla { StatusOr> TextLiteralReader::ReadPath( tensorflow::StringPiece path) { - CHECK(!path.ends_with(".gz")) + CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = - tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file); + tensorflow::Env::Default()->NewRandomAccessFile(std::string(path), &file); if (!s.ok()) { return s; } @@ -92,7 +92,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece sp(shape_string); if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = sp.ToString(); + string tmp = std::string(sp); shape_string = tmp; } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); @@ -115,7 +115,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece value_string = pieces[1]; tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!coordinates_string.Consume("(")) { + if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); } @@ -124,10 +124,10 @@ StatusOr> TextLiteralReader::ReadAllLines() { line.c_str()); } float value; - if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(), + if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - value_string.ToString().c_str()); + std::string(value_string).c_str()); } SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); coordinate_values.clear(); @@ -136,7 +136,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - piece.ToString().c_str()); + std::string(piece).c_str()); } coordinate_values.push_back(coordinate_value); } diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 3fee467594d8423c707abf07a0622a738437830a..373c0d2d8d8ab05dec11e51f265d41b91e7920bf 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -30,10 +30,10 @@ limitations under the License. namespace xla { -/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( +/* static */ Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(path.ToString(), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); if (!s.ok()) { return s; } @@ -43,7 +43,7 @@ namespace xla { return s; } - tensorflow::Status status; + Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( [f_ptr, &status](tensorflow::gtl::ArraySlice indices, diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 7375493f4309c9bf75fc9d724626267dff7ce5ed..0a1235b5e04675da0f412bafab6c4ecf04367787 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -37,8 +37,8 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static tensorflow::Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 091fa0c3ec807a66449eca0bfbb141285b8eb532..e4a052c8f1c0009619c3a94606f6384d04006e4e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,11 +36,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -63,10 +62,9 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -75,6 +73,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -83,11 +82,12 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -136,7 +136,7 @@ tf_cc_binary( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -163,12 +163,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -182,12 +180,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -200,13 +197,12 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -222,17 +218,3 @@ tf_cc_binary( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc index fe03a6e7bdfe99877c250fe1ae22beee4c8018a2..14d01b5bfb067cc39abc4d6e0605007624b6e0ae 100644 --- a/tensorflow/compiler/xla/tools/convert_computation.cc +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/env.h" @@ -33,7 +33,7 @@ namespace xla { namespace tools { void RealMain(const string& mode, const string& path) { - SessionModule module; + HloSnapshot module; tensorflow::Env* env = tensorflow::Env::Default(); if (mode == "txt2bin") { TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 21ae8583d7cd3343230dcaff7dc17456e9e3e702..befb55453777dce30af89bcaad2ffe1647097576 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -17,7 +17,7 @@ limitations under the License. // // Dumps a graphviz URL for a snapshot computation to the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The GraphViz URL is placed into the log stderr, whereas computation @@ -30,11 +30,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,10 +48,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index b82f1c81c84b487c1661af5267b9123da97bb107..cfb8f37487d6499b803438a135be54524fcf17d2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -21,11 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); std::unique_ptr program_shape = client->GetComputationShape(computation).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1..5dd5150be339846d0775880931f615b92c5b08d8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -19,11 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,16 +38,16 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); if (compile) { std::unique_ptr program_shape = @@ -65,8 +63,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); @@ -74,13 +71,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { local_service->backend().platform()->Name().c_str(), module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { - const ComputationTracker& tracker = local_service->computation_tracker(); - UserComputation* user_computation = - tracker.Resolve(computation.handle()).ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); std::unique_ptr module = - tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + HloModule::CreateFromProto(computation.proto(), config) .ConsumeValueOrDie(); fprintf(stdout, "%s\n", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 51f90b07c66f7d839f587350726333b9dbe6a9f0..a5dce20456c6a2402f425ebb3d575d1bb625f839 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -48,10 +47,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index 97aacf6b39f83978e732060817cd93ede81ca782..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,86 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index eda5effbb92db92c9317a956497a00c0ec15c27c..f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -17,13 +17,16 @@ limitations under the License. // // Replays computations and shows the results on the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // Computations that require arguments can be replayed using fake data by // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -36,12 +39,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,6 +68,7 @@ namespace { // fields. struct Options { string fake_infeed_shape; + bool generate_fake_infeed = false; bool use_fake_data = false; bool print_result = true; int num_runs = 1; @@ -71,85 +77,168 @@ struct Options { // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // -// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; -// otherwise, no infeed is performed. -StatusOr> ReplayComputation( - const SessionModule& module, Client* client, const Options& opts) { - TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); +// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided. +// If generate_fake_infeed is true, the required infeed shape is derived from +// the computation and then used to provide a fake infeed shape. +// +// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, +// no infeed is performed. +StatusOr ReplayComputation(const HloSnapshot& module, + LocalClient* client, const Options& opts) { + XlaComputation computation(module.hlo().hlo_module()); - std::vector> arguments; + // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our + // arguments. This is a bit involved, because we may have to convert from + // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our + // objects. + std::vector scoped_shaped_buffer_arguments; + std::vector> global_data_arguments; + std::vector argument_ptrs; if (opts.use_fake_data) { - arguments = MakeFakeArgumentsOrDie(computation, client); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + for (const auto& data : global_data_arguments) { + argument_ptrs.push_back( + client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) + .ValueOrDie()); + } } else { // use recorded data if available for (const auto& proto : module.arguments()) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(proto)); - TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(*literal)); - arguments.push_back(std::move(data)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer data, + client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + scoped_shaped_buffer_arguments.push_back(std::move(data)); + } + for (const auto& argument : scoped_shaped_buffer_arguments) { + argument_ptrs.push_back(&argument); } } + bool provide_infeed = false; + Shape infeed_shape; + if (!opts.fake_infeed_shape.empty()) { + StatusOr shape_status = + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + TF_CHECK_OK(shape_status.status()); + infeed_shape = std::move(shape_status).ValueOrDie(); + provide_infeed = true; + } else if (opts.generate_fake_infeed) { + for (const auto& comp : computation.proto().computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { + CHECK(!provide_infeed) + << "--generate_fake_infeed only works if the model has 0 or 1 " + "infeed ops, but this one has >= 2."; + provide_infeed = true; + infeed_shape = instruction.shape(); + LOG(INFO) << "Generating fake infeed shape for inferred shape: " + << ShapeUtil::HumanString(infeed_shape); + } + } + } + } // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape. + // concurrent infeed occur via the fake_infeed_shape, or when + // --generate_fake_infeed is passed and there exists an infeed operation in + // the HloSnapshot. tensorflow::gtl::optional pool; - - if (!opts.fake_infeed_shape.empty()) { + std::unique_ptr data; + if (provide_infeed) { + data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + } + auto transfer_infeed = [&data, client]() { + TF_CHECK_OK(client->TransferToInfeed(*data)); + }; + if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([opts, client]() { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - Shape shape = std::move(shape_status).ValueOrDie(); - StatusOr> data_status = MakeFakeLiteral(shape); - TF_CHECK_OK(data_status.status()); - std::unique_ptr data = std::move(data_status).ValueOrDie(); - while (true) { - TF_CHECK_OK(client->TransferToInfeed(*data)); - } + pool->Schedule([transfer_infeed]() { + // There may be several infeed buffers needed, however we don't know how + // many. If we proactively transfer too many infeed buffers, we may run + // out of memory. If we transfer too few infeed buffers, the program will + // hang. Therefore, we register a callback that is called when the infeed + // becomes empty, and in this callback we will transfer another fake + // infeed. + auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); + infeed_manager->RegisterOnEmptyCallback(transfer_infeed); + transfer_infeed(); }); } - std::vector execute_arguments; - execute_arguments.reserve(arguments.size()); - for (auto& argument : arguments) { - execute_arguments.push_back(argument.get()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); } + std::unique_ptr executable = + client->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); // Run the computation num_runs times, and return the result from the last // execution. - std::unique_ptr result; + StreamExecutorMemoryAllocator allocator( + client->platform(), + {client->platform()->ExecutorForDevice(0).ValueOrDie()}); + tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - if (opts.print_result) { - TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( - computation, execute_arguments, - /*execution_options=*/nullptr, &profile)); - } else { - // If we're not printing the result, execute the computation but don't - // bother retrieving the result. This can be a significant speedup. - TF_RETURN_IF_ERROR(client - ->Execute(computation, execute_arguments, - /*execution_options=*/nullptr, &profile) - .status()); - } + ExecutableRunOptions run_options; + run_options.set_execution_profile(&profile); + run_options.set_allocator(&allocator); + + TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - return std::move(result); + // Check that --num_runs > 0, otherwise *result below will fail with an + // unhelpful error (because the loop didn't run any iterations). + CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; + TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + client->ShapedBufferToLiteral(*result)); + return std::move(*result_literal); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { - Client* client = ClientLibrary::LocalClientOrDie(); +StatusOr ParseInputFile(const string& filename, + const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); + + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} + +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - SessionModule module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); - StatusOr> result_status = - ReplayComputation(module, client, opts); + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { + continue; + } + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); + StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -157,16 +246,17 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { continue; } - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { + if (opts.print_result) { + Literal result = std::move(result_status).ValueOrDie(); + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), + ShapeUtil::HumanString(result.shape()).c_str(), + result.ToString().c_str()); + if (snapshot.has_result()) { std::unique_ptr literal = - Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), literal->ToString().c_str()); } } @@ -191,6 +281,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, + "Whether a fake infeed shape should be generated " + "derived from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 1f3340cbc6afa9bda8bf639d01b8185968f79a4d..4e53fafcc97ff53afc5713e7ed8ee5222fac316b 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -18,7 +18,7 @@ limitations under the License. // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command // line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The output format is: @@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -49,13 +48,14 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + auto computation = client->LoadSnapshot(module).ConsumeValueOrDie(); std::unique_ptr shape = client->GetComputationShape(computation).ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s\n", arg, + module.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(*shape).c_str()); } } diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 9fa4297523bab0748863479be52dff1b7b523a8b..b645acb700b0f168112a40c9c72b4669435f717d 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -46,4 +46,10 @@ using ::Eigen::half; } // namespace xla +// Alias namespace ::stream_executor as ::xla::se. +namespace stream_executor {} +namespace xla { +namespace se = ::stream_executor; +} // namespace xla + #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1f0c626bbb2d64ef4e67c9ec51485ae96ae73d04..e43498e381b8e63543e2ddda08ca7c0df91817e4 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" -#include #include #include @@ -244,8 +243,8 @@ string HumanReadableNumOps(double flops, double nanoseconds, static_cast(nano_flops * 1e9)); tensorflow::StringPiece sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (sp.ends_with("B") || // Ends in 'B', ignoring case - sp.ends_with("b")) { + if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case + tensorflow::str_util::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); @@ -292,7 +291,8 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, } int64 Product(tensorflow::gtl::ArraySlice xs) { - return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies()); + return std::accumulate(xs.begin(), xs.end(), static_cast(1), + std::multiplies()); } std::vector> CommonFactors( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 46ec7af54290f40dfac1e4627801eab4dabb8aa5..b4f45cc972d3d397ddff8e8d9163d1fef387392f 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/status.h" @@ -217,6 +218,12 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); // Passed-varargs variant of the InvalidArgument factory above. Status InvalidArgumentV(const char* format, va_list args); +template +Status InvalidArgumentStrCat(Args&&... concat) { + return InvalidArgument( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + template Status UnimplementedStrCat(Args&&... concat) { return Unimplemented( @@ -427,29 +434,37 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name); template -bool c_all_of(Container container, Predicate predicate) { - return std::all_of(std::begin(container), std::end(container), predicate); +bool c_all_of(const Container& container, Predicate&& predicate) { + return std::all_of(std::begin(container), std::end(container), + std::forward(predicate)); +} + +template +bool c_any_of(const Container& container, Predicate&& predicate) { + return std::any_of(std::begin(container), std::end(container), + std::forward(predicate)); } template -OutputIterator c_transform(InputContainer input_container, +OutputIterator c_transform(const InputContainer& input_container, OutputIterator output_iterator, - UnaryOperation unary_op) { + UnaryOperation&& unary_op) { return std::transform(std::begin(input_container), std::end(input_container), - output_iterator, unary_op); + output_iterator, + std::forward(unary_op)); } template -OutputIterator c_copy_if(InputContainer input_container, +OutputIterator c_copy_if(const InputContainer& input_container, OutputIterator output_iterator, - UnaryPredicate predicate) { + UnaryPredicate&& predicate) { return std::copy_if(std::begin(input_container), std::end(input_container), - output_iterator, predicate); + output_iterator, std::forward(predicate)); } template -OutputIterator c_copy(InputContainer input_container, +OutputIterator c_copy(const InputContainer& input_container, OutputIterator output_iterator) { return std::copy(std::begin(input_container), std::end(input_container), output_iterator); @@ -461,12 +476,13 @@ void c_sort(InputContainer& input_container) { } template -void c_sort(InputContainer& input_container, Comparator comparator) { - std::sort(std::begin(input_container), std::end(input_container), comparator); +void c_sort(InputContainer& input_container, Comparator&& comparator) { + std::sort(std::begin(input_container), std::end(input_container), + std::forward(comparator)); } template -bool c_binary_search(Sequence& sequence, T&& value) { +bool c_binary_search(const Sequence& sequence, T&& value) { return std::binary_search(std::begin(sequence), std::end(sequence), std::forward(value)); } @@ -476,10 +492,81 @@ bool c_is_sorted(const C& c) { return std::is_sorted(std::begin(c), std::end(c)); } +template +bool c_is_sorted(const C& c, Compare&& comp) { + return std::is_sorted(std::begin(c), std::end(c), + std::forward(comp)); +} + template auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); } + +template +auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { + return std::find_if(std::begin(c), std::end(c), std::forward(pred)); +} + +template +auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) { + return std::find(std::begin(c), std::end(c), std::forward(value)); +} + +template +void c_reverse(Sequence& sequence) { + std::reverse(std::begin(sequence), std::end(sequence)); +} + +template +typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, + BinaryOp&& binary_op) { + return std::accumulate(std::begin(sequence), std::end(sequence), + std::forward(init), + std::forward(binary_op)); +} + +template +typename std::iterator_traits< + decltype(std::begin(std::declval()))>::difference_type +c_count_if(const C& c, Pred&& pred) { + return std::count_if(std::begin(c), std::end(c), std::forward(pred)); +} + +template +int64 FindIndex(const C& c, Value&& value) { + auto it = c_find(c, std::forward(value)); + return std::distance(c.begin(), it); +} + +template +void InsertAt(C* c, int64 index, Value&& value) { + c->insert(c->begin() + index, std::forward(value)); +} + +template +void EraseAt(C* c, int64 index) { + c->erase(c->begin() + index); +} + +// Returns true if `x` fits in 32-bits. +template +bool IsInt32(T x) { + // Following conversion rules: "the value is unchanged if it can be + // represented in the destination type (and bit-field width); otherwise, the + // value is implementation-defined." + return static_cast(x) == x; +} + +template +Status EraseElementFromVector(std::vector* container, const T& value) { + // c_find returns a const_iterator which does not seem to work on gcc 4.8.4, + // and this breaks the ubuntu/xla_gpu build bot. + auto it = std::find(container->begin(), container->end(), value); + TF_RET_CHECK(it != container->end()); + container->erase(it); + return Status::OK(); +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 93284b80f9e1f82c4b18dc7388754d5c01a7740c..f11123ca24849af1d9c4fd49809a986eb7202bd5 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -199,6 +199,9 @@ bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { int64 DilatedBound(int64 bound, int64 dilation) { CHECK_GE(bound, 0); CHECK_GE(dilation, 1); + if (bound == 0) { + return 0; + } // Suppose the array has three entries 123 and the dilation factor is 4. Then // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except @@ -212,7 +215,7 @@ int64 StridedBound(int64 bound, int64 window_size, int64 stride) { CHECK_GE(bound, 0); CHECK_GE(stride, 1); - if (window_size > bound) { + if (bound == 0 || window_size > bound) { return 0; } diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 6b136d333bbf079efd314833f46fe3b98743fbac..1439f1bcc5cec39203a7cb4b1f8604e7349382c6 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -6,7 +6,9 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") # xla_proto_library() is a convenience wrapper around cc_proto_library. -def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): +def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs): + if kwargs.get('use_grpc_plugin'): + kwargs['use_grpc_namespace'] = True cc_proto_library(name=name, srcs=srcs, deps=deps, @@ -16,6 +18,13 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): ), protoc="@protobuf_archive//:protoc", testonly=testonly, - visibility=visibility,) + visibility=visibility, + **kwargs) + +def xla_py_grpc_library(**kwargs): + # Note: we don't currently define any special targets for Python GRPC in OSS. + _ignore = kwargs + pass + ORC_JIT_MEMORY_MAPPER_TARGETS = [] diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 56162ab44e2e0e3e4478fe631888f243332dc1d8..6f07e4606bef015214f2c564515c8258a906205b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -16,7 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/session.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; package xla; @@ -188,6 +188,12 @@ message DebugOptions { // directory. string xla_dump_per_pass_hlo_proto_to = 96; + // Generate calls to MKL-DNN in the CPU backend. + bool xla_cpu_use_mkl_dnn = 97; + + // Maximum kernel unroll factor for the GPU backend. + int32 xla_gpu_max_kernel_unroll_factor = 98; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; @@ -219,22 +225,6 @@ message ExecutionOptions { repeated DeviceHandle device_handles = 5; } -message SnapshotComputationRequest { - ComputationHandle computation = 1; -} - -message SnapshotComputationResponse { - SessionModule module = 1; -} - -message LoadComputationSnapshotRequest { - SessionModule module = 1; -} - -message LoadComputationSnapshotResponse { - ComputationHandle computation = 1; -} - message GetDeviceHandlesRequest { int64 device_count = 1; } @@ -293,8 +283,8 @@ message ResetDeviceRequest { message ResetDeviceResponse { } -message ComputationStatsRequest { - ComputationHandle computation = 1; +message ComputationGraphStatsRequest { + HloModuleProto computation = 1; DebugOptions debug_options = 2; } @@ -302,14 +292,6 @@ message ComputationStatsResponse { ComputationStats stats = 1; } -message ComputationRequest { - string name = 1; -} - -message ComputationResponse { - ComputationHandle computation = 1; -} - message CreateChannelHandleRequest { } @@ -324,26 +306,16 @@ message UnregisterRequest { message UnregisterResponse { } -message SetReturnValueRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message SetReturnValueResponse { -} - -message ExecuteRequest { - reserved 3, 4; - - ComputationHandle computation = 1; +message ExecuteGraphRequest { + HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 5; + ExecutionOptions execution_options = 3; } -message ExecuteParallelRequest { - repeated ExecuteRequest requests = 1; +message ExecuteGraphParallelRequest { + repeated ExecuteGraphRequest requests = 1; } message ExecuteResponse { @@ -355,21 +327,6 @@ message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } -message ExecuteAsyncRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; -} - -message ExecuteAsyncResponse { - // A handle to the execution launched asynchronously. - ExecutionHandle execution = 1; -} - message WaitForExecutionRequest { ExecutionHandle execution = 1; } @@ -379,26 +336,13 @@ message WaitForExecutionResponse { ExecutionProfile profile = 2; } -message IsConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - int64 num_parameters = 3; -} - -message IsConstantResponse { - bool is_constant = 1; -} - -message ComputeConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - Layout output_layout = 3; - repeated LiteralProto parameters = 4; +message ComputeConstantGraphRequest { + HloModuleProto computation = 1; + Layout output_layout = 2; } message ComputeConstantResponse { - // A LiteralProto is returned directly for this request, instead of a - // ComputationDataHandle. + // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } @@ -440,14 +384,6 @@ message LoadDataResponse { int64 nanoseconds = 5; } -message SpecializeRequest { - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; -} - -message SpecializeResponse { -} - message GetShapeRequest { GlobalDataHandle data = 1; } @@ -456,14 +392,6 @@ message GetShapeResponse { Shape shape = 1; } -message GetComputationShapeRequest { - ComputationHandle computation = 1; -} - -message GetComputationShapeResponse { - ProgramShape program_shape = 1; -} - message UnpackRequest { GlobalDataHandle data = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 28620c3b86349281573eaf57d2838bee1488d838..0af73e8a93060f4569ddef9697b89a6fa2b8674b 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -66,11 +66,16 @@ enum PrimitiveType { // in the dimensions field. TUPLE = 13; - // An opaque type used for passing context specific data to a custom - // operation. + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. OPAQUE = 14; - // Next = 17 + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 18 } // Describes the value held inside padding elements. @@ -134,6 +139,8 @@ enum Format { // example, Convert) are ignored. // // See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange message Layout { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. @@ -159,9 +166,12 @@ message Layout { // memory. This field must be unset unless the format is SPARSE. int64 max_sparse_elements = 5; - // Important: if any field is added, be sure to modify ShapeUtil::Equal() - // appropriately to account for the new field. + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. } +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) // A shape describes the number of dimensions in the array, the size of each // dimension, and the primitive component type. @@ -170,6 +180,8 @@ message Layout { // defined. // // See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange message Shape { reserved 1; reserved "rank"; @@ -190,9 +202,12 @@ message Shape { // The layout used to back this shape. Layout layout = 5; - // Important: if any field is added, be sure to modify ShapeUtil::Equal() and - // ShapeUtil::Compatible() appropriately to account for the new field. + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. } +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) // Shape of the parameters and output of a computation (like a traditional // function signature). @@ -261,12 +276,6 @@ message ExecutionProfile { int64 compute_and_transfer_time_ns = 5; } -// Handle given to a user that represents a computation that the user builds up -// before execution. -message ComputationHandle { - int64 handle = 1; -} - // Handle given to a user that represents an execution that the user launched // asynchronously on the device. message ExecutionHandle { @@ -280,13 +289,6 @@ message GlobalDataHandle { int64 handle = 1; } -// Handle given to a user that represents a data result in a computation. -// This is used to pass to subsequent computations that depends upon the data as -// an operand. -message ComputationDataHandle { - int64 handle = 1; -} - // Handle given to a user that represents a replicated virtual device. Each // replicated device represents N physical devices for execution where N is the // number of replicas. @@ -355,17 +357,19 @@ message WindowDimension { // positions of the window in this dimension. int64 stride = 2; - // If positive, means the amount of padding with zeroes to add to the base - // area at the low end of this dimension; if negative, its negative means the - // number of elements removed from the low end of this dimension. For example, - // in the horizontal dimension of a rectangle, this would be the number of - // zeroes to pad on the left, given that indices increase when going right. + // If positive, means the amount of padding to add to the base area at the low + // end of this dimension; if negative, its negative means the number of + // elements removed from the low end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of padding + // values to pad on the left, given that indices increase when going right. + // The actual padding value depends upon the context. Convolution pads with + // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + // init value. int64 padding_low = 3; - // As padding_low, but on the high end of this dimension. For - // example, in the horizontal dimension of a rectangle, this would - // be the number of zeroes to pad on the right, given that indices - // increase when going right. + // As padding_low, but on the high end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of values to + // pad on the right, given that indices increase when going right. int64 padding_high = 4; // Dilation factor of the sliding window in this dimension. A dilation factor @@ -418,44 +422,10 @@ message GatherDimensionNumbers { // transforms the gather index looked up from the gather_indices tensor into // the starting index in the input space. repeated int64 gather_dims_to_operand_dims = 3; -} - -// Operation requests that are all collected as a tagged union with a oneof -// field in OpRequest. - -message ConstantRequest { - LiteralProto literal = 2; -} -message GetTupleElementRequest { - ComputationDataHandle operand = 2; - int64 index = 3; -} - -message SliceRequest { - ComputationDataHandle operand = 2; - repeated int64 start_indices = 3; - repeated int64 limit_indices = 4; - repeated int64 strides = 5; -} - -message DynamicSliceRequest { - // Operand from which to slice at dynamic 'start_indices'. - ComputationDataHandle operand = 2; - // Dynamically computed 'start_indices' for slice operation. - ComputationDataHandle start_indices = 3; - // Slice sizes for each dimension (note that indices calculations are computed - // modulo dimension sizes to avoid out-of-bound array accesses). - repeated int64 slice_sizes = 4; -} - -message DynamicUpdateSliceRequest { - // Operand on which slice 'update' is to be applied. - ComputationDataHandle operand = 2; - // The slice update to apply to 'operand'. - ComputationDataHandle update = 3; - // Dynamically computed start indices for the update slice operation. - ComputationDataHandle start_indices = 4; + // The dimension in the gather_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; } message ConvolutionDimensionNumbers { @@ -495,13 +465,6 @@ message ConvolutionDimensionNumbers { // Next = 13 }; -message ConvolveRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kernel. - ConvolutionDimensionNumbers dimension_numbers = 5; -} - enum FftType { FFT = 0; // Forward FFT; complex in, complex out. IFFT = 1; // Inverse FFT; complex in, complex out. @@ -510,56 +473,6 @@ enum FftType { // fft_length real out } -message FftRequest { - FftType fft_type = 1; - repeated int64 fft_length = 2; // Multivalent for higher-order FFT. - ComputationDataHandle operand = 3; -} - -message InfeedRequest { - // The shape of the data returned by reading the device's infeed buffer. - Shape shape = 2; - - // Additional infeed configuration for the backend. - bytes config = 3; -} - -message OutfeedRequest { - // The shape of the data returned by reading the device's outfeed buffer. - Shape shape = 1; - - // Operand to the Outfeed. Supports tuple. - ComputationDataHandle operand = 2; - - // Backend-specific information for how to perform the outfeed. - bytes outfeed_config = 3; -} - -message CallRequest { - ComputationHandle to_apply = 2; - repeated ComputationDataHandle operands = 3; -} - -message CustomCallRequest { - string call_target_name = 2; - repeated ComputationDataHandle operands = 3; - Shape shape = 4; -} - -message HostComputeRequest { - // Operand to the HostCompute. Supports tuple. - repeated ComputationDataHandle operands = 1; - - // Name used to identify HostSend/Recv channels. - string channel_name = 2; - - // Cost estimate in nanoseconds. - int64 cost_estimate_ns = 3; - - // The shape of any data returned by host. - Shape shape = 4; -} - message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -571,288 +484,6 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; }; -message DotRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; - DotDimensionNumbers dimension_numbers = 4; -} - -message MapRequest { - repeated ComputationDataHandle operands = 2; - ComputationHandle to_apply = 3; - repeated ComputationDataHandle static_operands = 4; - // The dimensions over which to map. - // Example mapping a Dot operation along the batch dimension 0: - // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] - // Map({operand0, operand1}, Dot, {0}) - repeated int64 dimensions = 5; -} - -message ReduceRequest { - // Operand to the reduction. - ComputationDataHandle operand = 2; - - // Initial value for the reduction. This must be consistent with the result - // shape of to_apply. - ComputationDataHandle init_value = 3; - - // The dimensions to reduce over. - repeated int64 dimensions = 4; - - // The computation to apply in the reduction. - ComputationHandle to_apply = 5; -} - -message ReduceWindowRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle init_value = 3; - Window window = 4; - ComputationHandle to_apply = 5; -} - -message BatchNormTrainingRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - float epsilon = 4; - int64 feature_index = 5; -} - -message BatchNormInferenceRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - ComputationDataHandle mean = 4; - ComputationDataHandle variance = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message BatchNormGradRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle mean = 3; - ComputationDataHandle variance = 4; - ComputationDataHandle grad_output = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message CrossReplicaSumRequest { - ComputationDataHandle operand = 2; -} - -message SelectAndScatterRequest { - // Operand array on which the windows slide. - ComputationDataHandle operand = 2; - - // Source array for the data to scatter. - ComputationDataHandle source = 3; - - // Initial scalar value for each element in the output. - ComputationDataHandle init_value = 4; - - // Window configuration. - Window window = 5; - - // Binary function used to select an element from each window. - ComputationHandle select = 6; - - // Binary function used to combine each scattered value from source with the - // current output value at the selected location. - ComputationHandle scatter = 7; -} - -message ReverseRequest { - ComputationDataHandle operand = 2; - repeated int64 dimensions = 3; -} - -message BroadcastRequest { - ComputationDataHandle operand = 2; - repeated int64 broadcast_sizes = 3; -} - -message PadRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle padding_value = 3; - PaddingConfig padding_config = 4; -} - -message ReshapeRequest { - ComputationDataHandle operand = 2; - - // The dimension order for collapse (from fastest-changing to slowest). - repeated int64 dimensions = 3; - - // The new dimension sizes (from dimension 0 to n-1). - repeated int64 new_sizes = 4; -} - -message TransposeRequest { - ComputationDataHandle operand = 2; - - // The permutation of the operand's dimensions (in the range 0 to n-1). - repeated int64 dimensions = 3; -} - -message ParameterRequest { - Shape shape = 2; - int64 parameter = 3; - string name = 4; -} - -message GetLocalShapeRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message GetLocalShapeResponse { - Shape shape = 1; -} - -message TraceRequest { - string tag = 2; - ComputationDataHandle operand = 3; -} - -message ConvertRequest { - ComputationDataHandle operand = 2; - PrimitiveType new_element_type = 3; -} - -message ConcatenateRequest { - repeated ComputationDataHandle operands = 2; - // The dimension in which we concatenate; e.g. if you had dimension arrays of - // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. - // Attempting to concatenate those in dimension 1 would produce an error, as - // 4 != 5 (and there is no ragged array support). - int64 dimension = 3; -} - -message ConditionalRequest { - ComputationDataHandle predicate = 2; - ComputationDataHandle true_operand = 3; - ComputationHandle true_computation = 4; - ComputationDataHandle false_operand = 5; - ComputationHandle false_computation = 6; -} - -message WhileRequest { - ComputationHandle condition = 2; - ComputationHandle body = 3; - ComputationDataHandle init = 4; -} - -enum UnaryOperation { - UNOP_INVALID = 0; - - // Elementwise, logical negation on booleans and bitwise negation on ints. - UNOP_NOT = 1; - - // Elementwise, computes e^x. - UNOP_EXP = 2; - - // Elementwise, computes -x. - UNOP_NEGATE = 3; - - // Puts the elements in the operand into sorted order. - UNOP_SORT = 4; - - // Elementwise, computes tanh(x). - UNOP_TANH = 5; - - // Elementwise, computes the natural logarithm of x. - UNOP_LOG = 6; - - // Elementwise, computes the floor of x. - UNOP_FLOOR = 7; - - // Elementwise, computes the ceil of x. - UNOP_CEIL = 8; - - // Elementwise, computes the abs of x. - UNOP_ABS = 9; - - // Elementwise, computes the sign of x. - UNOP_SIGN = 10; - - // Elementwise, tests if values are finite (not NaN or inf) - UNOP_IS_FINITE = 11; - - // Elementwise, computes the cosine of x. - UNOP_COS = 12; - - // Elementwise, computes the sine of x. - UNOP_SIN = 13; - - // Elementwise, rounds x to nearest integral value, rounding half-way cases - // away from zero. - UNOP_ROUND_NEAREST_AFZ = 14; - - // Elementwise, extract real component of complex x. - UNOP_REAL = 15; - - // Elementwise, extract real component of complex x. - UNOP_IMAG = 16; -} - -message UnaryOpRequest { - UnaryOperation unop = 2; - ComputationDataHandle operand = 3; -} - -enum BinaryOperation { - BINOP_INVALID = 0; - - // Arithmetic operations. - BINOP_ADD = 1; - BINOP_DIV = 2; - BINOP_MUL = 3; - BINOP_SUB = 4; - - // Comparison operators. - BINOP_EQ = 5; - BINOP_GE = 6; - BINOP_GT = 7; - BINOP_LE = 8; - BINOP_LT = 9; - BINOP_NE = 10; - - // Element-wise maximum. - BINOP_MAX = 14; - - // Element-wise minimum. - BINOP_MIN = 15; - - // Raises the left-hand-side to the right-hand-side power. - BINOP_POW = 16; - - // Remainder operation. - BINOP_REM = 17; - - // Element-wise, logical operators on booleans and bitwise operators on ints. - BINOP_AND = 18; - BINOP_OR = 19; - - BINOP_SHIFT_LEFT = 20; - BINOP_SHIFT_RIGHT_ARITHMETIC = 21; - BINOP_SHIFT_RIGHT_LOGICAL = 22; - - // Complex from real, imag. - BINOP_COMPLEX = 23; - - // Computes the 4-quadrant arctangent of the y, x input arguments. - BINOP_ATAN2 = 24; -} - -message BinaryOpRequest { - BinaryOperation binop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - repeated int64 broadcast_dimensions = 5; -} - enum RandomDistribution { RNG_INVALID = 0; @@ -867,67 +498,6 @@ enum RandomDistribution { // Next: 4 } -message RngRequest { - RandomDistribution distribution = 2; - repeated ComputationDataHandle parameter = 3; - Shape shape = 4; -} - -enum TernaryOperation { - TRIOP_INVALID = 0; - - // Given a predicate and two operands, selects operand0 if the predicate is - // true and operand1 if the predicate is false. - TRIOP_SELECT = 1; - - // Given a min, max and an operand returns the operand if between min and max, - // else returns min if operand is less than min or max if operand is greater - // than max. - TRIOP_CLAMP = 3; -} - -message TernaryOpRequest { - TernaryOperation triop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - ComputationDataHandle ehs = 5; -} - -enum VariadicOperation { - VAROP_INVALID = 0; - - // Creates a tuple from its operands. - VAROP_TUPLE = 1; -} - -message VariadicOpRequest { - VariadicOperation varop = 2; - repeated ComputationDataHandle operands = 3; -} - -message ReducePrecisionRequest { - ComputationDataHandle operand = 1; - int32 exponent_bits = 2; - int32 mantissa_bits = 3; -} - -message SendRequest { - ComputationDataHandle operand = 1; - ChannelHandle channel_handle = 2; -} - -message RecvRequest { - Shape shape = 1; - ChannelHandle channel_handle = 2; -} - -message GatherRequest { - ComputationDataHandle input = 1; - ComputationDataHandle gather_indices = 2; - GatherDimensionNumbers dimension_numbers = 3; - repeated int64 window_bounds = 4; -} - message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -958,59 +528,3 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } - -message OpRequest { - ComputationHandle computation = 1; - OpMetadata metadata = 33; - OpSharding sharding = 40; - - oneof op { - BinaryOpRequest binary_op_request = 2; - BroadcastRequest broadcast_request = 3; - CallRequest call_request = 4; - ConcatenateRequest concatenate_request = 5; - ConstantRequest constant_request = 6; - ConvertRequest convert_request = 7; - ConvolveRequest convolve_request = 8; - CrossReplicaSumRequest cross_replica_sum_request = 9; - CustomCallRequest custom_call_request = 10; - DotRequest dot_request = 43; - DynamicSliceRequest dynamic_slice_request = 11; - DynamicUpdateSliceRequest dynamic_update_slice_request = 12; - GetTupleElementRequest get_tuple_element_request = 13; - InfeedRequest infeed_request = 14; - MapRequest map_request = 15; - PadRequest pad_request = 16; - ParameterRequest parameter_request = 17; - ReducePrecisionRequest reduce_precision_request = 36; - ReduceRequest reduce_request = 18; - ReduceWindowRequest reduce_window_request = 19; - ReshapeRequest reshape_request = 20; - ReverseRequest reverse_request = 21; - RngRequest rng_request = 22; - SelectAndScatterRequest select_and_scatter_request = 23; - SliceRequest slice_request = 24; - TernaryOpRequest ternary_op_request = 25; - TraceRequest trace_request = 26; - TransposeRequest transpose_request = 34; - UnaryOpRequest unary_op_request = 27; - VariadicOpRequest variadic_op_request = 28; - WhileRequest while_request = 29; - SendRequest send_request = 30; - RecvRequest recv_request = 31; - OutfeedRequest outfeed_request = 32; - BatchNormTrainingRequest batch_norm_training_request = 35; - BatchNormGradRequest batch_norm_grad_request = 37; - BatchNormInferenceRequest batch_norm_inference_request = 38; - FftRequest fft_request = 41; - ConvertRequest bitcast_convert_request = 42; - ConditionalRequest conditional_request = 44; - HostComputeRequest host_compute_request = 45; - GatherRequest gather_request = 46; - // Next: 47 - } -} - -message OpResponse { - ComputationDataHandle output = 1; -} diff --git a/tensorflow/compiler/xla/xlalogo.png b/tensorflow/compiler/xla/xlalogo.png new file mode 100644 index 0000000000000000000000000000000000000000..7a0a295953d0c47b23718197dcbab1677b337455 Binary files /dev/null and b/tensorflow/compiler/xla/xlalogo.png differ diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index bab37e8906e5c648acdc1556da7e5f4601776ff5..50b1ae5cc3cba2d6ac89c4415a3419ffdf7aec93 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -8,6 +8,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("//tensorflow:tensorflow.bzl", "if_not_windows") py_library( name = "contrib_py", @@ -24,22 +25,26 @@ py_library( "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", + "//tensorflow/contrib/checkpoint/python:checkpoint", "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", - "//tensorflow/contrib/coder:coder_ops_py", + "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", + "//tensorflow/contrib/constrained_optimization", + "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/gan", @@ -51,7 +56,6 @@ py_library( "//tensorflow/contrib/image:single_image_random_dot_stereograms_py", "//tensorflow/contrib/input_pipeline:input_pipeline_py", "//tensorflow/contrib/integrate:integrate_py", - "//tensorflow/contrib/kafka", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", "//tensorflow/contrib/kfac", @@ -63,28 +67,31 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/mixed_precision:mixed_precision", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", + "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/py2tf", "//tensorflow/contrib/receptive_field:receptive_field_py", + "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/contrib/rpc", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/contrib/seq2seq:seq2seq_py", "//tensorflow/contrib/signal:signal_py", @@ -110,6 +117,15 @@ py_library( "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka", + ], + "//conditions:default": [], + }) + if_not_windows([ + "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", + "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), ) @@ -119,7 +135,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", @@ -133,7 +148,13 @@ cc_library( "//tensorflow/contrib/text:all_kernels", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", - ]), + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_kernels", + ], + "//conditions:default": [], + }), ) cc_library( @@ -142,12 +163,10 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", - "//tensorflow/contrib/kafka:kafka_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", @@ -158,17 +177,11 @@ cc_library( "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", + ] + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], - ), - visibility = ["//tensorflow:__subpackages__"], + "//conditions:default": [], + }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 4f6f539027b040de7554d09fe9118ff97aa006f8..ad8c40395c2cdcc5e4288e04bb2115bd3627cdc9 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=g-import-not-at-top # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,18 +19,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import batching from tensorflow.contrib import bayesflow +from tensorflow.contrib import checkpoint from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler +from tensorflow.contrib import constrained_optimization +from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn from tensorflow.contrib import data from tensorflow.contrib import deprecated +from tensorflow.contrib import distribute from tensorflow.contrib import distributions from tensorflow.contrib import estimator from tensorflow.contrib import factorization @@ -54,17 +61,20 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import mixed_precision from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor +from tensorflow.contrib import proto from tensorflow.contrib import quantization from tensorflow.contrib import quantize from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn +from tensorflow.contrib import rpc from tensorflow.contrib import saved_model from tensorflow.contrib import seq2seq from tensorflow.contrib import signal @@ -83,8 +93,11 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite +if os.name != "nt": + from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field +from tensorflow.contrib.recurrent.python import recurrent_api as recurrent from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs from tensorflow.contrib.summary import summary @@ -92,6 +105,7 @@ from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", globals(), "tensorflow.contrib.ffmpeg") +del os del LazyLoader del absolute_import diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 8dff93b4f825277dcf0a64aa3b96bd809d36e1e9..881808a98bfd688c2efaa8beb5b8f11a2527fee8 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -11,6 +11,16 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_py_test") +py_library( + name = "all_reduce_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":all_reduce", + "//tensorflow/python:util", + ], +) + py_library( name = "all_reduce", srcs = [ @@ -45,16 +55,3 @@ tf_py_test( "//tensorflow/python:state_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/all_reduce/__init__.py b/tensorflow/contrib/all_reduce/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9824f4cfbf83d9b001a58cafe582226e96c076f --- /dev/null +++ b/tensorflow/contrib/all_reduce/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All-reduce implementations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.all_reduce.python.all_reduce import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'build_ring_all_reduce', + 'build_recursive_hd_all_reduce', + 'build_shuffle_all_reduce', + 'build_nccl_all_reduce', + 'build_nccl_then_ring', + 'build_nccl_then_recursive_hd', + 'build_nccl_then_shuffle', + 'build_shuffle_then_ring', + 'build_shuffle_then_shuffle' +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 6658f0d9c13f6db17b25354cde2593d57f104f17..159d985db5c48f8fe1a26350255f8d8f68482473 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -18,10 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import math -import re from tensorflow.contrib import nccl +from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -38,16 +39,15 @@ def _flatten_tensors(tensors): shape: the original shape of each element of input tensors Raises: - ValueError: tensors are empty or non-isomorphic. + ValueError: tensors are empty or non-isomorphic or have unknown shape. """ if not tensors: raise ValueError("tensors cannot be empty") shape = tensors[0].shape for tensor in tensors: shape = shape.merge_with(tensor.shape) - if shape.ndims is None: - raise ValueError("At least one of the tensors in 'tensors' must have " - "statically known rank.") + if not shape.is_fully_defined(): + raise ValueError("Tensors must have statically known shape.") if len(shape) != 1: reshaped = [] for t in tensors: @@ -660,21 +660,20 @@ def _split_by_task(devices, values): num_devices = len(devices) if num_devices != len(values): raise ValueError("len(devices) must equal len(values)") - pattern = re.compile(r"/task:(\d+)/") - per_task_devices = [] - per_task_values = [] + per_task_devices = collections.OrderedDict() + per_task_values = collections.OrderedDict() for d in range(num_devices): - m = pattern.search(devices[d]) - if m: - index = int(m.group(1)) - while index >= len(per_task_devices): - per_task_devices.append([]) - per_task_values.append([]) - per_task_devices[index].append(devices[d]) - per_task_values[index].append(values[d]) - else: + d_spec = device_lib.DeviceSpec.from_string(devices[d]) + if not hasattr(d_spec, "task") or d_spec.task is None: assert False, "failed to parse device %s" % devices[d] - return (per_task_devices, per_task_values) + index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) + if index not in per_task_devices: + per_task_devices[index] = [] + per_task_values[index] = [] + per_task_devices[index].append(devices[d]) + per_task_values[index].append(values[d]) + + return (list(per_task_devices.values()), list(per_task_values.values())) def build_nccl_all_reduce(input_tensors, red_op, un_op=None): diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index 47bab0a3670a90644972b2c961954a3036b8ecba..b3f5d92259df8475b205110dd3f0cee1cb5bde6f 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging class AllReduceTest(test_util.TensorFlowTestCase): + def testFlattenTensorsShapesDefined(self): + x = array_ops.placeholder(types_pb2.DT_FLOAT, [None]) + with self.assertRaisesRegexp(ValueError, + "must have statically known shape"): + ar._flatten_tensors([x, x]) + def testRingPermutations(self): # 0 devices pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 4bff3c27d22c4550747a651a59909bdef80e8285..c10179ba8b290b6209f5567d6323df4bcf711585 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -38,20 +38,6 @@ cc_library( alwayslink = 1, ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # JAR with Java bindings to TF. android_library( name = "android_tensorflow_inference_java", @@ -86,7 +72,7 @@ cc_binary( "-s", "-Wl,--gc-sections", "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - LINKER_SCRIPT, + "$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 380a652435ad089f46f3ca80e4fd43097fd96e10..513d519eabbd54f46fde9ec0f004247c02277732 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" namespace tensorflow { namespace { @@ -228,9 +229,8 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { } string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { - string output(name); - StringPiece piece(output); - piece.Consume(prefix_); + StringPiece piece(name); + str_util::ConsumePrefix(&piece, prefix_); return piece.ToString(); } @@ -243,6 +243,11 @@ bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) { return AAssetDir_getNextFileName(dir.get()) != NULL; } +Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern, + std::vector* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status AssetManagerFileSystem::NewWritableFile( const string& fname, std::unique_ptr* result) { return errors::Unimplemented("Asset storage is read only."); diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h index 665304b5eef1f8a3633c8c522259e20d744b1808..a87ff42ae217c429ecf5d2458b88b3431551ad97 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.h +++ b/tensorflow/contrib/android/asset_manager_filesystem.h @@ -66,6 +66,9 @@ class AssetManagerFileSystem : public FileSystem { Status DeleteDir(const string& d) override; Status RenameFile(const string& s, const string& t) override; + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + private: string RemoveAssetPrefix(const string& name); diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index a115d1610e2334a6626f29674f3dd195e3a3c648..ecf1a103d2981f409a4598d762fb26100217f779 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -75,7 +75,6 @@ target_link_libraries(tensorflow_inference include_directories( ${PREBUILT_DIR}/proto ${PREBUILT_DIR}/protobuf/include - ${PREBUILT_DIR}/nsync/public ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen ${TENSORFLOW_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 707853b59befc2625145ad96952fbf9f66d62b43..30de7b59af79cb36ee266a15bb6e668c2e3f628a 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/android/jni/run_stats_jni.h" #include + #include #include "tensorflow/core/protobuf/config.pb.h" @@ -73,7 +74,8 @@ JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz, StatSummarizer* s = requireHandle(env, handle); if (s == nullptr) return nullptr; std::stringstream ret; - ret << s->GetStatsByMetric("Top 10 CPU", StatSummarizer::BY_TIME, 10) + ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME, + 10) << s->GetStatsByNodeType() << s->ShortSummary(); return env->NewStringUTF(ret.str().c_str()); } diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/autograph/BUILD similarity index 75% rename from tensorflow/contrib/py2tf/BUILD rename to tensorflow/contrib/autograph/BUILD index d91220f6ddb859ff52d4e5853948cb667981009b..30dd846893c30b9205972bd5216cc1871ab03d76 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -15,16 +15,16 @@ filegroup( ) py_library( - name = "py2tf", + name = "autograph", srcs = [ "__init__.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/impl", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..a4aec8c74a9ad1418072471a5d3cde8c3b968a38 --- /dev/null +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -0,0 +1,48 @@ +# How to Contribute + +We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. + +## TensorFlow Code of Conduct +Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult [GitHub +Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +After a pull request is approved, we merge it. Note our merging process differs +from GitHub in that we pull and submit the change into an internal version +control system. This system automatically pushes a git commit to the GitHub +repository (with credit to the original author) and closes the pull request. + +## Style + +See the [AutoGraph style guide](STYLE_GUIDE.md). + +## Unit tests + +Please include unit tests when contributing new features ([example here](converters/continue_statements_test.py)), as they help to a) prove that your code works correctly, and b) guard against future breaking +changes to lower the maintenance cost. +It's also helpful to check that any +changes you propose do not break existing unit tests. You can run tests using the command, + +```shell +bazel test --config=opt --copt=-O3 --copt=-march=native \ + //tensorflow/contrib/autograph/... +``` + +from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md new file mode 100644 index 0000000000000000000000000000000000000000..674859bed4ec157d5d5b33b6fc015c930e54b392 --- /dev/null +++ b/tensorflow/contrib/autograph/README.md @@ -0,0 +1,122 @@ +# AutoGraph + +IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). + +AutoGraph is a Python to TensorFlow compiler. + +With AutoGraph, you can write [Eager style](https://www.tensorflow.org/programmers_guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. + +For example, this Python function: + +``` +def f(x): + if x < 0: + x = -x + return x +``` + +would be converted to this: + +``` +def graph_mode_f(x): + with tf.name_scope('f'): + + def if_true(): + with tf.name_scope('if_true'): + x_1, = x, + x_1 = tf.negative(x_1) + return x_1, + + def if_false(): + with tf.name_scope('if_false'): + x_1, = x, + return x_1, + x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false) + return x +``` + +so you can use it like an op: + +``` +with tf.Graph().as_default(): + x = tf.constant(-1.0) + + converted_f = autograph.to_graph(f) + y = converted_f(x) + + with tf.Session() as sess: + print(sess.run(y)) + # Output: 1 +``` + +# Getting started + +Use AutoGraph in one of the following ways, described below: + + 1. Annotations (simpler) + 2. Functional API (more flexible) + +To get started, install the latest nightly TensorFlow build: + +```shell +pip install -U tf-nightly +``` + +Then import the `autograph` module from `tf.contrib`: + +``` +from tensorflow.contrib import autograph as ag +``` + +### Interactive demo notebooks + +For more extensive examples, check out these interactive notebooks: + + * [RNN trained using Keras and Estimators](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) + * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb) + +## Using with annotations + +Annotating a function or class with `@convert` converts it in place: + +``` +@ag.convert() +def f(x): + if x < 0: + x = -x + return x +``` + +... so that it always outputs TensorFlow code: + +``` +with tf.Graph().as_default(): + x = tf.constant(-1) + + y = f(x) + + with tf.Session() as sess: + print(sess.run(y)) + # Output: 1 +``` + +## Using the functional API + +The functional API allows you to convert an existing function, class or object after it was defined: + +``` +converted_f = ag.to_graph(f) + +print(converted_f(tf.constant(-1))) +# Output: Tensor + +print(f(-1)) +# Output: 1 +``` + +You can use the functional API to inspect the generated code as well: + +``` +print(ag.to_code(f)) +# Output: +``` diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..866e5f583a34570dfddc733f57561ed1d2b7c5bf --- /dev/null +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -0,0 +1,75 @@ +# AutoGraph Style Guide + +This page contains style decisions that developers should follow when +contributing code to AutoGraph. + +## TensorFlow Style + +Follow the [TensorFlow style +guide](https://www.tensorflow.org/community/style_guide), the [documentation +guide](https://www.tensorflow.org/community/documentation) and the +[Google Python style guide](https://google.github.io/styleguide/pyguide.html). + +Naming conventions: + +1. The name is TensorFlow, not Tensorflow. +2. The name is AutoGraph, not Autograph. + +## AutoGraph Style + +Below are AutoGraph-specific conventions. In the event of conflict, +it supercedes all previous conventions. + +1. __Citations in Docstrings.__ Write a `#### References` subsection at the + bottom of any docstring with citations. Use ICLR’s bibliography style to + write references; for example, order entries by the first author's last + name. Add a link to the paper if the publication is open source (ideally, + arXiv). + + Write in-paragraph citations in general, e.g., [(Tran and Blei, 2018)][1]. + Write in-text citations when the citation is a noun, e.g., [Tran and Blei + (2018)][1]. Write citations with more than two authors using et al., e.g., + [(Tran et al., 2018)][1]. Separate multiple citations with semicolon, e.g., + ([Tran and Blei, 2018][1]; [Gelman and Rubin, 1992][2]). + + Examples: + + ```none + #### References + + # technical report + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + + # journal + [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation + Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. + + # arXiv preprint + # use "et al." for papers with too many authors to maintain + [3]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 + + # conference + [4]: Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, and Roger Grosse. + Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches. In _International Conference on Learning + Representations_, 2018. + https://arxiv.org/abs/1803.04386 + ``` + +2. Avoid LaTeX in docstrings. + + * It is not rendered in many (if not most) editors and can be hard to read + for both LaTeX experts and non-experts. + +3. Write docstring and comment math using ASCII friendly notation; python using + operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, + `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: + x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. + + * The more we stick to python style, the more someone can + copy/paste/execute. + * Python style is usually easier to read as ASCII. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdbad8f4c91c725294baa36acebbaf5b5e8cf5c --- /dev/null +++ b/tensorflow/contrib/autograph/__init__.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Autograph compiles Python code into equivalent TensorFlow code. + +Equivalent here means that they have the same effect when executed. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# TODO(mdan): Bring only the relevant symbols to the top level. +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph.impl.api import convert +from tensorflow.contrib.autograph.impl.api import converted_call +from tensorflow.contrib.autograph.impl.api import do_not_convert +from tensorflow.contrib.autograph.impl.api import RunMode +from tensorflow.contrib.autograph.impl.api import to_code +from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.impl.directives import set_element_type +from tensorflow.contrib.autograph.impl.directives import set_loop_options +from tensorflow.contrib.autograph.impl.special_functions import stack +from tensorflow.contrib.autograph.pyct.transformer import AutographParseError +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + # Main API + 'RunMode', + 'convert', + 'converted_call', + 'do_not_convert', + 'to_code', + 'to_graph', + # Overloaded operators + 'operators', + # Special functions and directives + 'set_element_type', + 'set_loop_options', + 'stack', + # Exceptions + 'AutographParseError', + # Utilities: to be removed + 'utils', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD similarity index 67% rename from tensorflow/contrib/py2tf/converters/BUILD rename to tensorflow/contrib/autograph/converters/BUILD index e9a96ec8d1dfc01ff6bc3b1fcaaef8e9b71a14a8..284ad84be566199adaaa1ab641d37528ae4dfd2d 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -24,10 +24,14 @@ py_library( "continue_statements.py", "control_flow.py", "decorators.py", - "for_loops.py", + "ifexp.py", "list_comprehension.py", + "lists.py", "logical_expressions.py", + "name_scopes.py", "side_effect_guards.py", + "single_return.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -45,8 +49,10 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ ":converters", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -56,9 +62,9 @@ py_test( name = "asserts_test", srcs = ["asserts_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -69,7 +75,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -78,20 +83,22 @@ py_test( name = "builtin_functions_test", srcs = ["builtin_functions_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( name = "call_trees_test", + size = "large", srcs = ["call_trees_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], ) @@ -102,7 +109,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -113,7 +119,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -124,18 +129,16 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( - name = "for_loops_test", - srcs = ["for_loops_test.py"], - srcs_version = "PY2AND3", + name = "name_scopes_test", + srcs = ["name_scopes_test.py"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -146,7 +149,16 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "lists_test", + srcs = ["lists_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", "//tensorflow/python:client_testlib", ], ) @@ -157,7 +169,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -166,9 +177,46 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", + tags = [ + # TODO(mdan): Fix. + "flaky", + "notap", + ], + deps = [ + ":test_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "single_return_test", + srcs = ["single_return_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "ifexp_test", + srcs = ["ifexp_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/__init__.py rename to tensorflow/contrib/autograph/converters/__init__.py index ca10896ee5c6c23d9b20ff23add9945de68e5bf9..e4e8eda42f655e204310eaa9defdd5c90bf06e15 100644 --- a/tensorflow/contrib/py2tf/converters/__init__.py +++ b/tensorflow/contrib/autograph/converters/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Code converters used by Py2TF.""" +"""Code converters used by Autograph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/py2tf/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py similarity index 85% rename from tensorflow/contrib/py2tf/converters/asserts.py rename to tensorflow/contrib/autograph/converters/asserts.py index 5b9b8e772bed82df2429fd6cb94dbf7b565e22b3..3b0db677ce5e417e7afea8d8fe4121a0352bb6d7 100644 --- a/tensorflow/contrib/py2tf/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,22 +20,20 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class AssertsTransformer(transformer.Base): """Transforms Print nodes to Call so they can be handled as functions.""" - # pylint:disable=invalid-name - def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ - tf.Assert(test, [msg]) + tf.Assert(test, (msg,)) """ if node.msg is None: @@ -44,9 +42,7 @@ class AssertsTransformer(transformer.Base): elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: - raise NotImplementedError('Can only convert string messages for now.') - - # pylint:enable=invalid-name + raise NotImplementedError('can only convert string messages for now.') def transform(node, context): diff --git a/tensorflow/contrib/py2tf/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py similarity index 90% rename from tensorflow/contrib/py2tf/converters/asserts_test.py rename to tensorflow/contrib/autograph/converters/asserts_test.py index 6611f2777a93a7e819c8becfa06a09b27f4e6aaf..cc913febe8d0f411588af69b87ec52ce58f4469c 100644 --- a/tensorflow/contrib/py2tf/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py new file mode 100644 index 0000000000000000000000000000000000000000..775d92c1d9f8bc35d1eda62f3f3ef7ee43414779 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -0,0 +1,141 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes break statements by de-sugaring into a control boolean.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# Tags for local state. +BREAK_USED = 'break_used' +CONTROL_VAR_NAME = 'control_var_name' + + +class BreakStatementTransformer(transformer.Base): + """Canonicalizes break statements into additional conditionals.""" + + def visit_Break(self, node): + self.set_local(BREAK_USED, True) + var_name = self.get_local(CONTROL_VAR_NAME) + # TODO(mdan): This will fail when expanded inside a top-level else block. + template = """ + var_name = True + continue + """ + return templates.replace(template, var_name=var_name) + + def _guard_if_present(self, block, var_name): + """Prevents the block from executing if var_name is set.""" + if not block: + return block + + template = """ + if not var_name: + block + """ + node = templates.replace( + template, + var_name=var_name, + block=block) + return node + + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used + + def visit_While(self, node): + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + break_var = self.context.namer.new_symbol('break_', scope.referenced) + + node.test = self.visit(node.test) + node.body, break_used = self._track_body(node.body, break_var) + # A break in the else clause applies to the containing scope. + node.orelse = self.visit_block(node.orelse) + + if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + + template = """ + var_name = False + while test and not var_name: + body + else: + orelse + """ + node = templates.replace( + template, + var_name=break_var, + test=node.test, + body=node.body, + orelse=guarded_orelse) + + return node + + def visit_For(self, node): + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + break_var = self.context.namer.new_symbol('break_', scope.referenced) + + node.target = self.visit(node.target) + node.iter = self.visit(node.iter) + node.body, break_used = self._track_body(node.body, break_var) + # A break in the else clause applies to the containing scope. + node.orelse = self.visit_block(node.orelse) + + if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) + template = """ + var_name = False + for target in iter_: + (var_name,) + body + else: + orelse + """ + node = templates.replace( + template, + var_name=break_var, + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + + anno.setanno(node[1], 'extra_test', extra_test) + + return node + + +def transform(node, context): + return BreakStatementTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1af59e9b5260fe0d3a3ef72c7a003dc451e230f3 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -0,0 +1,147 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for break_statements module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.python.platform import test + + +class BreakCanonicalizationTest(converter_test_base.TestCase): + + def test_basic_while(self): + + def test_fn(x): + v = [] + while x > 0: + x -= 1 + if x % 2 == 0: + break + v.append(x) + return v + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + self.assertEqual([], result.test_fn(0)) + self.assertEqual([], result.test_fn(1)) + self.assertEqual([3], result.test_fn(4)) + + def test_basic_for(self): + + def test_fn(a): + v = [] + for x in a: + x -= 1 + if x % 2 == 0: + break + v.append(x) + return v + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + # The break is incompletely canonicalized. The loop will not interrupt, + # but the section following the break will be skipped. + self.assertEqual([], result.test_fn([])) + self.assertEqual([3, 3], result.test_fn([4, 4])) + self.assertEqual([3], result.test_fn([4, 5])) + self.assertEqual([3], result.test_fn([5, 4])) + + def test_deeply_nested(self): + + def test_fn(x): + v = [] + u = [] + w = [] + while x > 0: + x -= 1 + if x % 2 == 0: + if x % 3 != 0: + u.append(x) + else: + w.append(x) + break + v.append(x) + return v, u, w + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + self.assertEqual(([], [], []), result.test_fn(0)) + self.assertEqual(([2, 1], [2], [0]), result.test_fn(3)) + self.assertEqual(([10, 9, 8, 7], [10, 8], [6]), result.test_fn(11)) + + def test_nested_loops(self): + + def test_fn(x): + v = [] + u = [] + while x > 0: + x -= 1 + y = x + while y > 0: + y -= 1 + if y % 2 == 0: + break + u.append(y) + if x == 0: + break + v.append(x) + return v, u + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + self.assertEqual(([], []), result.test_fn(0)) + self.assertEqual(([1], []), result.test_fn(2)) + self.assertEqual(([2, 1], [1]), result.test_fn(3)) + self.assertEqual(([4, 3, 2, 1], [3, 1]), result.test_fn(5)) + + def test_loop_else(self): + + def test_fn(x): + v = [] + u = [] + while x > 0: + x -= 1 + y = x + while y > 1: + break + else: + u.append(y) + break + v.append(x) + return v, u + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + self.assertEqual(([], []), result.test_fn(0)) + self.assertEqual(([], [1]), result.test_fn(2)) + self.assertEqual(([2], [1]), result.test_fn(3)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py similarity index 75% rename from tensorflow/contrib/py2tf/converters/builtin_functions.py rename to tensorflow/contrib/autograph/converters/builtin_functions.py index 2eb00f90575920ac948e799b0e97a9cfccb42fad..231e4ee35a72f51845a476d9f605986ac73b4676 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class BuiltinFunctionTransformer(transformer.Base): @@ -31,28 +31,26 @@ class BuiltinFunctionTransformer(transformer.Base): TF equivalent, like `len`. """ - def __init__(self, context): - super(BuiltinFunctionTransformer, self).__init__(context) - - # pylint:disable=invalid-name - - def _convert_len(self, node): + def _convert_builtin(self, node): template = """ - tf.shape(args)[0] + ag__.utils.dynamic_builtin(func, args) """ - return templates.replace(template, args=node.args)[0].value + return templates.replace(template, func=node.func, args=node.args)[0].value def _convert_print(self, node): template = """ - py2tf_utils.call_print(args) + ag__.utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. - if isinstance(node.func, gast.Name) and node.func.id == 'len': - return self._convert_len(node) + # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. + if (isinstance(node.func, gast.Name) and + node.func.id in ('len', 'range', 'xrange', 'float', 'int')): + return self._convert_builtin(node) + # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': return self._convert_print(node) return node @@ -69,8 +67,6 @@ class BuiltinFunctionTransformer(transformer.Base): function_call = templates.replace(template, fname='print', args=args)[0] return self.visit(function_call) - # pylint:enable=invalid-name - def transform(node, context): return BuiltinFunctionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py similarity index 64% rename from tensorflow/contrib/py2tf/converters/builtin_functions_test.py rename to tensorflow/contrib/autograph/converters/builtin_functions_test.py index b279ff77ef10b96586d3d68585adb0d5424afb90..30272409df322560b04ba75b3e1cb6f9ad5ff0af 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -22,12 +22,10 @@ import sys import six -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -47,7 +45,9 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): sess.run( result.test_fn(constant_op.constant([0, 0, 0])))) - def test_print_with_op(self): + self.assertEqual(3, result.test_fn([0, 0, 0])) + + def test_print(self): def test_fn(a): print(a) @@ -55,14 +55,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a') + result.test_fn(constant_op.constant('a')) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a\n') finally: @@ -70,41 +68,19 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): def test_print_with_op_multiple_values(self): - def test_fn(a, b): - print(a, b) - - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) - - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: - with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn('a', 1) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_print_with_py_func(self): - def test_fn(a, b, c): print(a, b, c) node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include logging_ops.Print here, to verify - # that py_func is used. - with self.compiled(node, script_ops.py_func) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a', 1, [2, 3]) + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3]) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') finally: diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py similarity index 65% rename from tensorflow/contrib/py2tf/converters/call_trees.py rename to tensorflow/contrib/autograph/converters/call_trees.py index 1050ba654c63bb52c1c5e71c981a6a0baa3fc987..b6ecdcb7809b1ad7e7461324cb6a110ef4180609 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -22,17 +22,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import types +from collections import namedtuple import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect +class FunctionInfo(namedtuple('FunctionInfo', ('dtype',))): + pass + + +# TODO(mdan): Move this to config.py. +KNOWN_NUMPY_FUNCTIONS = { + ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'), +} + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -72,9 +84,8 @@ class CallTreeTransformer(transformer.Base): self.uncompiled_modules = uncompiled_modules self.nocompile_decorators = nocompile_decorators - # pylint:disable=invalid-name - def _resolve_name(self, node): + """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): @@ -99,7 +110,19 @@ class CallTreeTransformer(transformer.Base): (owner_type, node.attr)) return None + def _function_is_compilable(self, target_entity): + """Determines whether an entity can be compiled at all.""" + # TODO(mdan): This is just a placeholder. Implement. + return not inspect_utils.isbuiltin(target_entity) + def _should_compile(self, node, fqn): + """Determines whether an entity should be compiled in the context.""" + # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. + module_name = fqn[0] + for mod in self.uncompiled_modules: + if module_name.startswith(mod[0] + '.'): + return False + for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False @@ -123,7 +146,7 @@ class CallTreeTransformer(transformer.Base): # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. - # To parse and re-analize each function for every call site could be quite + # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) @@ -141,33 +164,6 @@ class CallTreeTransformer(transformer.Base): return True - def _determine_function_owner(self, m): - # TODO(mdan): The parent type should be known at analysis. Use that instead. - if hasattr(m, 'im_class'): # Python 2 - return m.im_class - if hasattr(m, '__qualname__'): # Python 3 - # Object attributes: should be bound to "self". - if hasattr(m, '__self__'): - return type(m.__self__) - - # Class attributes: should have the owner name in their namespace. - qn = m.__qualname__.split('.') - if len(qn) < 2: - return None - owner_name, func_name = qn[-2:] - if func_name != m.__name__: - raise ValueError('Inconsistent names detected ' - '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % - (func_name, m.__name__, m)) - if owner_name == '': - return None - if owner_name not in self.context.namespace: - raise ValueError( - 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % - (owner_name, m, self.context.namespace)) - return self.context.namespace[owner_name] - return None - def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') @@ -182,7 +178,11 @@ class CallTreeTransformer(transformer.Base): target_fqn, live_entity=target_entity) do_rename = True else: - owner_type = self._determine_function_owner(target_entity) + if anno.hasanno(node.func, 'parent_type'): + owner_type = anno.getanno(node.func, 'parent_type') + else: + # Fallback - not reliable. + owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) @@ -196,15 +196,54 @@ class CallTreeTransformer(transformer.Base): return node def _wrap_to_py_func_no_return(self, node): - # TODO(mdan): Properly handle varargs, kwargs, etc. + # TODO(mdan): Properly handle varargs, etc. template = """ - py2tf_utils.wrap_py_func(func, None, (original_args,), True) + ag__.utils.wrap_py_func(func, None, (args,), kwargs, True) """ - return templates.replace(template, func=node.func, original_args=node.args) - - def _function_is_compilable(self, target_entity): - # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_entity, types.BuiltinFunctionType) + return templates.replace( + template, + func=node.func, + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _wrap_to_py_func_single_return(self, node, dtype): + # TODO(mdan): Properly handle varargs, etc. + template = """ + ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False) + """ + return templates.replace_as_expression( + template, + func=node.func, + dtype=parser.parse_expression(dtype), + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _insert_dynamic_conversion(self, node): + """Inlines a dynamic conversion for a dynamic function.""" + # TODO(mdan): Pass information on the statically compiled functions. + # Having access to the statically compiled functions can help avoid + # unnecessary compilation. + # For example, this would lead to function `a` being compiled twice: + # + # def a(): + # v = b + # b() + # def b(): + # a() + # + # This is really a problem with recursive calls, which currently can + # only be gated by a static condition, and should be rare. + # TODO(mdan): It probably makes sense to use dynamic conversion every time. + # Before we could convert all the time though, we'd need a reasonable + # caching mechanism. + template = """ + ag__.converted_call(func, True, False, {}, args) + """ + call_expr = templates.replace(template, func=node.func, args=node.args) + new_call = call_expr[0].value + # TODO(mdan): Improve the template mechanism to better support this. + new_call.keywords = node.keywords + return new_call def visit_Expr(self, node): if isinstance(node.value, gast.Call): @@ -239,20 +278,41 @@ class CallTreeTransformer(transformer.Base): self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') + if anno.hasanno(node.func, 'fqn'): + target_fqn = anno.getanno(node.func, 'fqn') + else: + target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) + elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: + # TODO(mdan): Should we replace these with equivalent TF ops instead? + node = self._wrap_to_py_func_single_return( + node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: - raise NotImplementedError('py_func with return values') + raise NotImplementedError( + 'py_func with return values (unknown function)') else: + if anno.hasanno(node.func, anno.Basic.QN): + # Special-case a few builtins that otherwise go undetected. This + # normally doesn't pose a problem, but the dict built-in doesn't + # work with inspect.getargspec which is required for dynamic functions. + # Note: expecting this is resilient to aliasing (e.g. + # dict = an_evil_dict), because in those cases the regular mechanisms + # process a simple user function. + qn = anno.getanno(node.func, anno.Basic.QN) + # Add items to this list as needed. + if str(qn) in ('dict',): + return node + + if ast_util.matches(node, 'super(_)'): + # super() calls are preserved. The class conversion mechanism will + # ensure that they return the correct value. + return node + if self.context.recursive: - raise NotImplementedError('Could not resolve target function.') - else: - # TODO(mdan): Double check. Is this reachable code? - pass + node = self._insert_dynamic_conversion(node) return node - # pylint:enable=invalid-name - def transform(node, context, uncompiled_modules, nocompile_decorators): """Transform function call to the compiled counterparts. diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py similarity index 75% rename from tensorflow/contrib/py2tf/converters/call_trees_test.py rename to tensorflow/contrib/autograph/converters/call_trees_test.py index 777648dc0b31863227262fbf931aba680bb4ed98..303dd54a4ee49de27fad0c5cdc2d6274abfe0fa8 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import converter_test_base +import numpy as np + +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -30,7 +34,7 @@ class CallTreesTest(converter_test_base.TestCase): def test_basic(self): def test_fn_1(_): - raise ValueError('This should not be called in the compiled verison.') + raise ValueError('This should not be called in the compiled version.') def renamed_test_fn_1(a): return a + 1 @@ -47,6 +51,21 @@ class CallTreesTest(converter_test_base.TestCase): result.renamed_test_fn_1 = renamed_test_fn_1 self.assertEquals(3, result.test_fn_2(1)) + def test_dynamic_function(self): + + def test_fn_1(): + raise ValueError('This should be masked by the mock.') + + def test_fn_2(f): + return f() + 3 + + node = self.parse_and_analyze(test_fn_2, {}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + # 10 = 7 (from the mock) + 3 (from test_fn_2) + self.assertEquals(10, result.test_fn_2(test_fn_1)) + def test_simple_methods(self): class TestClass(object): @@ -59,6 +78,7 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, + namer=converter_test_base.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) node = call_trees.transform(node, self.ctx, (), ()) @@ -89,6 +109,20 @@ class CallTreesTest(converter_test_base.TestCase): sess.run(sess.graph.get_operations()[0]) self.assertEquals('bar', a.foo) + def test_py_func_wrap_known_function(self): + + def test_fn(): + return np.random.binomial(2, 0.5) + + node = self.parse_and_analyze(test_fn, {'np': np}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node, dtypes.int64) as result: + result.np = np + with self.test_session() as sess: + self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) + self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) + def test_uncompiled_modules(self): def test_fn(a): diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py new file mode 100644 index 0000000000000000000000000000000000000000..0417817a77e706fc0ce805f7391bea600f5fbb2d --- /dev/null +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes continue statements by de-sugaring into a control boolean.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# Tags for local state. +CONTROL_VAR_NAME = 'control_var_name' +CONTINUE_USED = 'continue_used' +GUARD_CREATED = 'guard_created' +CREATE_GUARD_NEXT = 'create_guard_next' + + +class ContinueCanonicalizationTransformer(transformer.Base): + """Canonicalizes continue statements into additional conditionals.""" + + def visit_Continue(self, node): + self.set_local(CONTINUE_USED, True) + template = """ + var_name = True + """ + return templates.replace( + template, var_name=self.get_local(CONTROL_VAR_NAME)) + + def _postprocess_statement(self, node): + # Example of how the state machine below works: + # + # 1| stmt # State: CONTINUE_USED = False + # | # Action: none + # 2| if cond: + # 3| continue # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = False + # | # Action: set CREATE_GUARD_NEXT = True + # 4| stmt # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = True + # | # Action: create `if not continue_used`, + # | # set GUARD_CREATED = True + # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True + # | # Action: none (will be wrapped under previously + # | # created if node) + + if self.get_local(CONTINUE_USED, False): + if self.get_local(GUARD_CREATED, False): + return node, None + + elif not self.get_local(CREATE_GUARD_NEXT, False): + self.set_local(CREATE_GUARD_NEXT, True) + return node, None + + else: + self.set_local(GUARD_CREATED, True) + template = """ + if not var_name: + original_node + """ + cond, = templates.replace( + template, + var_name=self.get_local(CONTROL_VAR_NAME), + original_node=node) + return cond, cond.body + return node, None + + def _visit_loop_body(self, node, nodes): + self.enter_local_scope() + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + self.set_local(CONTROL_VAR_NAME, continue_var) + + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + + if self.get_local(CONTINUE_USED, False): + template = """ + var_name = False + """ + control_var_init = templates.replace(template, var_name=continue_var) + nodes = control_var_init + nodes + + self.exit_local_scope() + return nodes + + def _visit_non_loop_body(self, nodes): + self.enter_local_scope(inherit=(CONTROL_VAR_NAME,)) + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + continue_used = self.get_local(CONTINUE_USED, False) + self.exit_local_scope(keep=(CONTINUE_USED,)) + return nodes, continue_used + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) + return node + + def visit_For(self, node): + node.target = self.generic_visit(node.target) + node.iter = self.generic_visit(node.iter) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) + return node + + def visit_If(self, node): + node.test = self.generic_visit(node.test) + node.body, continue_used_body = self._visit_non_loop_body(node.body) + node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse) + self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse) + return node + + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body, _ = self._visit_non_loop_body(node.body) + return node + + +def transform(node, namer): + return ContinueCanonicalizationTransformer(namer).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/continue_statements_test.py rename to tensorflow/contrib/autograph/converters/continue_statements_test.py index a598dcd1aed29478b7e3fe27e3c1b20010247dd9..bcbb316d7459aa5a25bb0bd128cd6e359a393288 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ddbe8a04f64848d6ec21155d8d85f60e19d276 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -0,0 +1,342 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles control flow statements: while, for, if.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +class SymbolNamer(object): + """Describes the interface for ControlFlowTransformer's namer.""" + + def new_symbol(self, name_root, reserved_locals): + """Generate a new unique symbol. + + Args: + name_root: String, used as stem in the new name. + reserved_locals: Set(string), additional local symbols that are reserved + and which should not be used. + Returns: + String. + """ + raise NotImplementedError() + + +class ControlFlowTransformer(transformer.Base): + """Transforms control flow structures like loops an conditionals.""" + + def _create_cond_branch(self, body_name, aliased_orig_names, + aliased_new_names, body, returns): + if aliased_orig_names: + template = """ + def body_name(): + aliased_new_names, = aliased_orig_names, + body + return (returns,) + """ + return templates.replace( + template, + body_name=body_name, + body=body, + aliased_orig_names=aliased_orig_names, + aliased_new_names=aliased_new_names, + returns=returns) + else: + template = """ + def body_name(): + body + return (returns,) + """ + return templates.replace( + template, body_name=body_name, body=body, returns=returns) + + def _create_cond_expr(self, results, test, body_name, orelse_name): + if results is not None: + template = """ + results = ag__.utils.run_cond(test, body_name, orelse_name) + """ + return templates.replace( + template, + test=test, + results=results, + body_name=body_name, + orelse_name=orelse_name) + else: + template = """ + ag__.utils.run_cond(test, body_name, orelse_name) + """ + return templates.replace( + template, test=test, body_name=body_name, orelse_name=orelse_name) + + def visit_If(self, node): + self.generic_visit(node) + + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) + body_defs = body_scope.created | body_scope.modified + orelse_defs = orelse_scope.created | orelse_scope.modified + live = anno.getanno(node, 'live_out') + + # We'll need to check if we're closing over variables that are defined + # elsewhere in the function + # NOTE: we can only detect syntactic closure in the scope + # of the code passed in. If the AutoGraph'd function itself closes + # over other variables, this analysis won't take that into account. + defined = anno.getanno(node, 'defined_in') + + # We only need to return variables that are + # - modified by one or both branches + # - live (or has a live parent) at the end of the conditional + modified = [] + for def_ in body_defs | orelse_defs: + def_with_parents = set((def_,)) | def_.support_set + if live & def_with_parents: + modified.append(def_) + + # We need to check if live created variables are balanced + # in both branches + created = live & (body_scope.created | orelse_scope.created) + + # The if statement is illegal if there are variables that are created, + # that are also live, but both branches don't create them. + if created: + if created != (body_scope.created & live): + raise ValueError( + 'The main branch does not create all live symbols that the else ' + 'branch does.') + if created != (orelse_scope.created & live): + raise ValueError( + 'The else branch does not create all live symbols that the main ' + 'branch does.') + + # Alias the closure variables inside the conditional functions + # to avoid errors caused by the local variables created in the branch + # functions. + # We will alias variables independently for body and orelse scope, + # because different branches might write different variables. + aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) + aliased_orelse_orig_names = tuple(orelse_scope.modified - + orelse_scope.created) + aliased_body_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + for s in aliased_body_orig_names) + aliased_orelse_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + for s in aliased_orelse_orig_names) + + alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) + alias_orelse_map = dict( + zip(aliased_orelse_orig_names, aliased_orelse_new_names)) + + node_body = ast_util.rename_symbols(node.body, alias_body_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) + + if not modified: + # When the cond would return no value, we leave the cond called without + # results. That in turn should trigger the side effect guards. The + # branch functions will return a dummy value that ensures cond + # actually has some return value as well. + results = None + elif len(modified) == 1: + results = modified[0] + else: + results = gast.Tuple([s.ast() for s in modified], None) + + body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.context.namer.new_symbol('if_false', + orelse_scope.referenced) + if modified: + + def build_returns(aliased_names, alias_map, scope): + """Builds list of return variables for a branch of a conditional.""" + returns = [] + for s in modified: + if s in aliased_names: + returns.append(alias_map[s]) + else: + if s not in scope.created | defined: + raise ValueError( + 'Attempting to return variable "%s" from the true branch of ' + 'a conditional, but it was not closed over, or created in ' + 'this branch.' % str(s)) + else: + returns.append(s) + return tuple(returns) + + body_returns = build_returns(aliased_body_orig_names, alias_body_map, + body_scope) + orelse_returns = build_returns(aliased_orelse_orig_names, + alias_orelse_map, orelse_scope) + + else: + body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value + + body_def = self._create_cond_branch( + body_name, + aliased_orig_names=tuple(aliased_body_orig_names), + aliased_new_names=tuple(aliased_body_new_names), + body=node_body, + returns=body_returns) + orelse_def = self._create_cond_branch( + orelse_name, + aliased_orig_names=tuple(aliased_orelse_orig_names), + aliased_new_names=tuple(aliased_orelse_new_names), + body=node_orelse, + returns=orelse_returns) + cond_expr = self._create_cond_expr(results, node.test, body_name, + orelse_name) + + return body_def + orelse_def + cond_expr + + def visit_While(self, node): + self.generic_visit(node) + + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) + cond_closure = set() + for s in cond_scope.referenced: + for root in s.support_set: + if root not in body_scope.created: + cond_closure.add(root) + + state = list(body_closure) + if not state: + # TODO(mdan): Implement this properly. + # To complete this statement, we need to check whether any variable + # created inside the body scope is used before being modified outside the + # scope. This should be done during activity analysis, and in general + # should cover the case where variables may not be initialized. + raise ValueError('cannot convert while loop: no outputs') + + state_ssf = [ + self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] + state_ast_tuple = state + else: + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + + node_body = ast_util.rename_symbols(node.body, ssf_map) + test = ast_util.rename_symbols(node.test, ssf_map) + + template = """ + def test_name(state_ssf): + return test + def body_name(state_ssf): + body + return state_ssf, + state_ast_tuple = ag__.while_stmt( + test_name, body_name, (state,), (extra_deps,)) + """ + node = templates.replace( + template, + state=state, + state_ssf=state_ssf, + state_ast_tuple=state_ast_tuple, + test_name=self.context.namer.new_symbol('loop_test', + body_scope.referenced), + test=test, + body_name=self.context.namer.new_symbol('loop_body', + body_scope.referenced), + body=node_body, + extra_deps=tuple(s.ast() for s in cond_closure), + ) + + return node + + def visit_For(self, node): + self.generic_visit(node) + + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + state = list(body_closure) + + state_ssf = [ + self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] + state_ast_tuple = state + else: + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + + node_body = ast_util.rename_symbols(node.body, ssf_map) + if anno.hasanno(node, 'extra_test'): + extra_test = anno.getanno(node, 'extra_test') + extra_test = ast_util.rename_symbols(extra_test, ssf_map) + else: + extra_test = parser.parse_expression('True') + + template = """ + def extra_test_name(state_ssf): + return extra_test_expr + def body_name(iterate, state_ssf): + body + return state_ssf, + state_ast_tuple = ag__.for_stmt( + iter_, extra_test_name, body_name, (state,)) + """ + node = templates.replace( + template, + state=state, + state_ssf=state_ssf, + state_ast_tuple=state_ast_tuple, + iter_=node.iter, + iterate=node.target, + extra_test_name=self.context.namer.new_symbol('extra_test', + all_referenced), + extra_test_expr=extra_test, + body_name=self.context.namer.new_symbol('loop_body', all_referenced), + body=node_body) + + return node + + +def transform(node, context): + cfg.run_analyses(node, cfg.Liveness(context)) + cfg.run_analyses(node, cfg.Defined(context)) + node = ControlFlowTransformer(context).visit(node) + return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9d23d9b5b7e8e8480e04fccc1c8c81799abf382b --- /dev/null +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -0,0 +1,257 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for control_flow module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.platform import test + + +class ControlFlowTest(converter_test_base.TestCase): + + def test_simple_while(self): + + def test_fn(n): + i = 0 + s = 0 + while i < n: + s += i + i += 1 + return s, i, n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + self.assertEqual((10, 5, 5), + sess.run(result.test_fn(constant_op.constant(5)))) + + def test_while_single_var(self): + + def test_fn(n): + while n > 0: + n -= 1 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) + + def test_simple_if(self): + + def test_fn(n): + a = 0 + b = 0 + if n > 0: + a = -n + else: + b = 2 * n + return a, b + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + self.assertEqual((-1, 0), + sess.run(result.test_fn(constant_op.constant(1)))) + self.assertEqual((0, -2), + sess.run(result.test_fn(constant_op.constant(-1)))) + + def test_if_single_var(self): + + def test_fn(n): + if n > 0: + n = -n + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + + def test_imbalanced_aliasing(self): + + def test_fn(n): + if n > 0: + n = 3 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_ignore_unread_variable(self): + + def test_fn(n): + b = 3 # pylint: disable=unused-variable + if n > 0: + b = 4 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_handle_temp_variable(self): + + def test_fn_using_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z, w + + node = self.parse_and_analyze(test_fn_using_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + self.assertEqual(3, w) + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + self.assertEqual(2, w) + + def test_fn_ignoring_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z + + node = self.parse_and_analyze(test_fn_ignoring_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + + def test_simple_for(self): + + def test_fn(l): + s1 = 0 + s2 = 0 + for e in l: + s1 += e + s2 += e * e + return s1, s2 + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + l = [1, 2, 3] + self.assertEqual( + test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) + l = [] + self.assertEqual( + test_fn(l), + sess.run( + result.test_fn( + constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) + + def test_for_single_var(self): + + def test_fn(l): + s = 0 + for e in l: + s += e + return s + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + l = [1, 2, 3] + self.assertEqual( + test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) + l = [] + self.assertEqual( + test_fn(l), + sess.run( + result.test_fn( + constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) + + def test_for_with_iterated_expression(self): + + eval_count = [0] + + def count_evals(x): + eval_count[0] += 1 + return x + + def test_fn(n): + s = 0 + for e in count_evals(range(n)): + s += e + return s + + node = self.parse_and_analyze(test_fn, {'count_evals': count_evals}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + result.count_evals = count_evals + self.assertEqual(test_fn(5), result.test_fn(5)) + # count_evals ran twice, once for test_fn and another for result.test_fn + self.assertEqual(eval_count[0], 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py similarity index 55% rename from tensorflow/contrib/py2tf/converters/converter_test_base.py rename to tensorflow/contrib/autograph/converters/converter_test_base.py index 67747183dd323a799a04943ce4c7fe8c4093d002..41c2e71702e7e3ee3811a2cbee27c8c988eb3a5c 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/converters/converter_test_base.py @@ -21,26 +21,31 @@ from __future__ import print_function import contextlib import imp -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test class FakeNamer(object): + """A fake namer that uses a global counter to generate unique names.""" + + def __init__(self): + self.i = 0 def new_symbol(self, name_root, used): - i = 0 while True: - name = '%s%d' % (name_root, i) + self.i += 1 + name = '%s%d' % (name_root, self.i) if name not in used: return name - i += 1 def compiled_function_name(self, original_fqn, @@ -52,26 +57,51 @@ class FakeNamer(object): return ('renamed_%s' % '_'.join(original_fqn)), True +class FakeNoRenameNamer(FakeNamer): + + def compiled_function_name(self, original_fqn, **_): + return str(original_fqn), False + + class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager def compiled(self, node, *symbols): - source = '' + source = None + + self.dynamic_calls = [] + def converted_call(*args): + """Mock version of api.converted_call.""" + self.dynamic_calls.append(args) + return 7 + try: result, source = compiler.ast_to_object(node) - result.tf = self.make_fake_tf(*symbols) - result.py2tf_utils = utils + result.tf = self.make_fake_mod('fake_tf', *symbols) + fake_ag = self.make_fake_mod('fake_ag', converted_call) + fake_ag.__dict__.update(operators.__dict__) + fake_ag.__dict__['utils'] = utils + result.__dict__['ag__'] = fake_ag yield result except Exception: # pylint:disable=broad-except - print('Offending compiled code:\n%s' % source) + if source is None: + print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) + else: + print('Offending compiled code:\n%s' % source) raise - def make_fake_tf(self, *symbols): - fake_tf = imp.new_module('fake_tf') + def make_fake_mod(self, name, *symbols): + fake_mod = imp.new_module(name) for s in symbols: - setattr(fake_tf, s.__name__, s) - return fake_tf + if hasattr(s, '__name__'): + setattr(fake_mod, s.__name__, s) + elif hasattr(s, 'name'): + # This is a bit of a hack, but works for things like tf.int32 + setattr(fake_mod, s.name, s) + else: + raise ValueError('can not attach %s - what should be its name?' % s) + return fake_mod def attach_namespace(self, module, **ns): for k, v in ns.items(): @@ -83,6 +113,7 @@ class TestCase(test.TestCase): namer=None, arg_types=None, include_type_analysis=True, + owner_type=None, recursive=True): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( @@ -92,7 +123,9 @@ class TestCase(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, - recursive=recursive) + owner_type=owner_type, + recursive=recursive, + type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py similarity index 57% rename from tensorflow/contrib/py2tf/converters/decorators.py rename to tensorflow/contrib/autograph/converters/decorators.py index 3f620c1cd2d9b75f82410754a7e812e13eabe3ae..92445f31746cf94856ea43893f99a2ba60355fb5 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,8 +24,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import pretty_printer class DecoratorsTransformer(gast.NodeTransformer): @@ -33,6 +33,7 @@ class DecoratorsTransformer(gast.NodeTransformer): def __init__(self, remove_decorators): self.remove_decorators = remove_decorators + self.additional_dependencies = set() # pylint:disable=invalid-name @@ -44,13 +45,38 @@ class DecoratorsTransformer(gast.NodeTransformer): dec_func = dec.func else: dec_func = dec + + # Special cases. + # TODO(mdan): Is there any way we can treat these more generically? + # We may want to forego using decorators altogether if we can't + # properly support them. + if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',): + # Assumption: decorators are only visible in the AST when converting + # a function inline (via another decorator). + # In that case, the converted function is no longer part of the + # original object that it was declared into. + # This is currently verified by tests. + continue + if not anno.hasanno(dec_func, 'live_val'): raise ValueError( 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) + dec_value = anno.getanno(dec_func, 'live_val') if dec_value not in self.remove_decorators: - kept_decorators.append(dec) - node.decorator_list = kept_decorators + kept_decorators.append((dec, dec_value)) + + for _, dec_value in kept_decorators: + if dec_value.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_value, + dec_value.__module__)) + else: + self.additional_dependencies.add(dec_value) + + node.decorator_list = [dec for dec, _ in kept_decorators] return node # pylint:enable=invalid-name @@ -59,4 +85,4 @@ class DecoratorsTransformer(gast.NodeTransformer): def transform(node, remove_decorators): transformer = DecoratorsTransformer(remove_decorators) node = transformer.visit(node) - return node + return node, transformer.additional_dependencies diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9c01f689127dbedad7669c65b03e7da071b2d64d --- /dev/null +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for decorators module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.python.platform import test + + +# The Python parser only briefly captures decorators into the AST. +# The interpreter desugars them on load, and the decorated function loses any +# trace of the decorator (which is normally what you would expect, since +# they are meant to be transparent). +# However, decorators are still visible when you analyze the function +# from inside a decorator, before it was applied - as is the case +# with our conversion decorators. + + +def simple_decorator(f): + return lambda a: f(a) + 1 + + +def self_removing_decorator(removing_wrapper): + def decorator(f): + @wraps(f) + def wrapper(*args): + # This removing wrapper is defined in the test below. This setup is so + # intricate just to simulate how we use the transformer in practice. + transformed_f = removing_wrapper(f, (self_removing_decorator,)) + return transformed_f(*args) + 1 + return wrapper + return decorator + + +class DecoratorsTest(converter_test_base.TestCase): + + def _remover_wrapper(self, f, remove_decorators): + namespace = { + 'self_removing_decorator': self_removing_decorator, + 'simple_decorator': simple_decorator + } + node = self.parse_and_analyze(f, namespace) + node, _ = decorators.transform(node, remove_decorators=remove_decorators) + result, _ = compiler.ast_to_object(node) + return getattr(result, f.__name__) + + def test_noop(self): + + def test_fn(a): + return a + + node = self.parse_and_analyze(test_fn, {}) + node, deps = decorators.transform(node, remove_decorators=()) + result, _ = compiler.ast_to_object(node) + + self.assertFalse(deps) + self.assertEqual(1, result.test_fn(1)) + + def test_function(self): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(a): + return a + + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, test_fn(1)) + + def test_method(self): + + class TestClass(object): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(self, a): + return a + + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, TestClass().test_fn(1)) + + def test_multiple_decorators(self): + + class TestClass(object): + + # Note that reversing the order of this two doesn't work. + @classmethod + @self_removing_decorator(self._remover_wrapper) + def test_fn(cls, a): + return a + + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, TestClass.test_fn(1)) + + def test_nested_decorators(self): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(a): + @simple_decorator + def inner_fn(b): + return b + 11 + return inner_fn(a) + + with self.assertRaises(ValueError): + test_fn(1) + + # TODO(mdan): Uncomment this test once converter_test_base is updated. + # (can't do it now because it has unrelated pending changes) + # def test_nested_decorators(self): + # + # @self_removing_decorator(self._remover_wrapper) + # def test_fn(a): + # @imported_decorator + # def inner_fn(b): + # return b + 11 + # return inner_fn(a) + # + # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + # self.assertEqual(14, test_fn(1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/autograph/converters/ifexp.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/printing.py rename to tensorflow/contrib/autograph/converters/ifexp.py index 95a62bd80b5f4854e6a062df18d882f7bd495555..616d222762e09feeba1809f119d915dfbe522283 100644 --- a/tensorflow/contrib/py2tf/utils/printing.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -12,36 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow printing support utilities.""" +"""Canonicalizes the ternary conditional operator.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import py_func -from tensorflow.python.ops import logging_ops +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False +class IfExp(transformer.Base): + """Canonicalizes all IfExp nodes into plain conditionals.""" + def visit_IfExp(self, node): + template = """ + ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,)) + """ + desugared_ifexp = templates.replace_as_expression( + template, test=node.test, body=node.body, orelse=node.orelse) + return desugared_ifexp -def call_print(*values): - """Compiled counterpart of the print builtin. - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. +def transform(node, context): + """Desugar IfExp nodes into plain conditionals. Args: - *values: values to print + node: an AST node to transform + context: a context object + Returns: - A dummy value indicating the print completed. If tf. + new_node: an AST with no IfExp nodes, only conditionals. """ - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - return py_func.wrap_py_func(print, None, values, use_dummy_return=True) + node = IfExp(context).visit(node) + return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e --- /dev/null +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -0,0 +1,106 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ifexp module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.python.platform import test + + +class IfExpTest(converter_test_base.TestCase): + + def compiled_fn(self, test_fn, *args): + node = self.parse_and_analyze(test_fn, {}) + node = ifexp.transform(node, self.ctx) + module = self.compiled(node, *args) + return module + + def test_simple(self): + + def test_fn(x): + return 1 if x else 0 + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [0, 1]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_fn(self): + + def f(x): + return 3 * x + + def test_fn(x): + y = f(x * x if x > 0 else x) + return y + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + result.f = f + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_exp(self): + + def test_fn(x): + return x * x if x > 0 else x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_nested(self): + + def test_fn(x): + return x * x if x > 0 else x if x else 1 + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 0, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_in_cond(self): + + def test_fn(x): + if x > 0: + return x * x if x < 5 else x * x * x + return -x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_assign_in_cond(self): + + def test_fn(x): + if x > 0: + x = -x if x < 5 else x + return x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension.py rename to tensorflow/contrib/autograph/converters/list_comprehension.py index e8744831100e4852919b5cd1253b74acea4d790d..d7f292015164e047d054c5d1fb0b391e960bb73d 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,9 +31,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class ListCompCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension_test.py rename to tensorflow/contrib/autograph/converters/list_comprehension_test.py index 025fac11e41e6771fbb9b80ff3da70dc3ceec73e..4758671f5ec83c26cfa54be0ef68f5f564094f6c 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import list_comprehension +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import list_comprehension from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..c15dfff9e8ebd8b96fd4aff82459a6fd7d0ac8ab --- /dev/null +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -0,0 +1,227 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for list operations. + +This includes converting Python lists to TensorArray/TensorList. +""" + +# TODO(mdan): Elaborate the logic here. +# TODO(mdan): Does it even make sense to attempt to try to use TAs? +# The current rule (always convert to TensorArray) is naive and insufficient. +# In general, a better mechanism could look like: +# * convert to TensorList by default +# * leave as Python list if the user explicitly forbids it +# * convert to TensorArray only when complete write once behavior can be +# guaranteed (e.g. list comprehensions) + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# Tags for local state. +POP_USES = 'pop_uses' + + +class ListTransformer(transformer.Base): + """Converts lists and related operations to their TF counterpart.""" + + def visit_List(self, node): + node = self.generic_visit(node) + template = """ + ag__.new_list(elements) + """ + return templates.replace_as_expression(template, elements=node) + + def _replace_append_call(self, node): + assert len(node.args) == 1 + assert isinstance(node.func, gast.Attribute) + template = """ + target = ag__.list_append(target, element) + """ + return templates.replace( + template, + target=node.func.value, + element=node.args[0]) + + def _replace_pop_call(self, node): + # Expressions that use pop() are converted to a statement + expression. + # + # For example: + # + # print(target.pop()) + # + # ... is converted to: + # + # target, target_pop = ag__.list_pop(target) + # print(target_pop) + # + # Here, we just generate the variable name and swap it in, + # and _generate_pop_operation will handle the rest. + # + # Multiple uses of pop() are allowed: + # + # print(tartget.pop(), target.pop()) + # print(tartget.pop().pop()) + # + assert isinstance(node.func, gast.Attribute) + scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) + target_node = node.func.value + + # Attempt to use a related name if can get one. Otherwise use something + # generic. + if anno.hasanno(target_node, anno.Basic.QN): + target_name = anno.getanno(target_node, anno.Basic.QN).ssf() + else: + target_name = 'list' + pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced) + + pop_uses = self.get_local(POP_USES, []) + pop_uses.append((node, pop_var_name)) + self.set_local(POP_USES, pop_uses) + + return templates.replace_as_expression('var_name', var_name=pop_var_name) + + def _replace_stack_call(self, node): + assert len(node.args) == 1 + dtype = anno.getanno( + node.args[0], + 'element_type', + default=templates.replace_as_expression('None')) + template = """ + ag__.list_stack( + target, + opts=ag__.ListStackOpts( + element_dtype=dtype, + original_call=orig_call)) + """ + return templates.replace_as_expression( + template, + dtype=dtype, + target=node.args[0], + orig_call=node.func) + + def visit_Call(self, node): + node = self.generic_visit(node) + + # TODO(mdan): This is insufficient if target is a function argument. + # In the case of function arguments, we need to add the list to the + # function's return value, because it is being modified. + # TODO(mdan): Checking just the name is brittle, can it be improved? + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == 'append' and (len(node.args) == 1): + node = self._replace_append_call(node) + elif func_name == 'pop' and (len(node.args) <= 1): + node = self._replace_pop_call(node) + elif func_name == 'stack' and (len(node.args) == 1): + node = self._replace_stack_call(node) + + return node + + def _generate_pop_operation(self, original_call_node, pop_var_name): + assert isinstance(original_call_node.func, gast.Attribute) + + if original_call_node.args: + pop_element = original_call_node.args[0] + else: + pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to + # target, hence the func.value. + dtype = anno.getanno( + original_call_node.func.value, + 'element_type', + default=templates.replace_as_expression('None')) + shape = anno.getanno( + original_call_node.func.value, + 'element_shape', + default=templates.replace_as_expression('None')) + + template = """ + target, pop_var_name = ag__.list_pop( + target, element, + opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) + """ + return templates.replace( + template, + target=original_call_node.func.value, + pop_var_name=pop_var_name, + element=pop_element, + dtype=dtype, + shape=shape) + + def _postprocess_statement(self, node): + """Inserts any separate pop() calls that node may use.""" + pop_uses = self.get_local(POP_USES, None) + if pop_uses: + replacements = [] + for original_call_node, pop_var_name in pop_uses: + replacements.extend( + self._generate_pop_operation(original_call_node, pop_var_name)) + replacements.append(node) + node = replacements + self.exit_local_scope() + return node, None + + # TODO(mdan): Should we have a generic visit_block instead? + # Right now it feels that a visit_block would add too much magic that's + # hard to follow. + + def _visit_and_process_block(self, block): + return self.visit_block( + block, + before_visit=self.enter_local_scope, + after_visit=self._postprocess_statement) + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._visit_and_process_block(node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._visit_and_process_block(node.body) + return node + + +def transform(node, context): + return ListTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9f18ab9f44dd8c3f341a02b950f75317c676eff8 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lists module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import lists +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class ListTest(converter_test_base.TestCase): + + def test_empty_list(self): + + def test_fn(): + return [] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + tl = result.test_fn() + # Empty tensor lists cannot be evaluated or stacked. + self.assertTrue(isinstance(tl, ops.Tensor)) + self.assertEqual(tl.dtype, dtypes.variant) + + def test_initialized_list(self): + + def test_fn(): + return [1, 2, 3] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_append(self): + + def test_fn(): + l = [1] + l.append(2) + l.append(3) + return l + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_pop(self): + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32, ()) + s = l.pop() + return s, l + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + ts, tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2]) + self.assertAllEqual(sess.run(ts), 3) + + def test_double_list_pop(self): + + def test_fn(l): + s = l.pop().pop() + return s + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(result.test_fn(test_input), 3) + + def test_list_stack(self): + + tf = None # Will be replaced with a mock. + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32) + return tf.stack(l) + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node, array_ops.stack, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..3a795a315a3c2aa08ac1577a204102755b6e849c --- /dev/null +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for logical expressions. + +e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +# TODO(mdan): Properly extrack boolean ops according to lazy eval rules. +# Note that this isn't completely safe either, because tensors may have control +# dependencies. +# Note that for loops that should be done after the loop was converted to +# tf.while_loop so that the expanded conditionals are properly scoped. + +# Used to signal that an operand is safe for non-lazy evaluation. +SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' + + +class LogicalExpressionTransformer(transformer.Base): + """Converts logical expressions to corresponding TF calls.""" + + def __init__(self, context): + super(LogicalExpressionTransformer, self).__init__(context) + # TODO(mdan): Look into replacing with bitwise operators instead. + # TODO(mdan): Skip replacing if the function is trivial. + self.op_mapping = { + gast.And: 'tf.logical_and', + gast.Eq: 'tf.equal', + gast.Gt: 'tf.greater', + gast.GtE: 'tf.greater_equal', + gast.Lt: 'tf.less', + gast.LtE: 'tf.less_equal', + gast.Not: 'tf.logical_not', + gast.NotEq: 'tf.not_equal', + gast.Or: 'tf.logical_or', + gast.USub: 'tf.negative', + gast.Is: 'autograph_utils.dynamic_is', + gast.IsNot: 'autograph_utils.dynamic_is_not' + } + + def _expect_simple_symbol(self, operand): + if isinstance(operand, gast.Name): + return + if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND): + return + raise NotImplementedError( + 'only simple local variables are supported in logical and compound ' + 'comparison expressions; for example, we support "a or b" but not ' + '"a.x or b"; for a workaround, assign the expression to a local ' + 'variable and use that instead, for example "tmp = a.x", "tmp or b"') + + def _matching_func(self, operator): + op_type = type(operator) + mapped_op = self.op_mapping.get(op_type) + if not mapped_op: + raise NotImplementedError('operator %s is not yet supported' % op_type) + return mapped_op + + def _as_function(self, func_name, args): + template = """ + func_name(args) + """ + replacement = templates.replace_as_expression( + template, func_name=parser.parse_expression(func_name), args=args) + anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) + return replacement + + def visit_Compare(self, node): + node = self.generic_visit(node) + ops_and_comps = list(zip(node.ops, node.comparators)) + left = node.left + op_tree = None + + # Repeated comparisons are converted to conjunctions: + # a < b < c -> a < b and b < c + while ops_and_comps: + op, right = ops_and_comps.pop(0) + binary_comparison = self._as_function( + self._matching_func(op), (left, right)) + if isinstance(left, gast.Name) and isinstance(right, gast.Name): + anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) + if op_tree: + self._expect_simple_symbol(right) + op_tree = self._as_function('tf.logical_and', + (binary_comparison, op_tree)) + else: + op_tree = binary_comparison + left = right + assert op_tree is not None + return op_tree + + def visit_UnaryOp(self, node): + node = self.generic_visit(node) + return self._as_function(self._matching_func(node.op), node.operand) + + def visit_BoolOp(self, node): + node = self.generic_visit(node) + node_values = node.values + right = node.values.pop() + self._expect_simple_symbol(right) + while node_values: + left = node_values.pop() + self._expect_simple_symbol(left) + right = self._as_function(self._matching_func(node.op), (left, right)) + return right + + +def transform(node, context): + return LogicalExpressionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py similarity index 86% rename from tensorflow/contrib/py2tf/converters/logical_expressions_test.py rename to tensorflow/contrib/autograph/converters/logical_expressions_test.py index a28326c517d468230f35e45f0fbfe5257d769895..2814060c4d831e4dddacb3dcbcbe1db42160db20 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import logical_expressions +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -32,7 +32,7 @@ class GradientsFunctionTest(converter_test_base.TestCase): return a == b node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, self.ctx) with self.compiled(node, math_ops.equal) as result: with self.test_session() as sess: @@ -45,7 +45,7 @@ class GradientsFunctionTest(converter_test_base.TestCase): return (a or b) and (a or b or c) node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, self.ctx) with self.compiled(node, math_ops.logical_or, math_ops.logical_and) as result: diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py new file mode 100644 index 0000000000000000000000000000000000000000..dfee529abaa8c14d9b408819b32c5199500a2c2f --- /dev/null +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps a function body with a `name_scope` of the function name.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class FunctionNameScopeTransformer(transformer.Base): + """Wrap a function body with a `name_scope` of the function name.""" + + def _name_for_current_scope(self): + innermost = self.enclosing_entities[-1] + if len(self.enclosing_entities) > 1: + parent = self.enclosing_entities[-2] + if isinstance(parent, gast.ClassDef): + # Methods also take the name of their class. + name = '%s/%s' % (parent.name, innermost.name) + else: + name = innermost.name + else: + name = innermost.name + + # Sanitize the name. + # See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope + # TensorFlow doesn't like leading underscores at the top level. + while name[0] == '_': + name = name[1:] + return name + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + + unscoped_body = [] + scoped_body = node.body + if scoped_body: + first = scoped_body[0] + if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str): + # Skip any docstring. + unscoped_body = scoped_body[:1] + scoped_body = scoped_body[1:] + + template = """ + with tf.name_scope(scope_name): + body + """ + scoped_body = templates.replace( + template, + scope_name=gast.Str(self._name_for_current_scope()), + body=scoped_body) + node.body = unscoped_body + scoped_body + return node + + +def transform(node, context): + return FunctionNameScopeTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..17692cbd880dbc1db4bb40ad7345e27907499f9d --- /dev/null +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for for_canonicalization module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class FunctionNameScopeTransformer(converter_test_base.TestCase): + + def test_basic(self): + + def test_fn(l): + """This should stay here.""" + a = 5 + l += a + return l + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant(1)) + self.assertIn('test_fn/', result_op.op.name) + + self.assertEqual('This should stay here.', result.test_fn.__doc__) + + def test_long_docstring(self): + + def test_fn(l): + """Multi-line docstring. + + Args: + l: A thing. + Returns: + l + """ + return l + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + self.assertIn('Multi-line', result.test_fn.__doc__) + self.assertIn('Returns:', result.test_fn.__doc__) + + def test_nested_functions(self): + + def test_fn(l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('test_fn/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('test_fn/inner_fn/', second_result_input_name) + + def test_method(self): + + class TestClass(object): + + def test_fn(self, l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + # Note that 'TestClass' was needed in the namespace here. + node = self.parse_and_analyze( + TestClass, {'TestClass': TestClass}, owner_type=TestClass) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.TestClass().test_fn(constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('TestClass/test_fn/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name) + + def test_operator(self): + + class TestClass(object): + + def __call__(self, l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + # Note that 'TestClass' was needed in the namespace here. + node = self.parse_and_analyze( + TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.__call__(TestClass(), constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('call__/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('call__/inner_fn/', second_result_input_name) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py similarity index 92% rename from tensorflow/contrib/py2tf/converters/side_effect_guards.py rename to tensorflow/contrib/autograph/converters/side_effect_guards.py index 30976b3ec6db5a6607023ac804d9d54cfb296190..3bcb2d3c42c6e0663c8f78523199a364b6ac231f 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,12 +36,12 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -160,8 +160,8 @@ class SideEffectGuardTransformer(transformer.Base): [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ - with py2tf_utils.control_dependency_on_returns(call): - aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args) + with ag__.utils.control_dependency_on_returns(call): + aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, @@ -172,7 +172,7 @@ class SideEffectGuardTransformer(transformer.Base): alias_map = {} template = """ - with py2tf_utils.control_dependency_on_returns(call): + with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py similarity index 97% rename from tensorflow/contrib/py2tf/converters/side_effect_guards_test.py rename to tensorflow/contrib/autograph/converters/side_effect_guards_test.py index 463db2e770213ba9636d2537b095a77dece5d8f6..ce0ce33243a1352107eb8121050ee76474869809 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -0,0 +1,317 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes functions with multiple returns to use just one.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# TODO(mdan): Move this logic into transformer_base. +class BodyVisitor(transformer.Base): + """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" + + def __init__(self, context, depth_first=False): + self.depth_first = depth_first + self.changes_made = False + super(BodyVisitor, self).__init__(context) + + def visit_nodelist(self, nodelist): + for node in nodelist: + if isinstance(node, list): + node = self.visit_nodelist(node) + else: + node = self.generic_visit(node) + return nodelist + + def visit_If(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_For(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_While(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_Try(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + node.finalbody = self.visit_nodelist(node.finalbody) + for i in range(len(node.handlers)): + node.handlers[i].body = self.visit_nodelist(node.handlers[i].body) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_With(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_FunctionDef(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + self.generic_visit(node) + if not self.depth_first: + node = self.generic_visit(node) + return node + + +class FoldElse(BodyVisitor): + + def visit_nodelist(self, nodelist): + for i in range(len(nodelist)): + node = nodelist[i] + if isinstance(node, gast.If): + true_branch_returns = isinstance(node.body[-1], gast.Return) + false_branch_returns = len(node.orelse) and isinstance( + node.orelse[-1], gast.Return) + # If the last node in the if body is a return, + # then every line after this if statement effectively + # belongs in the else. + if true_branch_returns and not false_branch_returns: + for j in range(i + 1, len(nodelist)): + nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j])) + if nodelist[i + 1:]: + self.changes_made = True + return nodelist[:i + 1] + elif not true_branch_returns and false_branch_returns: + for j in range(i + 1, len(nodelist)): + nodelist[i].body.append(ast_util.copy_clean(nodelist[j])) + if nodelist[i + 1:]: + self.changes_made = True + return nodelist[:i + 1] + elif true_branch_returns and false_branch_returns: + if nodelist[i + 1:]: + raise ValueError( + 'Unreachable code after conditional where both branches return.' + ) + return nodelist + elif isinstance(node, gast.Return) and nodelist[i + 1:]: + raise ValueError( + 'Cannot have statements after a return in the same basic block') + return nodelist + + +def contains_return(node): + for n in gast.walk(node): + if isinstance(n, gast.Return): + return True + return False + + +class LiftReturn(transformer.Base): + """Move return statements out of If and With blocks.""" + + def __init__(self, context): + self.changes_made = False + self.common_return_name = None + super(LiftReturn, self).__init__(context) + + def visit_If(self, node): + # Depth-first traversal of if statements + node = self.generic_visit(node) + + # We check if both branches return, and if so, lift the return out of the + # conditional. We don't enforce that the true and false branches either + # both return or both do not, because FoldElse might move a return + # into a branch after this transform completes. FoldElse and LiftReturn + # are alternately run until the code reaches a fixed point. + true_branch_returns = isinstance(node.body[-1], gast.Return) + false_branch_returns = len(node.orelse) and isinstance( + node.orelse[-1], gast.Return) + if true_branch_returns and false_branch_returns: + node.body[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] + node.orelse[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0] + return_node = templates.replace('return a', a=self.common_return_name)[0] + self.changes_made = True + return [node, return_node] + else: + return node + + def visit_With(self, node): + # Depth-first traversal of syntax + node = self.generic_visit(node) + + # If the with statement returns, lift the return + if isinstance(node.body[-1], gast.Return): + node.body[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] + return_node = templates.replace('return a', a=self.common_return_name)[0] + node = self.generic_visit(node) + self.changes_made = True + return [node, return_node] + else: + return node + + def visit_FunctionDef(self, node): + # Ensure we're doing depth-first traversal + last_return_name = self.common_return_name + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + referenced_names = body_scope.referenced + self.common_return_name = self.context.namer.new_symbol( + 'return_', referenced_names) + node = self.generic_visit(node) + self.common_return_name = last_return_name + return node + + +class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): + """Throws an error if code returns inside loops or try/except.""" + + # First, throw an error if we detect a return statement in a loop. + # TODO(alexbw): we need to learn to handle returns inside a loop, + # but don't currently have the TF constructs to do so (need something + # that looks vaguely like a goto). + + def __init__(self): + self.cant_return = False + super(DetectReturnInUnsupportedControlFlow, self).__init__() + + def visit_While(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_For(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Try(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Return(self, node): + if self.cant_return: + raise ValueError( + '`return` statements are not supported in loops. ' + 'Try assigning to a variable in the while loop, and returning ' + 'outside of the loop') + + +class DetectReturnInConditional(gast.NodeVisitor): + """Assert that no return statements are present in conditionals.""" + + def __init__(self): + self.cant_return = False + super(DetectReturnInConditional, self).__init__() + + def visit_If(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Return(self, node): + if self.cant_return: + raise ValueError( + 'After transforms, a conditional contained a `return `statement, ' + 'which is not allowed. This is a bug, and should not happen.') + + +class DetectReturnInFunctionDef(gast.NodeVisitor): + + def visit_FunctionDef(self, node): + self.generic_visit(node) + if not contains_return(node): + raise ValueError( + 'Each function definition should contain at least one return.') + + +def transform(node, context): + """Ensure a function has only a single return. + + This transforms an AST node with multiple returns successively into containing + only a single return node. + There are a few restrictions on what we can handle: + - An AST being transformed must contain at least one return. + - No returns allowed in loops. We have to know the type of the return value, + and we currently don't have either a type inference system to discover it, + nor do we have a mechanism for late type binding in TensorFlow. + - After all transformations are finished, a Return node is not allowed inside + control flow. If we were unable to move a return outside of control flow, + this is an error. + + Args: + node: an AST node to transform + context: a context object + + Returns: + new_node: an AST with a single return value + + Raises: + ValueError: if the AST is structured so that we can't perform the + transform. + """ + # Make sure that the function has at least one return statement + # TODO(alexbw): turning off this assertion for now -- + # we need to not require this in e.g. class constructors. + # DetectReturnInFunctionDef().visit(node) + + # Make sure there's no returns in unsupported locations (loops, try/except) + DetectReturnInUnsupportedControlFlow().visit(node) + + while True: + + # Try to lift all returns out of if statements and with blocks + lr = LiftReturn(context) + node = lr.visit(node) + changes_made = lr.changes_made + fe = FoldElse(context) + node = fe.visit(node) + changes_made = changes_made or fe.changes_made + + if not changes_made: + break + + # Make sure we've scrubbed all returns from conditionals + DetectReturnInConditional().visit(node) + + return node diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d483005a09537ea8227814f65aa7e6402c853f60 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -0,0 +1,189 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for single_return module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.python.framework.ops import name_scope +from tensorflow.python.platform import test + + +class SingleReturnTest(converter_test_base.TestCase): + + def compiled_fn(self, test_fn, *args): + node = self.parse_and_analyze(test_fn, {}) + node = single_return.transform(node, self.ctx) + module = self.compiled(node, *args) + return module + + def test_noop(self): + # Noop + def test_fn(x): + return x + + with self.compiled_fn(test_fn) as result: + self.assertEqual(test_fn(2.0), result.test_fn(2.0)) + + def test_return_expression(self): + # ANF + def test_fn(x): + return x * x + + with self.compiled_fn(test_fn) as result: + x = 2 + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_merge(self): + # Simple merge + def test_fn(x): + if x > 0: + return x + else: + return x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_orphan_branch(self): + + def test_fn(x): + if x > 0: + return x + + with self.assertRaises(ValueError): + self.compiled_fn(test_fn) + + def test_lift_body_into_false_branch(self): + + def test_fn(x): + if x > 0: + return x + return x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_lift_body_into_true_branch(self): + + def test_fn(x): + if x < 0: + x *= x + else: + # TODO(alexbw): linter bug here that requires us suppress this warning. + return x # pylint: disable=undefined-loop-variable + return x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_nested_if(self): + + def test_fn(x): + if x > 0: + if x < 5: + return x + else: + return x * x + else: + return x * x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_context_manager(self): + + def test_fn(x): + + with name_scope(''): + return x * x + + with self.compiled_fn(test_fn) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_context_manager_in_conditional(self): + + def test_fn(x): + if x > 0: + with name_scope(''): + return x * x + else: + return x + + with self.compiled_fn(test_fn, name_scope) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def text_conditional_in_context_manager(self): + + def test_fn(x): + with name_scope(''): + if x > 0: + return x * x + else: + return x + + with self.compiled_fn(test_fn) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_no_return(self): + + def test_fn(x): + x *= x + + with self.compiled_fn(test_fn) as result: + self.assertEqual(test_fn(2), result.test_fn(2)) + + def test_nested_functiondefs(self): + + def test_fn(x): + + def inner_fn(y): + if y > 0: + return y * y + else: + return y + + return inner_fn(x) + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_loop(self): + + def test_fn(x): + for _ in range(10): + return x + return x + + with self.assertRaises(ValueError): + self.compiled_fn(test_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..85aeda9c4164eb70329bd50f789eea5441c8fc87 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for slice operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class SliceTransformer(transformer.Base): + """Converts slicing operations to their TF counterpart. + + Currently, relying on the default slice operator that Tensor uses is + insufficient, because TensorArray and tensor lists use dedicated index read + and write functions. + """ + + def _process_single_assignment(self, target, value): + if not isinstance(target, gast.Subscript): + return None + + template = """ + target = ag__.set_item(target, key, item) + """ + return templates.replace( + template, target=target.value, key=target.slice, item=value) + + def visit_Assign(self, node): + node = self.generic_visit(node) + # TODO(mdan): Support unpackings and multiple assignments. + if len(node.targets) != 1: + raise NotImplementedError('multiple assignment') + replacement = self._process_single_assignment(node.targets[0], node.value) + if replacement is not None: + return replacement + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + if not isinstance(node.slice, gast.Index): + # TODO(mdan): It might make more sense to wave them through. + raise NotImplementedError('non-index slice') + + if not isinstance(node.ctx, gast.Load): + # Index writes are handled at a higher level, one at which the rvalue is + # also available. + return node + + dtype = anno.getanno( + node.value, + 'element_type', + default=templates.replace_as_expression('None')) + + template = """ + ag__.get_item( + target, + key, + opts=ag__.GetItemOpts(element_dtype=dtype)) + """ + return templates.replace_as_expression( + template, target=node.value, key=node.slice, dtype=dtype) + + +def transform(node, context): + return SliceTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2d7e1ea1a6c46fcc3a2c6972a24507646ef858 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SliceTest(converter_test_base.TestCase): + + def test_index_access(self): + + def test_fn(l): + utils.set_element_type(l, dtypes.int32) + return l[1] + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = slices.transform(node, self.ctx) + + with self.compiled(node, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = result.test_fn(tl) + self.assertEqual(2, sess.run(y)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d62390494b78c415212ba91ac914cdfee324f971 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -0,0 +1,1919 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Dev Summit 2018 - Autograph", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python2", + "display_name": "Python 2" + } + }, + "cells": [ + { + "metadata": { + "id": "g7nGs4mzVUHP", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Experimental: TF Autograph\n", + "**TensorFlow Dev Summit, 2018.**\n", + "\n", + "This interactive notebook demonstrates **autograph**, an experimental source-code transformation library to automatically convert TF.Eager and Python code to TensorFlow graphs.\n", + "\n", + "**Note: this is pre-alpha software!** The notebook works best with Python 2, for now.\n", + "\n", + "> ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", + "\n", + "### Table of Contents\n", + "1. _Write Eager code that is fast and scalable._\n", + "2. _Case study: complex control flow._\n", + "3. _Case study: training MNIST with Keras._\n", + "4. _Case study: building an RNN._" + ] + }, + { + "metadata": { + "id": "uFcgBENZqkB2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Install TensorFlow; note that Colab notebooks run remotely, on virtual\n", + "# instances provided by Google.\n", + "!pip install -U -q tf-nightly" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Pa2qpEmoVOGe", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import os\n", + "import time\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import six\n", + "\n", + "from google.colab import widgets" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "ZVKfj5ttVkqz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 1. Write Eager code that is fast and scalable\n", + "\n", + "TF.Eager gives you more flexibility while coding, but at the cost of losing the benefits of TensorFlow graphs. For example, Eager does not currently support distributed training, exporting models, and a variety of memory and computation optimizations.\n", + "\n", + "Autograph gives you the best of both worlds: write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." + ] + }, + { + "metadata": { + "id": "snaZRFdWd9ym", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "For example, autograph can convert a function like this:" + ] + }, + { + "metadata": { + "id": "9__n8cSIeDnD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def g(x):\n", + " if x > 0:\n", + " x = x * x\n", + " else:\n", + " x = 0\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "gq0eQcuReHET", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "... into a TF graph-building function:" + ] + }, + { + "metadata": { + "id": "sELSn599ePUF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 413 + }, + "outputId": "bb0c7216-1ca3-4da1-d1fb-589902cdcd1a", + "executionInfo": { + "status": "ok", + "timestamp": 1522345737505, + "user_tz": 240, + "elapsed": 243, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "print(autograph.to_code(g))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "from tensorflow.contrib.autograph.impl import api as autograph_api\n", + "from tensorflow.contrib.autograph import utils as autograph_utils\n", + "\n", + "def tf__g(x):\n", + " with tf.name_scope('g'):\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_1, = x,\n", + " x_1 = 0\n", + " return x_1,\n", + " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "j74n-8hEe6dk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "You can then use the converted function as you would any regular TF op -- you can pass `Tensor` arguments and it will return `Tensor`s:" + ] + }, + { + "metadata": { + "id": "AkVaY0-dfEbH", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "4ffe3757-c44d-424c-c2a8-7ddc973bfcce", + "executionInfo": { + "status": "ok", + "timestamp": 1522345737841, + "user_tz": 240, + "elapsed": 257, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "tf_g = autograph.to_graph(g)\n", + "\n", + "with tf.Graph().as_default(): \n", + "\n", + " g_ops = tf_g(tf.constant(9))\n", + "\n", + " with tf.Session() as sess:\n", + " tf_g_result = sess.run(g_ops)\n", + "\n", + " print('g(9) = %s' % g(9))\n", + " print('tf_g(9) = %s' % tf_g_result)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "g(9) = 81\n", + "tf_g(9) = 81\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "trrHQBM1VnD0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 2. Case study: complex control flow\n", + "\n", + "Autograph can convert a large chunk of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in autograph.\n", + "Autograph will automatically convert most Python control flow statements into their correct graph equivalent.\n", + " " + ] + }, + { + "metadata": { + "id": "u0YG3DPgZxoW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:" + ] + }, + { + "metadata": { + "id": "xJYDzOcrZ8pI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "6c244ee4-b141-4ad6-eefa-cfffa71f33c6", + "executionInfo": { + "status": "ok", + "timestamp": 1522345738402, + "user_tz": 240, + "elapsed": 483, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def sum_even(numbers):\n", + " s = 0\n", + " for n in numbers:\n", + " if n % 2 > 0:\n", + " continue\n", + " s += n\n", + " return s\n", + "\n", + "\n", + "tf_sum_even = autograph.to_graph(sum_even)\n", + "\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " result = sess.run(tf_sum_even(tf.constant([10, 12, 15, 20])))\n", + "\n", + " print('Sum of even numbers: %s' % result)\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(sum_even))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Sum of even numbers: 42\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "_YXo4KOcbKrn", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Try replacing the `continue` in the above code with `break` -- Autograph supports that as well!" + ] + }, + { + "metadata": { + "id": "xHmC0rBIavW_", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The Python code above is much more readable than the matching graph code. Autograph takes care of tediously converting every piece of Python code into the matching TensorFlow graph version for you, so that you can quickly write maintainable code, but still benefit from the optimizations and deployment benefits of graphs." + ] + }, + { + "metadata": { + "id": "UEHWGpBXbS7g", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " + ] + }, + { + "metadata": { + "id": "qUU57xlEbauI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "add3db4a-2077-4dd5-f7a7-a5b5a4529c26", + "executionInfo": { + "status": "ok", + "timestamp": 1522345738697, + "user_tz": 240, + "elapsed": 253, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def f(x):\n", + " assert x != 0, 'Do not pass zero!'\n", + " return x * x\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " try:\n", + " print(sess.run(tf_f(tf.constant(0))))\n", + " except tf.errors.InvalidArgumentError as e:\n", + " print('Got error message: %s' % e.message)\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(f))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Got error message: assertion failed: [Do not pass zero!]\n", + "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "w5hBZaVJbck4", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "You can also use `print` functions in-graph:" + ] + }, + { + "metadata": { + "id": "6NdzRKLEboRv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "fb82dfc3-790f-4127-87f6-361805be9e9b", + "executionInfo": { + "status": "ok", + "timestamp": 1522345739013, + "user_tz": 240, + "elapsed": 247, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def print_sign(n):\n", + " if n >= 0:\n", + " print(n, 'is positive!')\n", + " else:\n", + " print(n, 'is negative!')\n", + " return n\n", + "\n", + "\n", + "tf_print_sign = autograph.to_graph(print_sign)\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " sess.run(tf_print_sign(tf.constant(1)))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(print_sign))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1 is positive!\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "9u_Z3i3AivLA", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We can convert lists to TensorArray, so appending to lists also works, with a few modifications:" + ] + }, + { + "metadata": { + "id": "MjhCQJVuiTNR", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "dc320b87-595b-4392-d29c-994486fd8a0a", + "executionInfo": { + "status": "ok", + "timestamp": 1522345744470, + "user_tz": 240, + "elapsed": 5391, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def f(n):\n", + " numbers = []\n", + " # We ask you to tell us about the element dtype.\n", + " autograph.utils.set_element_type(numbers, tf.int32)\n", + " for i in range(n):\n", + " numbers.append(i)\n", + " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + "\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(tf_f(tf.constant(5))))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(f))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[0 1 2 3 4]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "UdG8ZFrkTAF2", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "And all of these functionalities, and more, can be composed into more complicated code:\n" + ] + }, + { + "metadata": { + "id": "DVs6wt8NKaGQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "cellView": "code", + "outputId": "0a4b8d08-8f65-4bbc-85ba-dc4c60563519", + "executionInfo": { + "status": "ok", + "timestamp": 1522345745186, + "user_tz": 240, + "elapsed": 658, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def print_primes(n):\n", + " \"\"\"Returns all the prime numbers less than n.\"\"\"\n", + " assert n > 0\n", + " \n", + " primes = []\n", + " autograph.utils.set_element_type(primes, tf.int32)\n", + " for i in range(2, n):\n", + " is_prime = True\n", + " for k in range(2, i):\n", + " if i % k == 0:\n", + " is_prime = False\n", + " break\n", + " if not is_prime:\n", + " continue\n", + " primes.append(i)\n", + " all_primes = primes.stack()\n", + "\n", + " print('The prime numbers less than', n, 'are:')\n", + " print(all_primes)\n", + " return tf.no_op()\n", + "\n", + " \n", + "tf_print_primes = autograph.to_graph(print_primes)\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " n = tf.constant(50)\n", + " sess.run(tf_print_primes(n))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(print_primes))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "The prime numbers less than 50 are:\n", + "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "JQ8kQT99VqDk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 3. Case study: training MNIST with Keras\n", + "\n", + "As we've seen, writing control flow in Autograph is easy. So running a training loop in graph should be easy as well!\n", + "\n", + "Here, we show an example of such a training loop for a simple Keras model that trains on MNIST." + ] + }, + { + "metadata": { + "id": "0CrtGWgwuLJr", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import gzip\n", + "import shutil\n", + "\n", + "from six.moves import urllib\n", + "\n", + "\n", + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "2zu1U9Nqir6L", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "First, we'll define a small three-layer neural network using the Keras API" + ] + }, + { + "metadata": { + "id": "x_MU13boiok2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')])\n", + " model.build()\n", + " return model" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Wuqg3H8mi0Xj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's connect the model definition (here abbreviated as `m`) to a loss function, so that we can train our model." + ] + }, + { + "metadata": { + "id": "W51sfbONiz_5", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "035tNWQki9tr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now the final piece of the problem specification (before loading data, and clicking everything together) is backpropagating the loss through the model, and optimizing the weights using the gradient." + ] + }, + { + "metadata": { + "id": "CsAD0ajbi9iZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "PcVRIacKjSwb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "These are some utility functions to download data and generate batches for training" + ] + }, + { + "metadata": { + "id": "RVw57HdTjPzi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n", + "\n", + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "2zEJH5XNjgFz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This function specifies the main training loop. We instantiate the model (using the code above), instantiate an optimizer (here we'll use SGD with momentum, nothing too fancy), and we'll instantiate some lists to keep track of training and test loss and accuracy over time.\n", + "\n", + "In the loop inside this function, we'll grab a batch of data, apply an update to the weights of our model to improve its performance, and then record its current training loss and accuracy. Every so often, we'll log some information about training as well." + ] + }, + { + "metadata": { + "id": "UUI0566FjZPx", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + " train_losses = []\n", + " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n", + " test_losses = []\n", + " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n", + " train_accuracies = []\n", + " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n", + " tf.float32)\n", + " test_accuracies = []\n", + " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n", + " tf.float32)\n", + " i = tf.constant(0)\n", + " while i < hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + " if i % (hp.max_steps // 10) == 0:\n", + " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n", + " step_test_loss, 'train accuracy:', step_train_accuracy,\n", + " 'test accuracy:', step_test_accuracy)\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + " i += 1\n", + " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", + " test_accuracies.stack())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "cYiUQ1ppkHzk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Everything is ready to go, let's train the model and plot its performance!" + ] + }, + { + "metadata": { + "id": "K1m8TwOKjdNd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {}, + {}, + {} + ], + "base_uri": "https://localhost:8080/", + "height": 988 + }, + "outputId": "f9d3eef3-5bea-45c1-ddf9-4edee73e4436", + "executionInfo": { + "status": "ok", + "timestamp": 1522345800262, + "user_tz": 240, + "elapsed": 52391, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=500,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n", + " test_accuracies])\n", + " plt.title('MNIST train/test losses')\n", + " plt.plot(train_losses, label='train loss')\n", + " plt.plot(test_losses, label='test loss')\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Loss')\n", + " plt.show()\n", + " plt.title('MNIST train/test accuracies')\n", + " plt.plot(train_accuracies, label='train accuracy')\n", + " plt.plot(test_accuracies, label='test accuracy')\n", + " plt.legend(loc='lower right')\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Accuracy')\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/train-images-idx3-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/train-labels-idx1-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/t10k-images-idx3-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/t10k-labels-idx1-ubyte.gz\n", + "Step 0 train loss: 2.244329 test loss: 2.2499208 train accuracy: 0.12 test accuracy: 0.161\n", + "Step 50 train loss: 0.64771986 test loss: 0.56013924 train accuracy: 0.82 test accuracy: 0.836\n", + "Step 100 train loss: 0.49011207 test loss: 0.42143965 train accuracy: 0.84 test accuracy: 0.879\n", + "Step 150 train loss: 0.3768609 test loss: 0.39319593 train accuracy: 0.88 test accuracy: 0.883\n", + "Step 200 train loss: 0.36007702 test loss: 0.37089333 train accuracy: 0.9 test accuracy: 0.881\n", + "Step 250 train loss: 0.182115 test loss: 0.28543878 train accuracy: 0.94 test accuracy: 0.915\n", + "Step 300 train loss: 0.2119576 test loss: 0.22305593 train accuracy: 0.92 test accuracy: 0.93\n", + "Step 350 train loss: 0.12932214 test loss: 0.29057172 train accuracy: 0.96 test accuracy: 0.906\n", + "Step 400 train loss: 0.22937602 test loss: 0.2200287 train accuracy: 0.92 test accuracy: 0.925\n", + "Step 450 train loss: 0.23444137 test loss: 0.19857481 train accuracy: 0.94 test accuracy: 0.94\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3XmAFNW9Pvynlt5mYdhmQMHggnGN\nS9zCD0ElKug1edUY9ZoQTYze3GuiRk1uYjRqRHNj4n5NrhKjiUYlbihGQFRUFDSoKIvgICAO6+xL\n711V5/2jlq7qZaZnpnumZ3g+/zjTXV1dXSP91PecU+dIQggBIiIiGjLkwT4AIiIi6h2GNxER0RDD\n8CYiIhpiGN5ERERDDMObiIhoiGF4ExERDTEMb6JeOOigg3DllVdmPf6rX/0KBx10kGe766+/3rPN\ne++9h9mzZwMAtm3bhkMPPdR57osvvsCPfvQjzJw5EzNnzsTZZ5+NV199FQBw0003YdasWZg1axYO\nO+wwnHLKKc7v4XDY8x7JZBLz58/v9edavXo1Lr300oK2XbBgAebMmdPn97J19/rZs2fjhRde6PO+\niYY7hjdRL3366aee0Ewmk1izZk3WditXrsQnn3xS0D6vu+46TJs2DYsXL8bixYtxyy234LrrrsPO\nnTtxyy23YNGiRVi0aBHGjRuH3//+987vVVVVnv188sknfQrUI444Ag8//HBB2y5fvhxTpkzp83vZ\n+vt6oj0Zw5uol0444QQsWbLE+f3tt9/GV77ylaztrrnmGtx+++0F7bO+vh5HHnmk8/uRRx6JxYsX\nY/z48QUfV3NzM3784x/jo48+wkUXXQTAbAF48MEHMXPmTOi6jlWrVuHcc8/FrFmzcOaZZ2L58uUA\nzFaB0047DQBw//334ze/+Q2uuOIKfP3rX8d5552HxsZG533ee+89HHzwwVnv9cEHH+Bb3/oWTjvt\nNJx//vloaGgAAOzevRsXX3wxzjzzTJx66qm4++67cx5rPu+99x7OOecczJo1C9/+9redC6Vc++3u\ncSEE/vd//xczZ87EKaecgjlz5kDXdQDAwoULcdZZZ+GMM87AN77xDbz33nsFn3eiwcDwJuqlM844\nAy+99JLz+z//+U/MmjUr53ZCCCxatKjHfU6fPh1XXnkl/va3v2HTpk0AgHHjxkGSpIKPa+zYsbjm\nmmtw1FFH4YknnnAeF0Jg8eLFUBQFv/71r3HppZdi0aJFuPzyy3HTTTfl3NeiRYtw/fXX49VXX8WY\nMWPw7LPPAgA2bdqE2tpaTJgwwfNe4XAY//mf/4lrrrkGS5Yswfe+9z1cddVVAIBHH30Uxx13HF5+\n+WUsWLAADQ0NMAwj57FmikQiuOqqq3DDDTdg0aJF+OEPf4jrrrsOhmHk3G9jY2Pex1944QUsWrQI\nzzzzDJYsWYKGhgY8+eSTAIBbbrkFDz74IBYuXIibbroJr7/+esHnnWgwMLyJeun444/Hxo0b0dLS\nglgshlWrVmHKlCk5t73++uvxhz/8AYlEott9/v73v8d3vvMdLFiwAGeddRZmzJjhBEt/nXzyyc7P\n8+fPxxlnnAEAOOaYY5zqONOxxx6LCRMmQJIkHHLIIdi5cycAYMWKFTk/6wcffIBx48Zh6tSpAICz\nzjoLX3zxBXbs2IExY8bg7bffxvvvvw+/34+77roLdXV1BR376tWrMX78eBxzzDEAgJkzZ6KtrQ3b\nt2/Pu998jy9duhTf+ta3UF1dDVVV8e1vfxuvvPIKAGDMmDF46qmnsH37dhx77LH45S9/WdjJJRok\n6mAfANFQoygKTj/9dCxcuBCjR4/GiSeeCFXN/U/psMMOw3HHHYdHHnkERx99dN59BgIBXHrppbj0\n0kvR2dmJRYsW4fbbb8fEiRMxbdq0fh3vyJEjnZ8XLFiAv/3tb4hEIjAMA/mWNqiurnZ+VhTFaV5+\n5513cMkll2Rt39nZiYaGBk8LhN/vR2trKy655BIYhoFbbrkFjY2N+M53voOf/OQnBR17a2srRowY\nkXVsLS0tefeb7/Guri48/PDDmDdvHgBA13WMHj0aAPCnP/0Jf/rTn3Duuedir732wvXXX4/jjz++\noGMkGgwMb6I+OPPMM3H33Xdj1KhRPfbZ/vSnP8W5556LiRMn5ny+tbUV69evd6rWESNG4Pzzz8ey\nZctQX1/f7/C27d69GzfccAOefvppHHLIIfj8888xc+bMgl+vaRrWrFmT8yKkrq4O+++/P5577rmc\nr7388stx+eWXY8uWLbjsssucSronY8aMQXt7u/O7EAIdHR0YM2YMVFXNud+pU6fmfLyurg4zZszA\nd7/73az3+dKXvoTf/va3MAwD8+fPx7XXXotly5YVeGaIBh6bzYn64Oijj0ZjYyM2btzYY4VWV1eH\n73znO7j//vtzPh+Px3HllVd6wmLr1q34+OOPceyxx/bquFRVRTgczllRt7a2oqKiAvvvvz80TXMq\n0EgkUtC+V69ejYMOOgh+vz/rvY488kg0NTXh448/BgA0NDTgZz/7GYQQ+PWvf4133nkHgBmSY8eO\nhSRJ3R6r7YgjjkBzczNWrVoFwBxfMH78eEycODHvfvM9/vWvfx0vvPACYrEYAOCpp57C888/j9bW\nVnz/+99HOByGLMs48sgjezXWgGgwsPIm6gNJknDaaachFotBlnu+Bv7BD36Ap59+Oudze++9N/70\npz/hvvvuw5w5cyCEQFVVFX75y196RqAX4phjjsEf/vAHTJs2DW+++abnuYMPPhjTp0/HzJkzMWbM\nGPziF7/Ahx9+iNmzZ+O///u/e9y3fYtYvve67777cOuttyISicDn8+Gqq66CJEm48MIL8etf/xq3\n3norhBCYMWMGpkyZgh07dnheryhK1ntWVFTgnnvuwa233opoNIrRo0fjrrvu6na/I0eOzPk4AGzc\nuBHnnHMOADPYb7vtNowePRrTpk3Dt771LSiKAp/Ph9tuu61X551ooElcz5uIiGhoYbM5ERHREMPw\nJiIiGmIY3kREREMMw5uIiGiIYXgTERENMUPmVrGmpq6i7m/UqAq0tUWLus89Ec9j//Ec9h/PYXHw\nPPZfsc9hbW11zsf32MpbVbPvKaXe43nsP57D/uM5LA6ex/4bqHO4x4Y3ERHRUMXwJiIiGmIY3kRE\nREMMw5uIiGiIYXgTERENMQxvIiKiIYbhTURENMQwvImIaNh6443XCt723nvvxI4d23vc7sMP38cN\nN/y8P4fVbwxvIiIalnbu3IFXX11c8PZXXXUt9t57QgmPqHiGzPSoREREvXHXXb/D+vXr8Mgjc2EY\nBnbs2I6dO3fgnnv+iN/+9jdoampELBbDD35wOaZOnYYf//hyXHPNz7F06WuIRML44out2L59G668\n8lpMmTI153u89toSzJv3dyiKgoMOOgS33XYL6us34M47fwefzwe/349bbvktdu7cnvVYdXXuqU8L\nsceGd0c4gfc3NOLYg+sG+1CIiIa9f7z+GVZuaCzqPo87uA7nz5ic9/l///fZeO65f+D7378MDz/8\nIDQthT/+8c9oa2vF8cd/DWeccRa2b9+GG2/8BaZOneZ5bWPjbvzhD/fh3XeX44UXns0Z3tFoFA89\n9AAeeeQJVFRU4Oc//yneffddvPzyyzjnnPMwa9a/4YMPVqK1tQUvv7wg6zGGdx9ceecbaO2M46ZL\njsOk8X0/gURENDQccshhAIDq6hFYv34dXnzxOUiSjM7OjqxtjzjiKABAXV0dwuFwzv01NHyBiRO/\nhIqKCgDA0Ucfg/Xr1+PEE0/CH/7wP2ho+AJf//ppmDRp35yP9cceGd5b23YiPOFNSMnD0dwRZ3gT\nEZXY+TMmd1slDwSfzwcAWLJkETo7O/HAA39GZ2cnfvjD2VnbKkp6gREhRM79SZL3OU1LQZJCOPbY\n4/HnP/8Ny5cvw5w5N+PHP74652Nf/eqxff4se2R4f7ztCyjVbTBG70RLZ3ywD4eIiEpAlmXoup71\neHt7O/baa2/Isow333wdqVSqT/vfZ59J2LbtC0SjEVRUVGLVqg9x1VU/xrPPzsOUKSfi9NPPgBAC\n9fUbsGXLpqzHGN69dPykA7G4CZArO9DSwfAmIhqOJk3aD59+ugH33XcnKiurnMdPPnkGfvGLa/DJ\nJ2vxb//2TdTV1eGRR+b2ev+hUAhXXHEVrr32J5AkGUcccRSOPfZY7NzZghtv/AWqqqrg8/lw/fU3\nob7+06zH+kMS+doDykxTU1dR93fjit+ipTOCQyLn4yfnHlHUfe9Jamuri/632dPwHPYfz2Fx8Dz2\nX7HPYW1t7m7dPfY+7y+P2Q+SL4mmcOtgHwoREVGv7LHhPbFmPACgLdk2yEdCRETUO3tseI8JjQIA\nxBFGPKkN8tEQEREVbs8N74rRAADJH+egNSIiGlL22PAeW2FW3pI/xtvFiIhoSNljw3uME96svImI\naGjZY8M75AvCLwcg+eNoZuVNRDQs9WZJUNtHH32ItjbvnUjlsAyo2x4b3gAwMlDDypuIaJjq7ZKg\ntn/+88Ws8C43e+QMa7a6ijFojDWiqSt7UnoiIhra3EuCXnDBRbj99lvQ1dUFXddx9dU/w+TJB+Lx\nxx/Fm28uhSzLmDp1Gg455FAsW/YGtmzZjDlz7sD48eOz9pu5DOjVV1/nLANaWRkCIJdkGVC3PTy8\nxwItQKfWPtiHQkQ0rD332UtY1bimqPs8uu4rOHfyWXmfdy8J+uijf8YJJ/w/fOMbZ2PLls24994/\n4J57/oinnnoc8+cvgqIomD//WRx33NcwefKXcc01P88Z3LmWAf3ww/fx1ltLcc4552H27AuxaNHr\nJVkG1G2PDu/a0FgAQAysvImIhrM1a1ajvb0Nixe/DABIJMzu0pNP/jquvvq/cNpps3D66bN63E+u\nZUDr6zc4S362tOzClCknlWQZULc9OrzrKszwTildMISALEmDfERERMPTuZPP6rZKLjWfT8VPf/oz\nHH64dy2L6677JbZu/Ryvv74EP/nJf+Chh/7a7X5yLQMaCAScJT/XrFlZsmVA3fboAWt25Y1gFNE4\nZ1kjIhpO3EuCHnro4XjrrTcAAFu2bMZTTz2OcDiMRx6Zi0mT9sX3v38ZqqtrEI1G8i4lCniXAQWA\nVas+xEEHHYpnn52Hzs4OfPOb38QFF1yE+voNzmOnn36G81ix7NGV96hgDSQhQw5EEYmnUBXyDfYh\nERFRkbiXBP3hD3+E2267Gf/1Xz+EYRi4+urrUFVVhfb2Nlx22fcQClXg8MOPwIgRNTjqqK/ihhv+\nG7/97Z3Yf/8DPPvMtQzokUcehVgsihtv/AVGjaoBIJdkGVC3PXZJUHvZtp++fgvicQM/O+pa7L/3\niKK+x56ASwj2H89h//EcFgfPY/9xSdABEpCCkNQUIvHUYB8KERFRQfb48A4qIUiqhs4oJ2ohIqKh\nYY8P7wo1BABoj4YH+UiIiIgKs8eHd6XPvFevIxEZ5CMhIiIqzB4f3iMClQCAjjjDm4iIhoY9PrxH\nVZgj+Xa1tw3ykRARERVmjw/v0RXm7WE7OjrQHk4M8tEQERH1bI8P70qfOWBNUpNYvallkI+GiIio\nZwxvn9nnDSWFpvbY4B4MERFRAUo6Peodd9yBDz74AJqm4T/+4z9w+umnO88tX74cd911FxRFwfTp\n03HFFVeU8lDysm8Vk9QUWjvZbE5EROWvZOH97rvvYuPGjZg3bx7a2tpwzjnneMJ7zpw5ePjhhzFu\n3Dh897vfxcyZMzF58uRSHU5eITVo/qBoaOviRC1ERFT+Shbexx13HI44wlx6bcSIEYjFYtB1HYqi\noKGhATU1Ndhrr70AACeddBJWrFgxKOHtV/wAAJ9foK2NlTcREZW/koW3oijOYuXPPPMMpk+fDkVR\nAABNTU0YPXq0s+3o0aPR0NDQ7f5GjaqAqipFPcba2mqM1M3K2+8XaI8kMXZsFSSu690r+SbOp8Lx\nHPYfz2Fx8Dz230Ccw5IvCfrqq6/imWeewV/+8pd+7aetLVqkIzLZK78IISBLMiTFQCKpY+u2NlQG\nuTRoobgKUf/xHPYfz2Fx8Dz237BYVWzZsmX4v//7P8ydOxfV1ekDqKurQ3Nzs/P77t27UVdXV8pD\nyUuSJPhlP2TVXHi9jYPWiIiozJUsvLu6unDHHXfgwQcfxMiRIz3PTZw4EeFwGNu2bYOmaVi6dCmm\nTp1aqkPpkV/xAbIZ3h2R5KAdBxERUSFK1mz+8ssvo62tDVdffbXz2AknnICDDjoIp512Gm6++WZc\ne+21AIAzzzwT++23X6kOpUd+xY9kyhxpHo5xXW8iIipvJQvvCy64ABdccEHe54877jjMmzevVG/f\nKwHFjw6YS4IyvImIqNzt8TOsAYBf9kMXZmhHGN5ERFTmGN4w+7wNGIBksPImIqKyx/BGeqIWyDrC\ncYY3ERGVN4Y3zD5vAGZ4s/ImIqIyx/CG2ecNAKrPYJ83ERGVPYY3rPu8AYRCEitvIiIqewxvpPu8\nQyEgHNMG+WiIiIi6x/BGus87GABiCQ26YQzyEREREeXH8Ea68g5YS3tH46y+iYiofDG8Afhls89b\nVc2KO57UB/NwiIiIusXwRrryllUBwGw6JyIiKlcMbwABJQAAzrKgrLyJiKicMbwBhFQzvCXVrLjj\nSVbeRERUvhjeAIKKNVJNNu/xjiVYeRMRUflieAMIWpW3IZkVd4yVNxERlTGGN4CQGgIAGJJZecdZ\neRMRURljeAMIWgPWdCQBsM+biIjKG8MbgCqrUCQFmhXe7PMmIqJyxvAGIEkSgmoAKWGFNytvIiIq\nYwxvS1AJImkkAABxTtJCRERljOFtCaoBJHQrvDlJCxERlTGGtyWkBpHQk1BkNpsTEVF5Y3hbgkoQ\nAgKBoOCtYkREVNYY3hZ7opZgCIiyz5uIiMoYw9sSVM0pUitCQCSWGuSjISIiyo/hbQlZ85sHQwJJ\nzUAixaZzIiIqTwxvi115B4IGAFbfRERUvhjeFrvP2+c3wzvM8CYiojLF8LbYzeYqw5uIiMocw9ti\nV96yz+zrZngTEVG5YnhbglblLavmbWIMbyIiKlcMb4tdeUNheBMRUXljeFtC1mhzQzJDm+FNRETl\niuFtCTK8iYhoiGB4W+w+b3tNb85vTkRE5YrhbfHJKmRJdtb0TunGIB8RERFRbgxviyRJCClBZ01v\nneFNRERliuHtElQDiGlxKLLEypuIiMoWw9slqAYR1xJQFRmaJgb7cIiIiHJieLsErWZzRQE0Vt5E\nRFSmGN4uITUAAQHVLxjeRERUthjeLj7Fb/5XNRjeRERUthjeLn7ZBwCQVYGUzj5vIiIqTwxvF5+s\nAgAU1YCm9b/ybutK4MEX16G5I9bvfREREdkY3i4+xay8FaU4fd5PvFqP9z7Zjb8u3NDvfREREdkY\n3i4+u9ncZ0ArQrN5PKl7/ktERFQMDG8Xu89bUQwYQsAw2O9NRETlh+HtYjebS4rZZM5Z1oiIqBwx\nvF2c0eayGdq8XYyIiMoRw9vF7vO2K+9i9HsTEREVG8PbxWk2tyvvItwuRkREVGwlDe/6+nqceuqp\nePzxx7OemzFjBi666CLMnj0bs2fPxu7du0t5KAWxK2/I5ujwfjebC1buRERUfGqpdhyNRnHrrbdi\nypQpebeZO3cuKisrS3UIvebPCG8OWCMionJUssrb7/dj7ty5qKurK9VbFF1Ws3mxwlsqzm6IiIiA\nElbeqqpCVbvf/U033YTt27fjmGOOwbXXXgtJGtyUs6dHFZLdbM5mbyIiKj8lC++eXHnllZg2bRpq\nampwxRVXYPHixZg1a1be7UeNqoCqKkU9htraas/vcf9IAIBqLi6Gqqpg1ja94fObp9enKv3aT7kb\nzp9toPAc9h/PYXHwPPbfQJzDQQvvs88+2/l5+vTpqK+v7za829qiRX3/2tpqNDV1eR4Lx1IAgJSW\nBAA0t4TRVBPo83ukkpq1Pz3rvYaLXOeReofnsP94DouD57H/in0O810IDMqtYl1dXbj00kuRTJoh\nuXLlShx44IGDcSge9mhzQ+KANSIiKl8lq7zXrl2L3/3ud9i+fTtUVcXixYsxY8YMTJw4Eaeddhqm\nT5+OCy64AIFAAIceemi3VfdA8St2n7dZMevs8yYiojJUsvA+/PDD8dhjj+V9/uKLL8bFF19cqrfv\nE6fyBitvIiIqX5xhzUWRFEiQYMCsvDnDGhERlSOGt4skSfApPqfy7uk+7x3hXXjsk38grsUH4vCI\niIgADOJo83Lll33QhTVKvIc+7/s+eghdyTDGVdTi9H1PGYjDIyIiYuWdKagEkDQSAAC9m8p7W2MY\nXckwACBpJAfk2IiIiACGd5bairGIGREEv/oqtic3593ulfcbnJ8lzn9KREQDiOGdYXyFORe7pGpY\nrb2af0N3i/ogT+tKRER7FoZ3hnGV6YVUVPjzbifAe8CJiGhwMLwzjK+oTf8iCquoZTabExHRAGJ4\nZxhfOc75OYEINEPLvaGn8GZ4ExHRwGF4Z6j2V+EHX/4h9I4xgCTQGm/r8TXs8iYiooHE8M5h/5pJ\nMLpGAQCaYq05t/GMV8tTebNXnIiISoHhnYOqSBApc7BaLJV7KVLhSmbeKkZERAOJ4Z2DqsiAYU4+\nl8g7AYsnvYmIiAYMwzsHVZEhDAUAkNBzh3chzeZERESlwPDOQVUkQDfDO5knvHuD4U5ERMXE8M5B\nkiQo1pot21o6cm8kvNsTERENFIZ3Hgp8AICV9TuwsyWS9TxHkhMR0WBheOdhhzdkHZ2R7pvO2SxO\nREQDieGdhyqlwzsX4bpXzBD5lw4lIiIqNoZ3HnZ4S0qe6VFd3EFORERUagzvPHyKz5yIRdaR1Lqv\nrA2w8iYiooHD8M5DlRXAUCApOpKp7KZzd7HNZnMiIhpIDO88fKp1r7esI5nqofJmszkREQ0ghnce\n5ixrKiRFQ0LLUXm7f2blTUREA4jhnYeqyAVX3nqe8GZBTkREpcDwzkOWJXN+c1lHIpljxLn7VjEO\nWCMiogHE8M7DMIQ5YE0WSGiprOe9zeY9lNicw4WIiIqI4Z2HYQhAN+c3j2mJ7rdlnzcREQ0ghnce\nuiGcZUFjqXj2Bp5bxdi5TUREA6eg8F67di2WLl0KALj77rtx8cUX4/333y/pgQ023RAQyQAAIKKH\nu92Wfd5ERDSQCgrvOXPmYL/99sP777+PNWvW4MYbb8R9991X6mMbVIYhIBIhAEBMdGU9z1vFiIho\nsBQU3oFAAPvuuy9ee+01nH/++Zg8eTJkeXi3uJuVdzfh7VmYhM3mREQ0cApK4FgshoULF+LVV1/F\niSeeiPb2dnR2dpb62AaVIQREMggASCLXet7pwK7f1pZzxDkXLCEiolIoKLyvueYaLFiwAD/96U9R\nVVWFxx57DJdcckmJD21w6a5m86Scq8873VS+uy2CpvZY9hZ2djPDiYioiNRCNvra176Gww8/HFVV\nVWhubsaUKVPw1a9+tdTHNqgMwwAMFUJToSvRrOfdlTckkQ5q9zZW5c0KnIiIiqmgyvvWW2/FwoUL\n0d7ejgsvvBCPP/44br755hIf2uD60rhqAIBIhKCrkawAzry3O3ezub1taY6RiIj2TAWF9yeffIJv\nf/vbWLhwIc455xzcc8892Lp1a6mPbVBdcsbB+N7Mg+DTqwFZR0cyo49fSieyxMqbiIgGUEHhbYfP\nG2+8gRkzZgAAkslk6Y6qDFQGfTj56AkIiBEAgMZok+d54b63WxI5VyGxA53ZTURExVRQeO+33344\n88wzEYlEcMghh2D+/Pmoqakp9bGVhZAwP+fOsDe8vROziJwBzcqbiIhKoaABa3PmzEF9fT0OOOAA\nAMDkyZNxxx13lPTAykW1MgotALZ37fY8nll557rXO+lrARSJfd5ERFRUBYV3PB7H66+/jnvvvReS\nJOGoo47C5MmTS31sZWGkbzSAXM3mmaPNvQm9qf1ztO31OvwVYyBaTy71YRIR0R6koGbzG2+8EeFw\nGBdeeCHOP/98NDc344Ybbij1sZWFmmAVhACi1uIkH3zaiBfe3gJkNJvruje869s2AQCUmhb2eRMR\nUVEVVHk3Nzfjrrvucn4/5ZRTMHv27JIdVDmpCKpAVIJm6ACAB55fCwA4cLLrukcS6EqGcdt7D+Oc\nyf+GQ8cchNZ4KwBApHzs8yYioqIqeHrUWCw9g1g0GkUi0f0a18NFZVAFhAxN1z2Pp9y/S8DqjlXY\nEdmFBz5+GADQEm8DAIhkiH3eRERUVAVV3hdccAHOOOMMHH744QCAdevW4aqrrirpgZWLiqAPEBI0\n4Q3vpK65fhMQGQndaoe3prLyJiKioioovM877zxMnToV69atgyRJuPHGG/HYY4+V+tjKgll5S9AN\n74xqKS0d3lLGgDUhhFN5QzYY3kREVFQFhTcA7LXXXthrr72c31evXl2SAyo3duWti+6azYUnoBN6\n0pk+VZJ1DlgjIqKi6vOi3HtKNVkZVCGEbC5U4pIy3GEunAFtABDX4+mnFH2POVdERDQw+hzekiQV\n8zjKVkVQBSBBhze8tYzKO2GkB/DFtHR4S7LOAWtERFRU3Tabn3TSSTlDWgiBtra2kh1UOamw+ryF\n8PZdp3QNAdd2yTzhzT5vIiIqtm7D+4knnhio4yhbiixDEjIMpKC5J2KR3TOsGXkrb7DPm4iIiqzb\n8J4wYcJAHUdZkyUJAgZSWrqpXJJdt4pJQMod3qmoazsDBpjeRERUPH3u8y5EfX09Tj31VDz++ONZ\nzy1fvhznnXceLrjgAjzwwAOlPIx+kyADEEhprn5v1R3eAkmRDu+2RIfn9ULSQEREVCwlC+9oNIpb\nb70VU6ZMyfn8nDlzcP/99+PJJ5/EO++8g88++6xUh9JviiRDSAaSrvD2VN4QSBnp9c2d8DbM0ysk\n721mRERE/VGy8Pb7/Zg7dy7q6uqynmtoaEBNTQ322msvyLKMk046CStWrCjVofSbLCkABOJJVwgr\n3klaUsIV3vF28wctCAAQYOVNRETFU7LwVlUVwWAw53NNTU0YPXq08/vo0aPR1NSUc9tyoMgyJFmg\nI5JuGpdUb+Wtwd1sboV3yhwf+2w3AAAgAElEQVSPzsqbiIiKqeAZ1gbbqFEVUFWlqPusra0uaDtV\nMU+TUFzXOlblLTQVkj/puQu8PWk1m1vhDVkv+L2GouH82QYKz2H/8RwWB89j/w3EORyU8K6rq0Nz\nc7Pz++7du3M2r7u1tUW7fb63amur0dTUVdC2spABCdi2M31vu2SHt+5zqvB9qvZGQ3gHuhJh87lU\nABIAQ9IKfq+hpjfnkXLjOew/nsPi4Hnsv2Kfw3wXAiUdbZ7PxIkTEQ6HsW3bNmiahqVLl2Lq1KmD\ncSgFUWTzNHVE0/3akDXz/m3NvP6pwlhMm+AdnCecZnP2eRMRUfGUrPJeu3Ytfve732H79u1QVRWL\nFy/GjBkzMHHiRJx22mm4+eabce211wIAzjzzTOy3336lOpR+UxUF0IHOqGvaU1UDdBWQzHu4fQhA\nldOn0y/7ENet5nb2eRMRURGVLLwPP/zwbpcNPe644zBv3rxSvX1RqbIV3rH0oDQoGoSuOn3fighA\nkdN98j7Fh6iumE0bDG8iIiqiQWk2H2pUK5S7oq7R5opZeUtOePuhSunwViU1fZ+3zGZzIiIqHoZ3\nAXyKGcrhuN3nLZzK2x6spghvs7kqqxCaFeYyK28iIioehncB/NatYl0xK7xlA5IkzD5v2A/5PM3m\nqqxA6NbvisaVxYiIqGgY3gXwWfeX68KqoJ3bxNzh7Tebyi2qpMIwrN9lnUuTEBFR0TC8C1Dh91k/\nmRHszGvuCm9J+Jy+cQBmFW5V3pJVeXdGk7j/2dXY1hgekOMmIqLhieFdAL/PCm/JmkdNsSpwwzXj\nm65mNJur6cpcMdf0/ufyrVi1sRn3Pbt6AI6aiIiGK4Z3ARTJOk2SgCSlK293szkMNavZ3A53STYr\nb3s98GSKA9iIiKjvGN4FUOxbwCSByqAPvoDVg+2qvCXd22yuuprNoegw2OlNRERFMmQWJhlMslV5\nS5JAZciHsGrAACB0BYmNR0EZ2QRZqvbcKqZIKgAZQpedypuIiKgYWHkXIN1sbuDEr4yHL2D1fRsq\njLbxSG35CoRhB7b9GsXZxu7zdkjSwBw4ERENSwzvAshWEP/7qZNx5tcmweczk9i5jxuArouMZnPV\n2UbKvM+bVTgREfUDw7sAduU9fmwIkiRB8dmVtyu8hchoNndV3jL7vImIqHgY3gWQrSVBDWGGtqxa\no8Vdo811XXjmNreb0IWuAIoGwzDSO2SzORER9QPDuwB2Fa1b4S1Z93m7m80NQzgD2wCkg1xXIUlA\nyuDiJEREVBwcbV4AO5S/6NyGz9o3A0rKfMJwVd6GAclVUctOs7n537iWXguceuetj3dgQm0lDti7\nZrAPhYioLDC8C2D3eS/e+joAQLZWGfMMWDPSk7CYr/Fuc+fqe3AELhqQ4x1OYgkNjy7cAAD4yy9m\nDPLREBGVBzabF0B29WUDgJCyJ2nRDYHHXql3frfDW1LNKj2hJyBghntnJIkHnlsDg6POe6TpRs8b\nERHtYRjeBVAk72kSMMy7vQxvn/fGhnbXa8zntN2T0tsgXZl/UN+Enc2REh0xERENZwzvAmSGNwAr\nuNN93PGk7unztkebG51jobWMN3+Gd05znfeP9YhniIgoG8O7ALKsZD+oe4cLRGIpz+8KXK+xKnQD\n3hHnDO+esWeBiCgbw7sAco7KWxgZ/eAwB1c5r5HdK45Z94lL3srbYHj3iOdoePvX+t247I6l2N0a\nHexDIRpSGN4FyN9s7hV2Vd+K69TaQS/YbN5rDO/h7c8vfQLdEFi2eudgHwrRkMLwLkDmaHMATjXt\n5g7j5vZk+glhbtssbYLkT1cYKY6k7hFH5A9v/PMS9Q3DuwC5Ku9RVaFuX7P4vW3pX6yg362uQ/Co\nt5yHUymGd08Y3kRE2RjeBcjV5z1+VBWqQj4AQCjQw1w3OZrYAVbehWCzORFRNoZ3AZQczeaqrDrL\nfFZX+Lp9vcjRxA4AyZSe83FKY3jvGbhWD1HvMLwLIOf4ZnEv/1kdyg7vQyeNxrUXHoWvHTouu/KW\nNQACb3e8jHe2vwcAWPDOFsxd8ElRj3s4YHYTEWVjeBcgksq+jcW9/Gd1hT/r+ZHVfhy272jzOeE9\nzZI/DskfxxfJT/HEp88CAJ5ftgUr1u0q8pEPnLWbW7BibfGPn5U3EVE2hncBJo3YBwDw5VGTncfM\nZnPz5xGV6fAWmlmFj6kYCQCQ5exmcykQg+RPrzJmrxMOwGmKz2f+ss34+LPmPnyK0rrrHx9j7kvF\nbznggDUiomwM7wJU+6vwwIw7cOa+pzqPqa5Z1/yq7Axei6+ZisTGo3DUhP0BABUBNavZXArEIAVi\nzu+NkXQYdxdWndEkXnznc9z7zOr+faAS6unio7fKObxffGeLs+IZEdFAYnj3guIKbFVWPfNu71NX\nBQCoCYzAdbPOwKTx1QCAiqAv655wSUlB8qfD+4GP/+KsEa7p3YR3JJn3uXKR1Io7gr6cm83nL9uC\ntz7eMdiHMaSV8bUZUVljePeCu59blVQ4y2ZIwMGTRgEAamtCzs8AUBlSs/q8IQlP5d2aaIW692YA\ngN7N7WPhaCrvc+Wi2CPoyzm8iYgGC8O7F7Iqbye7JZxxwpfwzan74rJvHOp5TWXQlzUPOmTDCe+v\n7zPd3F9tAyBrWZV3Uk9i4ZZX0Z7oQGe0/CvvRLHD23U6/ufvH6KxPZZ/40HCC4y+4y1iRH3D8O4F\n9/3e7j5vSQJURcbZ0/ZH7UjvzGuVOZrNIRmQ/HEoIoBzDzwLB4a+AknVIPnj0DIq73/Uv4CXtryC\nFzctQke4/MM7WeRZ49x93vUN7Zj32sai7r8YONlO37HZnKhvGN69oHbT551PZUjN7vOWDEi+JFQj\nCAAwdOt5SUBzVXHtiQ6s2LnS+b0jo897xY6VWLBpUS8/RWkVu/IWGVVtOS7mknnBRURUagzvXvBU\n3pKCQtK7MuiDENnN5lBSkI0AAEDT0o+7+7zvWzU3/RJJdgashQLm/h7f8DQWbX0dulE+M7UVu887\nM6zLsYk6VeRBekREPWF490L2aHMzSLrrtzNvFcucpCUBSQIk3QzvlDUOTZIM6Fafd2ckieZYi/Oa\nqBZzKu+qjBndolr59AMnSthsnuv3cqAxvPuNfd9EvcPw7gXPaHO5h8VILLIsZd/n7TMnaJE0c3IX\n3S5WJQOaYQbBtX98C7rQceDIAwAAsVQMHZEEAMCvKp77qbuS4d5/mCJyH0vxR5tn/l4e4a27Dox9\n3n0nCup8IqJMDO9eUFyBrcqq606xHsoG4X1e8pshbM/GJgzreUkgkdTxxqrt0GWzyq5UK+BX/Ihq\nMcQTZjDqhkBcT8/QFklF+vyZisHdtF30Pu/MyrtMwlvT0sfR3b35ROXglZUNWL+1bbAPg4qI4d0L\n7nW9PQPWemzyywhvnxnMwqq8Dd16Xjbw7Fub8bfFn0JSzI5wvxxEhRpCTIs5FZ4hBLqS6cAO55h7\nfSC5w9tdeacMDXd/+Ces2LEy18sKkt1s3uddFZW72uaANSpniZSOp17biN8/uWqwD4WKiOHdC1kD\n1iw9ZXfdqFDOx42kWXk74S0Z2N5kNoFLqtkRHrDCO6rFnYFRhiEQTrnDe5Arb1d4ufu8t3Y24LP2\nLXh8w9N933eZjjZ3BzYHrFE5K5fWKiouhncvSK5RNYprkpae3Pz94zC+60QkPj3G83gqYTbD233e\nkiTgXApY06X65QBCaghxLY6UZm5oCOFpKh/sZnMtT+WtGVquzXsl84unXAasuQepsc+7H6w/Z+bY\nBiLqHsO7j3yyAvf0qN0J+lVMCnwZRkcthKv/OxaVYRgCuqvytkmqGXw+KYAKXxACAklh9pUbRmaz\n+WBX3q4+by0d3rmWUu2tzLDOvO97sLgvWDjavP/K5aJsOOK5HZ4Y3n2UOT1qTxTFOtVGeluR8iMS\nT0HX0n3e6RdYlbcUQIVaYT1vPmYIIJxKjzAPJ7NDMvMfbCKl4911u5zqvZjczebJZPrnLtcxLnrv\niz7tO/N7Rx/gLyJNN7BkZQOice+88u7AZp93/7Fpt3TKpauJiovh3UfmwiSmQu5R9dnh7VqkROgq\nWjsTcPLUGm0OpPu8VRFASA1ab2pW45l93pnN5l/s7sIPf7cUb3603XnsuTc346EFn2D+si0FfT63\ndVta8doH2/I+7xlt7ro4CLtuYVv4Xu/fF8jRbD7AX0TPv7UZT762EX9fUu953N1Uzmbz/mN4l065\ntFZRcTG8+6jQ+7xtimIlvHuFMV3BLY+uRDhid3obTsVsh7cCPypUc8CbZFXjhiE8TdI7Irs8s6wt\nX7sLAPDU6585jzU0dgEANu3o7NVxA8Cd8z7C35fU521+y9fn7b7/3JD7tiJa1gxrA/w9tHFbBwCg\ntTPheZwD1orD/nOyabd0mN3DE8O7j1RZ6dWiCnblLQz7vxKc028FuiS5m83NKlsRfgTUgPVYesBa\nXDPv8z669itoT3Rgbct656V2FSP7EqhvMwPcp5qj4/vTbJ6vOvKMNk+6wtvVIiAUb/gV/J5Z93kP\nbFC2dZnHPbI64Hnc22wuEI2nsO7z1gE9tuGEAVM6bNUYnhjefaRIhU2PalNVO6itjQ1X5S7Sk7TY\n7CpbNgLpJnopfatYzArv0yedAgB4a9sK57VOv/A+a3DvqofwcdNa+K33T/ajSszXt5tvkhZ35d3X\n8M5s8hvoUcntYfO4R1T4PY+nXIP0NM3AnfM+xp1PfYT6hvYBPb6hzv6nM9AXZW5CCCz9cBt2tQ7u\nfAmlwlaN4Ynh3UeqrOLkoycAAA760qgCtvc2m8vCHd7Wn8E1YE3yJyAMCbLwwWc10UtyepKWmBaD\nLBQ89sIuHDhyf2xo24hdkd3m7uzAC5lBMn/Ty/D5zPdI9WPu8XwDX9yjzd39v+5BdYbct+VMM9/S\nEAKabmTNvFYq9mfOfD8tY5KWLTvN7oimMlxvvJw5zeaD2POweWcnHnulHr+a++7gHUQJsfIenhje\nfeSTFXzntC/jjh9NwWH7ju5xe6fytprNZeSqvN3hHYNIhqDrrv512Wo2N4CYFofQfdi8vRMnjDfv\nH/+0bZP5vN1vnjJHqTdGmyGpZgWZ7EezuZ5jGlAhhGeeb/e0oRHXKHhD6Wt4Zw9Yu/z3b2DO3z7o\n0/56w90FkDkoTcszYI0LbPRNb6vD9zc04sEX1xWlqozEzC6q4VqgsvIenhjefaTKKmRJwtiRuWdP\ny9o+Y7S5e7S63Q9uN5srqg7Jn4RIhKDpBnyKtYqY5K6844BuTtEaUioBAM+/vRHN7bF0FaOkB4nF\nVXOFMvfgqriW7hMvRGbl3RZvx/ee+ylWtb0PqGY4u0MtYbgCW+1bs3n2gDXzd7vSLaXWrvT88cmM\nFgv3eXT/LGekdyyhIZbo/2Q1A03TDWzd1TVg79fb6vCP89fivU92Y3cBTd2vf7itV5/FMATunPeR\n526NoYyV9/DE8O4j91SphVCt0eb2JC2q5FrW0x6wZjWLT5xo7lskQkhpRlazOWCGt6GZj8swt4+m\n4nhx+efpK20lHRqblXcANenp835iwzO4d9VD+KhpbUGfQc+oPjd1fI6ElsCyliUIHvkG4Is74W0I\nA5qhQYHVV9zHyjuzz3sgFwGx108Huq+8NU/l7Q3vK+5+C1fc/VaJjrB0HlrwCW55dOWA9eH3tTp0\n5k/Io7E9hsdfqcctj+afXz+ztaSxPYZ1W1rx10Wf9umYyg2ze3gqaXjffvvtuOCCC3DhhRdi9erV\nnudmzJiBiy66CLNnz8bs2bOxe/fuUh5K0Vz+le/hrP1metb2LoRdeUtWda1KKn71PWu61IxmcyVo\n9puKRAU03Ug3m9vN6rIBXegwUnZ4p5/fvKPTudIWcgp1FWMBAElEoY773FMl2iPUN1rN7T3pbrIH\nSTGgjGx0giypm8EXRJV1Avp2q1jSSHpaEIq95Gh3uqLp982cRU3zDFhzDTTsY7N5S6wVN7xzOza0\nbuzbDors/Q2NAFBQZVsMfa0OMy8oMxXy/0vmn2y4dX2w8h6eShbe//rXv7B161bMmzcPt912G267\n7basbebOnYvHHnsMjz32GMaNG1eqQymqI2sPxxn7fb3Xr3Oaza3wliDjgL1roMiS0w/ujDb3m1+Y\nduWdsjPErrytMBO6VZFb64VLio4dzRGzipEMQNYxOjAKFx30LfP5jACt9JnN7YVOr6plfAnYI95t\nysgmZxR2QrdmiDMqrffuW+W9UnseoWNegz20qbezRdW3fYZH1j2BVB/mWe+KuirvjLECqTxzm/e1\ngnz1izfRlmjH3DWPFbS9EAJPLKnHui2lvT2tKuTreaMi6Gu+9HSPfUE5LHX765DHPu/hqWThvWLF\nCpx66qkAgAMOOAAdHR0Ih8M9vGr4UuzR5s7tZVbftyJD2KPN7T5t1aq8k0GkdAML3ramFrUGrNnL\nhcIKbwjF83wkrjkBH1KDOGj0gZ7nbVU+c0BbOJk7vNvi7Xh03ZOQrIuJzCrHvtf8lJFnw4hXQK5u\ncyrUlNXfrYgghC47y6D2VhhmOMk1zX16/b2rHsL7uz/CxwV2Dbh1uirvzJDwDNJzN6FrfWz+tVpy\nNFHYRcb2pghe/WAb7pz3UZ/erzvukfWZF2yl4q4OOyNJrN3ckndbz/H11I3ShzJ6uGXdnlh5G0Lg\nd3//EP9c8flgH0rJ9G6asF5obm7GYYcd5vw+evRoNDU1oaqqynnspptuwvbt23HMMcfg2muvzeov\ndBs1qgKq2rum6p7U1lYXdX/dGdlsNT/azeaKitraavhUGYmUt887FJKBlFlZ+/wqWtpTwCjXJC5W\neAvdrIpGjrA+hxXOmiGchU1GVY/AXnXWrWzW/u3PXREIAl1AzIjmPBfLPnkbK3evQuAIGfH3T0f1\niJBnu+QXZrhVh6ogkgHIwSg0w0BtbTVi7eaAMkX2QST9gJrM+R7vbVuFUcEafHns/p7Ho/EUKoLp\nqk8Zux1GR61nm978/Xyh3v+9U64vPUOSPK/3B9LHJrv6XYMVfmc7dytBT+8dCJj/FDVDK+g4O+Lp\nC7Fi/3/c1pluUQm5Pk8pqT7FeZ9fPLQEja1R3Hftydhv75qsbSOx9EVVVXXQeV2u44y7rrnyfY6a\n1phnm5he+N9tKGgKpy+cC/k8w+Ezh6NJfNrQjk8b2nHJN78y4O8/IP9mSv4Olsz7ZK+88kpMmzYN\nNTU1uOKKK7B48WLMmjUr7+vb2orb91ZbW42mpoEbTdvVZX1BWAEqdKCpqQuyLOWYpMX6YhYyOrvi\nUCUFCddrneZva8BaW6s5ktsO/65Iup9Y0hR0tVnPWzO0NTZ2QpIkdMbMintXuAlf7GyCX/Z5+vI7\nwzFnv1IwjJaWCJpC5nsmkjpefOdTqOOARFQ4rQApI4mmpi7s6jAHOukpCdD8kIKRrPNtCAN3vvMQ\nAOCBGXc4j2/Y2oY7nlyF807e32yokAClphkpyfBML9ubv19LR5dn+22NYUgSMKG2yrPdZ9s7cO/T\nH+Pq849EY0u6RSIWT3le3+EKuIireb2tPepsF0+mq+jujvWtxmVY9Nkbzu/23ydTMqXj3U9247iD\n69DWnv73UKz/j1es24VdLVEctl/61sfWtuiA/DuJu85vo9XPvnFLC6p82Y2Dja576Ztawmiq9uf9\n99zSkm7ty/c5OjLOZVNzz68ZSlpb0/8f9/R5Bvp7sVQiroWEBvrzFPsc5rsQKFmzeV1dHZqb002d\njY2NqK1NV05nn302xowZA1VVMX36dNTX1+fazbDh3ELk6vMGkNHnbQ1YU+1FjmVougG/Yo3Ylg0E\nfIrTbG73ecPwNpvHk5qzTYUagk/2eZ5/7q3NeGfNTqfPOqEncd1bv8b1Cx/CZ9s7nGN2z58uBSOe\npuKuWNJpAZAMn3MsQtagG4bTbA5dgdB8kBQdCc3bdJ7Qc98+ttIaLLVw5RanA1JSNchVfR/5HE15\nJ0/59V/+hRsf/lfWds8s/QyRuIZnlm5yms1HVPg8zea6YeCL3el/nO4+b/e98O6R/d01Xc5bu8Dz\ne0TLfaG6YPnneHThBjz56saSNO3OXfAJFiz/HM0d6XM1UPO25+qXzddk7668e1qOtZAxEpnvM9xW\n4ervx+mIJPHBp03FOZgBMtz+hrmULLynTp2KxYsXAwDWrVuHuro6p8m8q6sLl156KZJJ88t85cqV\nOPDAA0t1KGUhff+vNe847D5vKV1NWuEtK1Z1bshIaQYCavo+74BPTt8CZjWb6xrM6t0K51hCd6rz\nkBqCIiuQhAzJev6fK7bi4X+uzxpwFg5twd3/SPehulcFU0Y1oiWRHhxlGMK5QJCFz6m8JUWDpgkk\nrNHmwlAgUubFR1vcezUaTaXf372wiv1FLqvmY0I3L07kEd5+UPMiQcsK5lw6k7mvhA1hYHc0/cVk\nT6aj6Qa6oklUhXwI+BVPiK1YuxtrXQPFtDyD19yz2el5phAzRPbjLbHcg9B2tZih/vmu0t7j3tKR\n/ruUMrzdrXG5Lm7yjST3hHcPo80LGayV+d7D7YvffQ76MjPh//z9Qzzw/JohNfVvrgmlhpuShfdX\nv/pVHHbYYbjwwgsxZ84c3HTTTXjuueewZMkSVFdXY/r06c5tZKNHj+62yXw4kK0Ba3a/tV15T6yt\nAmA1nctWVW6FNwwFKV0gqKbv8w74FSek7cldNN2AJBSn2Tye0JyAt5cTlaB41wuHQEJPoNJeKxyA\n0FTPVbq78lZrt+PvDQ8imdJR39BuTlpih7er8oaiIaUbSOr2iHgZ0Mzwrm/x3pIW1dKh61772/4y\ntS8OjC6zGVcOeQc8aprAw2sfw8+W3YRIKuoJccMQePHtLc5kOJ3J3IG3cMur+M27v3fudbfvCtB0\nga5oCtUVPvhUb3hv3tHh2UfKM2DNtba5PUJd0pHQcg9Ei2vZrQ/5Rv/b/w/phijpl1OLq0uglMud\nunMkZ3jnCdGwK7x7Or5CgjgrvHvY519eXo///r/lPe63XLg/X19Gntu3Cw6lqX/zXSwPJyXt877u\nuus8vx988MHOzxdffDEuvvjiUr59WXEqbzugrdHml5xxMPbfewdeiSsw7AFp9n+FDE0zYOj23Oe6\n2Wwup8MdsKoj4Qp1pG/NspcTlYXqHW2uaBAQ2K9mknO/t4hXpudgR+4Q+cvL6/Gv9Y046/9NgqRo\nELoC3YCn8tZ1A0nDWr5UV2AkzGOYt+kZHD/hSAStVdJirvDuSHRiZMAcnGR/v8jWoDsRr4DQFUhB\n7/GkdANrms1j//mymwEAPzjsIhwz7ij8a/1uzH97C0LHCEABOhPp4HdXa+/sMCfvWLV7DY6qPdwJ\n70RKRziWwoSxlYgndU94B/2utdxlAy0j/gUlPgJSMIIW3Q/AHHyXTBmAZCB45FuYv6kT3z3sW1nn\nM7P1AzAHreVi37FgGKLHirM/8lXeiZSOrmgSY2sKm1WwJ+4gyZWx+YI3Ek+fn54uYgoZaZ35Pj0F\n/turd1rbGVDk8p/nyhPehkAP89rkNZRaJIbSsfZV+f+fN0yMqQlaP3mbzasr/Pi3KftClc1wkkc2\nIpwKQ7Kq8ZRuIJGy/keUzD5vJ9ytyjulGea93q5wloLm1fLY0Bjzd6E4zeZAetWyoBLElV/5sfmg\nrHtmrAqnIhgd9C668q/1Zn/0xoYOc1CcrkI3BISRWXlbzea6DL3xSxBJM7Cjrv5cd3gv3rrUqdad\nudntixFdhYhXWp/JfZtQdoDtipjHZ37BC2cZ1Q5X5e2e6tReS317s9msbs+E12atJmZW3rInxGLW\nQLTbLjsBoTFtSFR/Dv8Bq+GbsAmrxItOU3hS0yH545D8CWzsyD0NbVw3gzKoBHDKPicCQM570qOp\nmPP31Q0BrciVhbs5tdm1drm7JeGOJ1bh539a4Zl5rj/0HirCfBcoUdd0sz1V3oWFd+ZtgIV98Wd2\nKQgh8PFnzWU3Ha773PYn1IbSLWdsNqeiGVUdwG2XneBUywq8k18okgJJ1RD48ofYEdllzaomIaUZ\nSKYEhCFDkg34fa6QFq5lPo2McA5EASFhTMhscpZyVN4AsGZjJ3738GdQjRAgG051J4RAOBVBlTWR\nS6akpluVt2p++en2RDEaNF044W1oMiBk6O3mYEU7oAHvILKPm9bipS3mGIms6V11FUas0hz17k+/\nRtMMyJL3f2FP8Lmmh+1IdDrv51621J4AJ2m9zl533V6UpLrSD58qw3AtwBK3ngv61Zwz7dmfMakZ\nkHxmOLfEWz2f3WZX3idNnIqJVXtnfwaYs9XNee9ObAq+CkBYK6sV98vJHUSeytsVjvZ88u5m6/5w\nh0GuUMn3GfU83RQ5t+1L5a27jyv//jOPb/naXbj3mdX466INPb7nQHJ/hP5c8w2lanYoXWj0FcN7\nAO01phIHpk6D3joOU+r+n+e5zBHGftkHVTFHmyeSulllywZURXaazYWr2VwYCiRfylkARA5GIeuh\n9LzoIqMyt5qk7XlzzIsD3ak8E3oSmqGhUq1AcrN5n6Q9hzoAJDTdDEddhaYJT5/3X15ej664GZS6\nZq+mZr7WDnXAW3kDwEeNa8xNnT5vb+UNAFIo3XSe0DSnYjyq1jzGlGEHp56ezAaAgMD7uz92nks/\nYU9ba430z2hTHFFhhjeQDji7sgr6FShq9rdh0khCNwxsbwxD8iec998VzZ4C2J7oJqQGnb9VKiPk\nP2hcjY5kJ8LyLsjVbVafd3Er77hrBTXPimnWZ3avsFasL3F3tZ85h735Prk/o2dq2j40my98dytW\nb2rJu02+VfIyZVbe9iDGTdtLv2hOb3gGBvbjNoVi/z9XSkPpQqOvGN4D7IpZU3H1cT/AtMO+1O12\nqqzCp5qVdyJlhndNtYq6kaH0wDO72Vw3zIFhAEJfXQqoCUj+BJRUFYQQ+OPzaxCNCUiygNPsbM8X\n7p5iVU734dn93RVqJTsCVIYAACAASURBVPTmCdA7R8OADsBuEk5Bks3Qjic1T5/3Z9s6sGqTGVS6\nZlXyVmVu94UDQDSjv7cl3obGaHO6/9OpvBWIlNns7p5mNZyMQkDgyNrD8c39Z5rnwqpaw7GU83q9\nrQ5CAG82LEdjW8QTRPY99kKyBt9l3F49wmo2B9Jf1vGEBglAwK84I+LdknoKTy/dhKde/wzwpZug\nd4R3ZW1rV95BNQjVuqVPM7zhvXLXh87PytjtVp934V9OWzq+wJ/XPo5wMpJ3tHE8zxzg9mduaDSv\n8tQJG/HytpcKfu/u9NRsnm+kuztce2w2z9hvNK7h6Tc24Z6nP05v002fd+b+3ecvc8rcZmtA11in\ni6w86D20cBS8nyE09Vyxu5XKEcN7gPl9Cg6eNKrb2eQAwCer8Cmy1WyuQ5FVCDkJQ04BkvWl4fR5\n604VDgDKSPP+eilZiVhCw/ufNmXdCy75zdCwQ1FYfeZ25b0zYgbNCL81QYDzeiu8EXFeH4lrnsob\nSFfYuqZ4Xp9wVd6fN5mVynXH/Bhn7GtOpdueaE9/Qcqu+9l17/EDQJc1rWuVr8IJPrvyjsRS6dHq\nsSroLXthV2wXfvXss3jh7S3pE23tLwnzizfzy7raVXl/vH0zVu/YjFhSRzCgQJYkyEqu8E5imTWo\nya68ge7DO6QE0pV3RrN5S7wN1b4qyEKFXNmRNWBtV2Q33tuZf33zf9TPx6rG1bjx5b/i9sdzbxfP\n009rn4+GRnNMgG/CJqxu/xDhWApPLKnvVxO6O0dyZUrmMqw276IwvWs2D8ey++u7azbPvIBwH1Pm\nc01Wd8OoEYFuj2mgGT3cklfwfoZQNbsn9HkP2Axr1DuqYt5fHEtqSGoGAlAQTnXhXflvgGwu4iKE\nq9lcl5wFFeRqs0lQSoUQtkfmWkHv+9IGpD4/DFLAXrnMHDls6GZzvH070spdq8ztIxMAtDmVM2Qd\nMFTEpU4oMEeoR42Uq/K2mrqtCwwtZVW2OZrN12zdBXUsMMJf5dzSFtPi6S8J2T52NT0JnSss7TnZ\nK32V8Cne4AvHNE+fubbjAKhjd0Ie0YpVG9OTB9n3w8eMMO54/37ElNEA9nKe19ROrK94BsrY/fBk\nwyLz/RJnOyPOJVflbUSrIFeEkdCTqAgoiCU0p88bAHZEssM77qq884V3OBnBmNAoSMlKdIR2Q0fK\nEzi3vncnAKAmMAL71UxCwJ7Uxz4uawBdomIbNm34ctYxAN5mczc7HKMJLf33ADD3pbVYs6kNmiHw\nvZkH5XxtT/L1ecuSBEOIvCuCefq8ezlgzT1ffa73BrxVW+b+3bPmuS/0hBDOQL5yC7nM0eZ9NZQC\nkc3mNGh8soqAT0VXxPyySfc3Cydw7C/7ZMqAcC2bKVeYVZIwZIStLys7PNW6bZBrmiFb4W3Ezfu8\nDatvWlENGMLA6uZPEMIIvPCKNWGIYc+/bo149pnNqCJe4am81boGyKN2Oc3Qeko2mxFzhLd7xLs7\nvJ2Kxj2TnD0Lnavytu9Dr/RVOLPI2f3F4VjK2b/QVAgtPdGNh7WNhhS2djag0f+x5+k2YxeSUgT+\n/dMLm8STOkKBdDcBACTWHwe9dbzzGYP281blXaFU5q68dbvPO2Teiw/vrWIpQ0Ncj6PaV4VqqRaS\nBBiBzpxNyvd/NBd3rLwPQgi8/uE2ayY2gdZ4m3ksqubchZApka/Z3AooTRee127aaf5/YfSjedId\nJO4+b7v1J6nluaDoplk7U+aXeFeOkfLdNptrmeGt53zOfftavhYDIQSefXOTZxbDgeAZbd6Ppu+B\nWqSmGPaE+7wZ3mXi9Emn4OBRBzr3OvtkFUG/4vzDE3L6S8eerGREyAy8aELzNM/KlWZ4G7qcbtbU\nXaOiJQEpEDOraWsCFfteclk2oBk6UkYKWjQEZ35Su9ncqnxl64vciFeiPZxIr3AGwDfhMwgrZFMp\nCaOqA1AlMzyT1rSpumFO8iKEOUNb0BXedsUl5HSfd2azPwBENKvyViucCxk7+CKxFKCmK3e4lk1N\nnwcjPSFOPkp2c3IslUDQnx5dDwBC8zvvsaO1AxV2ePviEJoPtYFx6Eh2eia+AbwD1pZ+YIb79tb0\ngCd7lrsqfyWqYN72JwKdeQcP7Yo2YmP7Jjz+Sj2WvN+AjmSXZzIcuTJ7MFVcS2Bly/KsVeeAdEBp\nugHZdZ99NGVdlAT7vmSonmcglWKHd54Q9FbehQ9YE0KgM9q7ZvPsyts1sM89IY/r4iffRcfnu7rw\nzxVbcftj+bs4SsGd17kGBvbEPb/AUFGs1oZyxvAuE//fAWfgJ0dfZt0iBqiy2WxuS0npL2A7qGsq\nrfCOa5B82TN1CUN2+vjcfeKSrEMKRK0mc8nZFgBk1XACUNcl1768fd72hCkiXoH2cDIdrjCb0u3K\nW+gKVEXGiKDZPG/fLhWOpszPofmh6cKpvONaHAnNACCQ8rUDwgxGu9nefTucHUqVvgrzVjtIzoC4\ncCzlDG4TKX+Oyl3Af9BKz2fPJQVr/vdPj4HeYYYnanYh4LcvqlyD6qxz8MTrG8zKXElBCkZhRKsw\n2mfeKpdZfcdc4b16o1khN7anJ5SxZ56r9lVBMsygFJLebRW0rvlT5+ftHebAwUlV5gBJKZQ9Teyz\nG1/EB13L4Nvn06zn7PBKaYZ5+6HFvmjpzz3NIk+zuT1oMpmnP7uvfd66Yc6alykz4LuvvHM3m7tb\nLqJGFz5sXJ31Pok8XROl1t8Ba3Z4r9ncUvKpeYuluwuw7vx10QZnEp5yx/AuM4p137JPVhH0eatl\nN2HIGGmFdyyhQW8dl7Uvs/K2vmxEOoilQBSSqkEkXTNluSpbu8/VHinuft4OTykQMydesSZnCfpV\nnDzCmkFMNiAk3ZqaVIJPlVFTYb5XJGmGVWs4BikQgxGvQDJleJrNkykdck0z9EA7KhP7mHO4Z1T+\nABC1Ku8qfyUkSYJPVhFJxLHgnS1WeNvN5n4AMoQhpcPfl4AywgxLvWW8s09ZT48U3m+vaucWPpEM\nQsTMufn9B6xG20irerIH1bmqe8jmjGxyVbvZzN01GmN85t9nc8fnnr+RHd6bGiLOHPbuPm+7X7/K\nX+lZtz39hZT9ZRxOpPvZd0fN/v0vjzBnN5QrO7NGnO+KmhPb5Ap2O7xSuuFtclfSa8c/99YmvPbB\ntqzX9iTfaPN05e0Nu2ff3ISXln+OsNTsdH/0ps9b13NX3lpGuOVbqx0AYnmazd2tBLvHvYSH1z6O\n19au97z2b1/8H/wHZy+GUyhDCM/FQ8Gv62cVav89GhrD+M2j72PTjg7c98xqRON9v3ArplxjI/py\nwZJM6Xjzox34y8vre964DDC8y4w9baoqKZ7KO4sho6bKbPKOJjSkPj8MX459A0YsPamK0CWn8naW\nEQUgWc3u9qxn9v4As9nZvlXJXXlnjVZXXCPMAVQEVewd2MfaRoMOzemHVhUZFX4zFCNWsOzobIYk\nCYh4BZKajpCSEd6VZr9gZXQ/81hzNJt3psyFEsYEzYrYp/iwszWM55dtMf/B+lyVN2BeaNjH75rn\nXa4I43jpAnNbpB8/cOJIp5lbaH7rIsDU5WswH5fTg+Ls1gF17834rKUBcrV5cWCER2Kczzw3G9q8\nM63t6mqF0GX88dkNCPnM/bv7vO1b9qp9Va7WAyNdfarZYbR5d5vzc7u1GEyNOhpGIgg51JX1ZaZI\n9oWZAXWfDVDqvjB/l1zN5prhad2RrM8djafw0vKt+PuSeuyM7Ma8T5/POV97Lkae+7ztqYTXbmnF\nR5+lBxf+c8VWvLBuGT6vfhnq3uY8+T32eXtmFzN6rLx1XXQ72txTeWu5K2+7p+mJ1z51LpQMYaAj\n1QZlRO5FZwpx3zOr8V93vdVta4emG57lMM337t993plTwP7+yVX46LNmvPnR9l7vq9iefXMTfnTn\nm9jZ4p06ubtBh24bt7Wjrcv8/zVfyLd0xEs6HXFfMbzLjF15GxCe8E6PJbcYCmoqrfCOpwChoEau\n9TRf64bkDFjzfPFat4l5mrqtn9uqV6fvv3Y1J2eFp6x7Xl8RUOFTFXMOckWH4QlvCRX/P3vfGW9H\nVa/9TN/19H5OzknvIR0SEjpEulIFiShYLyI2BEQR9PpD5aJX5d5XQbHAtYAIypULWABpIXRIg5De\nc0pO3XXKej+sMmv2npOQkJAE5vlAOHvKXrNm9jzr356/wfTMWax0xyDt5EUKSdiOhxjTYM+5eRSY\nJjgdgxEYq+w2H3D7YGkmKkxqERuqESB33eR9z/k51OD4GZztI2F5FehIjxDu/mRMxylzRvgxascA\nsf34rgGLndIG8VhnODZGNZaFPvlpkZvgZSphKnG0pVqwrm+DiGl7xENPoQskT5vT8LwDuXXqoBTz\nlhdQPO6rsAXKmMqRmBc/EwCwTYqZ9xeY7CuJg2QroJhF9GT7AxKnQo42MQijeQN0Rt5xU4fteMg5\nebylPwa10idSYXnnfCL58Su348mtS/D0tucwHAghIg8j2DDD30f+/Cf3Bd3PWgO18DU2Fsfx8Pra\nnmFL1uRzOR4JlXYtzXrfXZ33cAlroXFuhYhEtqCG/b4RAReWkRvHlOJbv34Bn//RUwGyGS488Xah\nlogfcC/Du9Uudnd4aMlGAMCK9cFFUVAlL/yaO3uz+O7/vIyb734RQDjJ9/Tn8dWfPov/vPe1sm0H\nGxF5H2Lgcp8e8QJu85OSl+CyKR8RWeeEqEjFDWiqItxXhq4G4rfE8RPWvLxvkYsabznWy+uwY9vx\n5zUPBT4D4Mufqi4AAqjB2vKERevS4eo0EU3xyds0NCTMIHl35ujLl1reXiDmXXRcEVsnpCRhTvXd\nxUNuH+pitejpz+P7v30Zjh20qDXTptYwczcTT2rqwv51drbD3dXC+qYbgOohldBx2xePRW1lDBk7\nA0MxAaIGLG9DYeSt+AI1gfkC/GQ3x4DjemhPt8EhDr551xMAgK5cD4jiwcumEDM1DGR4jbwjXLfC\n8jZTUtzft7x5XH989RjUaC1iO8cga4X6qwfXw8vSmv2/L1+BL972NJ5bSePv/YX+wHGinaylwXE9\nPLVlCfr1jZClCbjl3ZcpAIoL64h/iYXGa10rUIqubA8Gi0P4zSNv4KofP4Wt3ZlhNbdLX7YD2SJW\n71oLrWETVOba91gI47W1PfjRH1/Dd38d7o4W51IdvLTzFQxk/UUst4qD3+1hlf009DYa/y+zvAsS\nebthbnP/M0X1hFUnt9flHqF9xe6M561d9HmRPQHDLYzeLrRS5aJDEKW6GYFF2zBW86ad9J70MC3/\nsORH3tt+1cbesm0HGxF5H2KQyduSyLs53YA5jTNgKiwm66mIWzoqkqZI7DE0FfUVPkm7riLI29ky\nDsUNkwFIwiEy2Uj//0bvWwDoAsHfTv/fHLNMxHJlyzwRM6DrKrW8VQdEcaEpPB6uIcXc5isHX8fD\n6/+B3iJzKRcSKNgutnXmoCoqNnX3omh7Qq5UfAdhMWtOiEYBLhxk+k1c87MleHNzHwaGXMhtTxW9\n6LvMAboA4W5zVodOHJal7hIYGvMU6P6POGNnYaksN0Amb7Yw8VAU4QNSQt6KXqTKdVDgekR0U4Pq\nghCC7Sx5jeTSKBRdZLJ+jTx/6QvL20j6fd+lmDe3vFNGCiopDy2IVquOKch7xY4NAIA7HlwJQgj6\nCiWlSxqXf6WWt0tCrErNRUXSxMBQEUosCzVGX3KqomJd/4aApekRD//x4m345Yrf4cnXaDLQ+m0D\nw8a8S8l73dYB/PjV22GOXClkfVHiiXpdcq/L4C9xo/0N3Lf+TxhM+fFM/rIujXFvJstgtKwXf3O8\n2rUcAwU/L6DUba5YGcSP/Jv/5ap/H4ekKgO5MmRf8HZ6cpcuSDj25DbfuGMA/3X/soDrfTjyPpRy\nuEuH+HZi3p0lLU7DKjhMQyv77FBBRN6HGHi3MZd4Abd5OkGJQybvmKkFpBh1XUVrbYX4m3gaeofY\nKp9ocDtHiAYn/Bz+viGPQpjbHIA1eSn9n5KYNz1GBzQXRHVE85WYqSNp+eP86/q/+brmtonnV+7E\nt3/zIlxbw2CBveTYGF23xDvAPuelal2d0o+LaDSBTKUdxVy1ECBcriInn58vWlzPg6XSfXVDJu8M\nYiodu6gVB+ApLC9AsUXSXqAcDzSWbjHCdlxPiKcomgvXI9gyRInMy6XYi1AR96efuXeHbE7eKSGB\nC9XzrT5meadNSu6EoKScboiGDYgKkqXPhpyYtq2/F45EzqpnsnI6D6ahwnY9IfIiQzc8NFbHA+1n\nTxt5MmY3TAfgl8ABtAFNxsliTd86keCn6wrk05JhyAYA1m4rr4tWNV8NcHcQOvkshGEnt4ltnHxl\nK01uHSvvs7p3LX6+7C68YD8ItXon9Kb1cFwPHgsDFG0XetPGkkF6tIwS/n0EEBDuKRsv8fZIzm8n\nbC1n4e+N5X3lfzyOl1d34bkV5Tr85eMoP9ef1/zfbtX+9hV7Gnep5e0GLO/wY3mf8mSMl5mW73co\nl5lF5H2IgVvepJS848wFzcmbqIgZQfI2NEVYhAAATxUPKIXix39RalmHrDBD3OoyFATd5lTpTYei\nuVAU1gwFQNzUUBEP9oC2XZ6lreI1Fssjju5b1oxcbVt+80iWM3P9J9W0fz2uCkUliM/5BxQzD6J4\nAVc37bxGaDy9pDOb6xKYjFw1g373uv4NsD0HMS3BxufPnY0C8k4BLmxhvYfNkaHQc7oegcoFDVUX\nRdtFX44nDkpa2GyBMsAWXZt6ekA8BU7RD4koiivkTLmLO2kkqTEqhwYA5NwMVI/OPfdCFIlv+e0c\npLHCZHYU8svno4KwzHvdEfK8A5LL18vTcyUTCpK8xpuPQU+IOZRlcLn17xFPJPFpqor/emCZf97d\nSHgOZotlOR+q7ore67sDf4nzygoS8/MBuFXtegSKlYU17Um83hOMsfMXf3+BHpdVemGNewVG+5so\n2g4efHo9rvrxU1i5cRfUip7AsUrA8vYTqoazvPNOAf/+3K349crf7/aa9qTbrVZ24U3mPQPefsxb\nnvdk3F+YD7eYKP246Nr4+6YncNeqe3Y7vr2F63n42h1LcO9j4W11AZQ6YgJW9HCaCNt76LuxtoL+\n/sLc6283Uc0jBKs29r6jxi97i4i8DzHwhDW3JObNLW9D5a5XD5apo7bSJ0VdV8vIuxQyAQXIhoQ8\nCnsgd/m7EjFK3rL1yWO0MVNDMhaU7LSJLb6DJxEpngHD8jCiIeWXAknhQSK3PWX/5vNAR2MaJ8xs\nDYxXYdnqPMnMMjV/PjSnrDOb43rCbc47hf34lTsAUPUzOugY7E1UCtT2CljeQ12w3mBN4FwyTGbN\nP7NsOx54YpMYe9HxROa9fJ9UaIDioj9TxKadg+jNDQKOibVbB0RCG1RPlCxx8t64Nc+6z0neBcVD\nkRRgEDZ+fq3En1SejY5CEqZdA5PF8hXNgc403Tlx0fmk280YEd4WbnnHtLjwLshKev1539LnBNfd\nnwsmj9VtwD82/QuATzAXnzyOnstxxaLW3jyOnch5W+QtkvGYkp6iEtF5T7a89Za1UONZPLL1kcDx\nolQupMd63inikefpPX3hra1Q48GMZyhyzFsi72Es78c2P4nOXDde3PkqVvS8OSxp7l6m1IM14SX8\nz9q7sXGAVkS83WzzjTv9+/R2Er5K8wGyUmfEMG/NviKTc9DVl8fGnYPIOTlsGiwvS1R3Y3kPN/4u\nFs/mW12XAEYe8SMfwWObnsTTr2/H7//xVuixpXjxjU6ahf9WePjmQCAi70MMgZh3wG1OicVQfOvN\nKnWba6qwfADfsm6o9gl+WPIOURIjw7jNOUbUVfqHqwomtlcFysc42cRMXciJcnDxEz6GuKVjYmsD\nHGLjcxeNRnUFHWdBTiJ2Zbc3V3BTkU4YaKyOB65HJOUxy7syYYpriM96TKjQnbuQkoHjEphsMdLX\n9Dh6cr2iZGt2zZHivM6OUXD7a2ATGy/upPrvXBa11G0OACZbbK3fPuhnzGsOirYrVMrkudVVHVA9\nPL9qJ2761QtQDBq394h0P1TXLxdi5L1lexH5okv34d4JVmGQGeT3UQFxNbjwJ7UvT4nZKRiIW5oY\nLzQbpk7H1S/FefkCSTdcn7wNej5L8cm74PrWZU/Od3trjLxL1dOMjlV4YM1DcDwXhAAT26tw1GRa\nG19winCJC7evDs72MaKigTeMCUPfUAGPv7JVkJDcjc4cuRJqRXfA8pZDQNLFis5hgYQzBqphH1zA\naH3t0gLDldzmEnlb4eQtJ/r9v9fuxLLulaH7DVceV3RtaHV+WODxzc8AKPdqEELws9d/jf9dG1yo\n9PTnWRc8EiDm4cgvb7t4+LmNWL2ZlmxmbT+GvCvfF3rMvoCX5xUdF//96p34/gs/KRM7KvXWBGr3\nh1ns8NACfw4c1xNVDH9a81f86onnsKnTv++7s8K3sERBfr/fDUTkfYhBVcOzzbmVoTGZUUV1ETM0\n1ErkPba1UsiE0pPQ40c2paXPJHeYbJmXan5Lx7O9yzaPb6kVC4yBrI3KlIXjjmgX23lTkpipiZcc\nh6uxlxnLJm+pTeDoFkqSd6+6V4xHdpvLMWthgbsa0gkDDdWJwPWUlsNVJM3A9ei1NN5cnaIuccfz\nhCeBqDbuXf0AvcaqMWhPjQheuEv3W9e3EaYSA8nR+f33y+eXzRG3vAH4Cxtmeedt+sKvTib8/TUD\niurhjU19gOJC0VwQx4TrefA8iJg4J29uUaowaRmT7DYXjVmkBZur0wx5Bp4QZ+cNxC1dkLeiOeLe\n9hcly7tAnzdV84TbnBOXqcZD3ea9WXo8IUy6Vy+KEkYK/9nryXK3ugKTkXPe4wsxQ1wDNCcQ81ZT\nvbjnzT8LBb8f3PMq7n70TSxdxWK3hpSAVbMT1sQXBQm6HvGrGmRIuQWDdjl5F11byMNyD4ilxv3K\nDtUT3gWZvLmGQSl6MoMwSAIN8ToAwxPgcKpyj21+MqDBz/NKZPL9xV9XYe2urVjWvRKPbHxMfH7/\nk2tx9zPPIT7zceitawLqdjIxquke4bnY3p3BH59Yi+/9lraslWV4d2a7Qse4L8ixDP9C0cP6Aerp\nKG3y44sJ2VjbtwHLi0/AGE1DII7r4anXt+FP/1obOIaHRGQJYPkdEZu6BErCf/YzuxGl6WFW/HCS\nvgcCEXkfYjihbSEA4JSO42GZ5daAcFUzy7u1jr4oJnVUY1RzhbAeAQh3LHe5A74rm26XasI7R8Dp\nbIPT2eZvl9zQJJeCN1SJRNHvuGVqJs4/kVoZU0ZS17HIqAZQLPjkHbf0gIAMjAIjW7pPU20Ccxpn\nYFrdJKzr34AhbTs7h3TxcsyaK615GtIJk1qBsopcSUb9rPH1qEun/HMxxboYUzVzXSL01wFfxtXQ\nDJhG8GfCSSTjZGEp/uKptT6FUpiaLITj66sXbQ95FhNorvGTDGO6IVnOvshMNu/QlzBbwMiWNyGA\n6rG+6l65d0K27ImniVp2wE+kKuboPbKYWA50h1U7EAzZQ1DsONA5GvYW1pVMc0SiD0+aMxCDxa5X\nJm/umvcGqJiOVrsNg3JrTql0atsQJVtVVYVlXfAYKTAvCl3EObAMSU9/1HI8ufVZPLLhn3A9V5RM\n9bA2nUqImI3sNlfCyrcUV7yMB/dgeXMPSEz1PUCK6olqD368O1gFNZ4RLm0ZWTuHQk7DB0efAcDv\nA5ArOPjt31eL/Rw3PKntjV1BFy8PXcge7KGcjd89/1TZsX99diNyBiVEo3VtoFc5J38lNgRr0guw\nJtM6/lK1uqyUUf+/ax8WvyEZL+54Bc9L/elLUXSL5fr/kuXN8asVv8OPX75dhMe4Vfzwhn/ihy//\nP2zxVkKv2waoDhzPw6/+7w08tGRjwAshpH+55e2RMg+kKpP3btrfdrPnbLgGPwcCEXkfYphcOwE/\nOf67mNVwBCyj/PYYnGBUDzFTQ1XKwq1XHI0vf5hm+fK4LQBBvgH3ouwelC1vosHeMBVeVs5WD24v\nrJyPKUnfhWxqBj588nh89zPzMGMctRZiElnlmfEbs6jbvLBsIYobJ4rtinR+njRycvvx9OsUjyXE\nlMfdrSnPCsubeNTytgzNT3aDH1fk1xC3NMwd7y88eDcxUzOggCa1aFKHXMI8Dbyvugw5NGCqscC2\ns0afiqq+2eLvQHvOgHyqi6JXBCEKWup8z0jcNH0vCCccx0Qmb9MsbE+lMe+Cr3QH1wAhCLjNdU3y\nTngl91+aJy7/6hQMxE0NMS3OzmvT5iu6DZe4UPKViO+a5hOo6kiWN53LJ1/y+8bLMW+e8ObsGAli\nGzDa30Rfwbcqq6sly3Dlb2BNfRquPgRNVaEqCoqk3PImqitCSfLcPrrxMVz1xNegN9FSL9cjwoPh\nZYOLq0DCWgi5Q/WEBSqTN/GYfKtr+6ED9jwljARkHf3BnA3Xc7FhYBNMLyUWMLe8eBs2D/ou7oJt\nU8liR0c2xyxCRn6PLN0kyc8SbM9tx5cfvwlPrH8hMNwEy81wulqhQRM6/4E4t+KhR6WJX2rp619a\n5PXYdBHVP1TwNQXYgpiXBZaSmWx5bx7ahqU7yrPOf7Xy9/jNyj+Iv4u2G9B8v/Wl/8Y1T92EXfle\n/HrFH7Az0yme9VIZ1NV9a6HX00UQJ+A1fesD+6iJwYDbnH+XJ2nYc0+G63riPnLIev6lynUyIvKO\nAADQVPYjUspdedzyVlRPuNJrKmJCwjBgeTOrNpDMEaKqxlGZNANxW0teCDA0VfrkbmoGFEVBY7Xv\n9pXJm7+EYqbGXKBKILNazlavTtPjWpK+znhpkhx3hauJId8t7lLL2zTUACnpsWLgHKauBRc2jNhM\nzYSmqXC8oOXNyVtX9fLYqpQ3YKlWYNOpI09ERW6cv13zr5d7PbTa7XijawOKjg14KtobZcvbpN4F\nkDLL2/OISNrjpMFnvQAAIABJREFULwlFt0EcAx4jb3gaFIVlC3P3eWkSIRfaAc1GB2huQNzS/fun\nOYiZmp87UIhTS5yoILaBIjJSzLsI4mpYsqwbdz9MXZM524/r8nixl6mE09kORSEYcCh5X/Ghqaiu\nCU6vmhhCwaRuV8NQy8ibXoODdFJ+PoOWqNEuNVlhiwsu7AJQFz63vF3J8k5qKcQd2kRGUT3YbJ4H\n7SHEtBgaN58HZ/toAIBDbL8Gmn1Hykj4vyvFQ6Ho4q3e9cg5eaSdVnj9dWIMXUO+KtiWHuZKdw0M\nZei4RJMdRhp6y1rE5vwdD+/6LYrI4c8rngxc85Cdpde1fip01RALKNntrbe+BcegnhAufyyseOn3\ns9T5E17dvAFf+q9nxCLHigWJqbQ3Oq9l93rpb3ht34bAdtlb4HgObNfG13/+HP7th/8Sn29l5ZM3\nPPtdvLDzZfxr6xLkmOVdCHNJMw8cH2NNrDqwWUkMBkrF1vZuxuceuwbLu94Qn8ltb6EHr4mrJAI0\ncY4QgnvefACvdvnhCdvx0McSE4frQX8gEJH3IYzm2gROmNWKL15whPgskE0eAqOEcDVVCcgbkuEs\nbwC1lbEAoafiQWICgNYaP0lNjudyWJLb3M821/06TEk0hYu4AEBVih4XsFRLM+Cl97OaYGVWsuWt\ny+QddJsbuio0vGWYqgFNU6h2t7SY4Mk3pmqUkbec9CeTsxibNN+xEMtbjWXxt/7f04Q1TxPudtNQ\nYeq+Z0V0RXNMZPIOtSCY5e1fqA04BmzH893mAGJxlLnNP3bqBEHufFvOy9JnytMRs3RhvampPjqn\njLy9giU8EKQYR8YblFzGRX9O2Hf15XwrLONkaEzZMcR+3EqLWRo8zd83zix/T6X3z9BUOGD3UnwH\n/d5kXAqTlHTVcwerpG1sIWeb8DIVYpz8he95HqAX4RXi+FjHlTAddqzqsg531PJOm0nYDoSHpugV\nxQKAex8qrKT4XXHPx7Iu6vKOF5vhDVXD3joGAPDcm742+NZeupghjo6BQWZpMsubX6Wa3hUoAyxk\ng7+/jJ1hc6RAgyFCF778bT/05vVAMQEvkxZiQaVqfRxLNvtlc+NHVOH8U/zcDzW1C/IP8ub/eQld\ng9TFbO8YgYQex/qSJjy25xPj7ct+gy8/eQN2ubQpztbuTKjVammmkKQNI0au9Oc4dCxyxjtA3d6y\nbsCDq/8BAPjjW38Wn/FjHdcLvEMAQElI5J230Vvow5Nbl+Dny+4Sn+8azIuZiCzvCACoxfzRRRNw\nxBh/tT4hdQTcwSoU3pwdekwpucdMLaiQJJM3CZKZkDdlkBOpONpqq6T9yxcSlaZvRfKXWMyULT//\n/LpE3tzy1lTNJ9mSxUVx3RHwWMIUb0nKY96moQlXOOCX9nDi0jU1IBwiX4OuKtjUOYTfPbpOfM7L\no/RQ8vYXKLy0SoYWIG95MVOSw6C6UIkuLNiKhCnlNDjCCiCOQd3mhJM3LwVzoai0tj5XcFDgbnMA\nlim7zTUoACxDCyTNAUCB5JDQaC5CwtIRN+j86rU70KmsgWKypKd8TMwDKcThEgdEp5nJil70xXDY\ngi3LuscRQtDv7PLbz7Lvz7t0e6ezCTuTVNr0kxMvw1ktF9Ahs/71pqHCUYKVA2JRKB5PAhhFjEx3\n4OjmuXRqpC588iKosHIezGKNyDsAaLKiYhQB2/QXSACgUMvbIx6G7AzSZgq27Sc2Op4jkTf9jsp4\nWlLCo9v6cixbv0jnmeTpwAekBc6OPhZGcA08vIS6yGWyo99B5X7zrx1L50hx8M+XttA+5ZkiuocG\nxBzpii5i5tzw1iq7oShAYeN4EMeEogDZYkHEkkvj/tttP8FLU5WA0Iw1+XmYE18AJ/A1W/rx6rrt\nbJ4NjKrsQHd+FwaKfqWCHMte2fMmrftnv+MbfrEUP/uLb81yFN2i0DRwPSJCFv7AWNWJ68sJK0RF\n7oVF9JqsbMBtvnEbnfNdhV6YE16A3voWc6F71I3O3iFpvQLENpnbnC0M8g5eWeu3C/WIhy2dQ3h+\nVac/3ihhLcJwiGtxFFfNg9dfH7q9lLwtUyuxvOWENXr7501uxPc/O5+2/pMs7/GtJf5MMMuCn1sr\nt7wbEv5Cg1tIPGv5e5+Zh7PmjQ0dK7e8AYiM5dIMYJJP4SOTzg1+oWR5u92tZePxX8R+9q0MQzWk\nemH/+3gs2ND0snri0fX+NRoh5C3PtxJI6C8JA6gudMUQZXQVSRNtKRqX16q6AuSbZZa3r89OAuSe\nLTjCbQ4Ahkl8kvc0aBpLAJOS5gACm+TAeSNmakjr/uIrhz5R1uQWLDEPXoH1ZlcGac9yzRPhEL5Y\n4q1fu3O7YKMgLF5ueXsKJYo3B/3yqE1bbdz10Dq2ncdXM7Br3gKgCNLjC4BYjL2U9SIUBUjoSVw0\n4TwoTgyKbqOphu4vXP/FGEBUmAodP7f+O6uepIsgT4Preb4YDot5d2W74REP9fE62K4nFp02sSWl\nO2Z5mwkpt4Fu4yEEl1Vf8DnK2QUUbRe/fGgVVm1hjXpcXWyX8wb4dRLHFGI7ikoT2VZv7sMdf10B\nWymI+VWhS25ztsBgdegkmxZz2DkwhKLtQW9eC60mqKrW7/o1y6qqBKRhAdAOaZK1nnPZ78s10Jyg\nZX49OT80kA35/cmu+tfWlau6FdyicJtD8VhIyQfPc+GLqEwxA89mioKuCkVzgqV10m9Qq+wRTXgc\nhzDLm97Hj7Z/Bl6mIuClsl0Pv3/CL9+7c/n/4DtLbsMDT/niMZHlHWFYqHu4Y2aJNWwZWlD3N1Aq\nRh/kptoE6qvi0FQ1QO6TO2rF/08fU4svXzjdj8ejNL5OURvzCZ94Kl08MJd5Q3UCR030CVa4iAGk\nErIrmrfwLL/YMXV+TJx386pImNA1Bc7WsdQqIeUxfuIRVFoVpaeDqRmi5EgJqXU3VCMgvfjZD07B\n5R+YIf7Ww8hb2t+TpEcntNUFd9RtqIqOuGR5z2+eCxAFesNmP6Pe1ZHJ2zR2yaw6a+ozfptXx0Au\nHyTvbGIjI2h6nzVNoeTLXtqJlEvdpooHz6afxS0dST2Jwlv0+lzFFpa3V4gL0RauVpbxBkUSk8hl\nYN+/dscu/Owvy/HEm5ScSYaFW3jZGrdwTN+789yrg4J8XOY2R7IXUF2MVuaCFBOB73hsgJYUcosx\nriawoycLt6hDt1xRiREgb0BkxOftIjziIR9jFmM+IfIKAFoOt2pjL/64lNbzt6aaqQyqwlu32uVu\n81hSkD/XyOdVBbativtJPy/gsZe34ull27F5Fy2Rq02m/aQ/j7vNFYgcCFsqeWT3dzBrY8122mKX\ne0Dyeep2J4SAe43VWAbEU0AKcXGN37l7KdZvH4AxgvUzcHTkXliEpNMIBzZ4GZ+mKgErmkP+zdhM\nuY84hig5zEreroxdImID3+1N57A8abDgFkTCGkqSyVJGUhCrIyzvrK+q6OqA7gT7juvhSWe268F1\nCfNuqMgXJE8Zu8ai7QbG8GrXcmjpXvEbEfu8S4jI+z2GcLe5VPIVEvMOWJYSudek/Bfr/KlNmDra\nJ3MAAUEYDpnc4WkBlzk9p2+5VyV88RiZ8MQChBGVSGarS6IuLnkDPA26ptDEKkUBoIAUEtAhuarZ\nS8ojBIs6TkBV/wwa72OQyZk37pARqJsHVZKrkhYBCb085i27zT2pfj5VojKnKABcFQ3VCUwbXYu5\nExtQHauC5VZBiQ8FMuppqZhfh6omhkQmLHENDOVtFGzfbd6XXI7p09l99VToqsK6zrH5GPMMVNZb\nmj8TBduFoijwhmhoxEYOipmHQhTAlmLezPLut/tgxIPkzc/Vn83h+VWd+PsKSt6lljePLTpMMCb/\n+kLs7CmIuHavthHfff5HUHU6B5br51qIlynJQzFzIt5tKQlk8g6Ia8BVCjBYtUYpefNcjbybD1iD\nzrYxrByPC9FQ8n19K81gbk01U8ubu80JJe+EpQuXdtKK+d4PI0jeXPdAdPDzCsiKen3672lzx4rf\nnS2XWqksROKYoGI7qng+sgUnQJwAkMl6ICBwPMePeceyIIUEAFXyDri49wnfclR0ByAqFM8MzLWm\nKqHlcmD3UYkPQq3sogtqT4NC6Dhkb1fGLre8lVjWJz+jnFjzbkGUivE5cnc14gOpy2CqlvjMcT04\nnoO8mxeqisTVoaiOmGN6fSXfwd3ujkcXAJoNuAZts+zySgJ/n7LjpTkChkmqO0CIyPsww5566IZa\n3oGYd3mdNycbz/MCCWvyQqBUfrB0eygUglhJrbpsrVt6+PG8QQh/iR49tQkfP20irr5oBkzNpCtu\nNv50wixrSiCTN38RVqUs6KqO6sKkQHtUUzVEOdCYmjZ888jrMK5q9LDXaGhqoJZ9XHMdPnDkCNx0\n2Vz/slUFhVVzoearsKDZF27hgh4yKhMJ6JqKL104HfOnUq+CiQR9YRh+0h0tFSOB8h7enAWOgf6h\nYHY9APSxzm10kaMGLG8A0OtYwhT7bOG0ZurZYQRAyTEPnSQAKNA1BQumNWFEFQ3Z7Mr3wUqweHKJ\n5a0YBVimImWrJ6Brip/YptlQFF+qlQghGVVoxW8Z2ibmwHN8qV23zw8ZKUZRWN4WErSch32HbrJY\nrsVkMJnHIGXRf9/Y0oVfPkL7NDudbSDFOAtNcPJmI0pQi7Ml2URj3sxtTsnbRdzSka70UBmjrV1L\nyZ+7r4u8RxD7DRb1XRhwugHFhcZEgyqsBGrScRAiuc0VOW4vJe0x0ti4Y1C4r0lJ7kHRsyl560W6\nwGChB+Fh01ykkiE04AYXWSqzvHUvjuKGSbC3jaLbmSWqN9MFDsnR3vSKS8chk3e2pH4bAPSGLYjN\noNnmhuWTYFKpggIFBafot2FlY/EKcShOnC7CJLc5j6kHLG/NoUTMUGrdK6oHKJS4HZewcj0ahvIX\nOPQ7MnlbkLfT3QJnRzubA7o9bmmR5R1heBT3QN56iaVoGcGENeLJ2xXpv1wmskSqkyGsLWCY5Q0A\nY6voD5vYFsa0BF3VMtEamoFvf+JI3HrF0SXnpeTI5V1jpoZjp7eIuHiVVcmuRRMNW2QYEnlfcfZ0\nfP7caRjTSo/RNZW9YPh1aSJO1VSTQGOqBhWmbJkH57M0/p00E/jwiePQ3ugfo6kKvMFaJDedgOqY\n/7nc7IGjpabc2rcU+oJVOem4GnIFF7miCy3mZ1VziyWQxyBZ+l051vCFuc0NPRgWEYlVnobPnD0F\nNRUxen+IBuJqKHg5KLotXsS6ruITZ0zGNefT+9WT3wU9zsnbEucCAK1yF5qmv+W77l0dLbVJsVDQ\n67Yj1rLJF3MJkZYFAM9gCnBF+twkYjq83iZUDbIKDL0o9NKTajWyeceP+zJLTjGZNeZpuOXf5iPN\nyHv11h68vtFPsgJo8ppX4vZWrBwsNYZ7/rYJBP4C1CU0YU2zCsg4Q+iobGGeJhWEKFA1Rt6eDVM1\nYLOsZrHAqejBC7gPWsNmaJXsGow4aiuo0EvOLhey4fFuIkkFr9vZi9gUKpzCFy5y3NzziEgMEwtX\nISTjoChJ2Y6qYhnlJeENVSXoLw4grqThdnb4izUtaBUX3jgyMA5ZMjUjZYIHcmMAQPEEeRc3TMIM\n71xYmomiWxAxb1GD7eooFF0YqinKHh2X+Cp2IrFRh6J5yBSkeWTkrQ42wB1gZWWqC9vxqKdDs0Fc\ng3lwuOVNv38wa4v5cLvaUB1jybvsGY9behTzjjA89mR5lwovWKaGie30ITthVmvoS5KngHgEQQlR\nibiUEPIu7fTEccX0T6Bi2wkgmUqcefTIYcdqqDra6lOoqYiVfS6j1Hr33dYkKNTBj5dUz2pSScwc\n71tqhqYG6n0BX7iBZ30nDD9cUGZ5l2Sex8JKxbgbnhDELR1nzO/AvMmNWDituWxfSyuPmXPyVrhl\nzd2sRRfElPpCM3KXQx2Vdf6LWGQrexp0VaVubzlhx2KuVtfXntfZfSaOgSFniMqzshc5d5vH9TgS\nehy7cr2iJI+/zD+4YIw4f6eyRri947qJz35oqug0BwBoXYmCW4ACJZDMJ5frOBq1evM5Rt5snKpD\nCXj2tCTMup3wCnGkvUYqYcleuq/iL4DqUPJm46urjPueE83P6OcvfM8jovWqxvu6aw5yWQVLWJtM\ngy1aHTh0MZ2gNdod6XamSgfAU6FoLJud2DA1U2RNBxfQ/iKNz21N2gI8Ddtz23HHsrsgMvqlccLT\nxMKoM+tnO8ulcIBP3rwlqli4Sm5z3oa3Wm3CN46/KuCh4ffC1bLwiIdxDS04YVYrxrXUse22P5eA\nOG57Fx1vgLyZZXzVjE/jqhmfDswBvRd8gRLDUMaFpZk0YU3EvNn8cfJWTJFQ5rgeduWpp8kX86H/\nDhX8MSi6DS+TRmbVLH8Bwo7f5qyHogDeUGXAbS5yC3K2mA/iGJg7voWek2kimHpkeUfYDfbU67fU\nhZyMGWitT+G2Lx6DxaeMLy9XAkRrPyrmIFnGEonK33v6yJNRH68NxH5lWJqJq886ATd8bI7I+JXB\nm6+UeglKt/NyH8sILjgqmeWt6DbSyZBac0XOXA9+h6YpActbBpf7TOp+LH5PlndY9yQeo+eqcecd\nNwafPnsKmmuTqDQqA/uGldvFWekWfzEeN82vr1WkF7/O483SgqxdnVZ2PuL6lrdcIy7I3/W158e3\nV2H+lEY0pqtEaRBPaNOlhUtNrBo9+V7U1zOyZy/CU+YEdeCJ4oB4KmaOa0RTTQLfuuyowPa8W4Cl\nmaiuKF8EAUBRoyV7OUbeosENK9dzk53wFAfurkbs7M0hm7eF29RGHlrtdroAkcSBYixPQdHcMnf0\n8nW70DfAXMUa70jmBBZIwuOkuFTVLk5Jo6OizV9oen5M2iU2DCk8M5yXAQDqE3WoqYgJ1/1rXcuR\nJf2wJlBJUd/y1oU7l3sv7O0j4Q0wi5aR8zPbn4dLiBAb4QtX/syYY19FhtDx1+ktSFspGLoKj7e5\nZZamrdHjm5J1+OiiCWirpgaBYmVhjnsZWsUudk56n555lWaqb+rpFdfG3ea5IR1JI/heUHQbnsEs\nZ9tAf6YIS7NQCIl5wzGQLzq+qJLqoug6uH/NX+k1MhU7fo0i1q7QOm6xuJcaBdmOh60OFW5xu1tD\nLe+hrB2o8EjH6Hvig8eOwHc/PQ+WoUUx7wjDY97kJswaX4+vLZ4Vup1nVNdZ9Zg2uhZnL6Qu7GSM\nJWaFkTezvUvbBcqiJvKmM0Yvwk3zrw0mp5WgtjKGUc3h5M7JebiYOd8u9MdLkt7SJn0BKbqDdLyc\nvGXLu1RIxtBUv+SoBJwYZMtbLyFXTmAfnXQhxlaNwsiKkqYlAM6Y34HT5rXjU2dNLtv2hWlXBWr0\nwzL2E5rk1lc0jGzyCb+m+xhfI54n+kj3tEFvxw+O/ffgCT1VinlL99RgbnfPz3jXVBWfOmsK6lL+\nvXOKLAFLWrjUxqphezZ67R7qvuS92y3//GkjhYq0CsXTce6xNI9A04KLy135PliaiY+dOhFnzO8o\nmwvCwgCdPVRHnN8jTmI7MszqtC1s685Qy1tWA+WhBdvCSbOobn/CYG1Nx7wu+otzwn9pdZcIJ2ga\nK8nTnMACSSgPcnI26QJjRLoVlsmS+jwNHmhfe9uz0dMnJToNoxxYvXURLM2kSoeyVCk2+vMhW95C\n598RcyD2Y8f/c9OT6NPXQIlTD0ap5a1oHvQxNJuee5Fk8halWCo9vi5OibE6Sc9jtKyHVs3ugfQc\ncm8N11bo7suhM0sJ/Sf3vFn221fTPSC1G0CKFrxsBdZvH0BPn4OcUxAiLVzbgdgx5G0XGgwxxgIZ\nQme2G9PrpooWvdzyzhTZ74Qt1OIaj/v7mgeO6yGDXpCiBZJL0wx1fs/ZImkoZ0uuewOVTGggFgcq\nUxZMQ0XRdvdoYO0vhJs+EQ5ZWKaGK88tt644UkYS35p/HSrMVHhMOqS1J3/Zlbb+k6340pZ77wS+\n5R1O/pogb/riLiXvuJThHeY214gpvqd0gUHJV8XE+Cx0NFQFtnELf3duc+5Wntc8B/Oa54SOP27p\nuOD4saHbUrFYwPIPu0dJPQk4/vbqtH+9llcJe/0UWJOfB1GZFSBZhZahIaZbSOoJP8bo6dBUBTFL\ng9vTgqJuw+zw5SHh6mVd32TLyCkyYpeItyZO44V9hX5U6JXg7RsURUFh9UxY41+B7dlIWAZqrKQI\njWglnouck0M6UYdpo2sxbXQtHlqyEYWVR6FuyhoMEhazJ4Bjq+hoTQjPByfvXqaRbqoxbO/JImbq\nouc44MtbLpzSjounUNnadExanNWzpL1Aq1z6Ha45CKj1rCpAmmON11mzeD57oSeNhL/wJCo8uEhY\nOlzVLbG2gwsYTt5xRp7phBH4neaJn+XtDXBi8suYeLlVIJ9B+v+COgjVGoKXjwsPguyB4z9z/rsy\nNBW9/QRWo+/9Kap0DPUsVl2TKM/VCHw/I/8iyeP1tT348f8+g9gR6+EO1ACuIQiZQ6vugqIAxS3j\nAU9HvujCLChQzaLoC6ym+kEIdWsXii7i4HF5FzZT4RvoKxeEGsxnAZjQ0rS6okKvQhcQ0DywHQ8O\n8ZUCs3lb/K54eCJrboNVxWrfPRWVcbqIzjssVBUbgDbiDeSKp5XNzYFAZHm/B1EXrxk2mexblx9V\n9hmnZbIbgi61yt8JRFx+mFOqSnncXkZCcmvL7U55rJlnm4dZtfzlf3TNSTh7zKmh35PYjdv8ncLU\ntYALN8z7kNb9a7JUU7jhAZoMN29C0NqX3eomW4BUxZi1ThSAKNBUBcmYgfqqONydI6ES2UrSAhYz\nAD+jHxAvYpng5Xr+ilgSs8bXi0XloglzUUmakXcLyDm5gJiPripCHpQjVhL394aqMdM8xf/A1QEo\naK5Jipp8txict7pUBTp7cxjMFuF2t+KktuMB+PKWDekKUXWRNMt/G7LkbYJtH4qvg9FOFzky2QkJ\nYJ6Mp9iIaVbwuXV05L0cYpVDVMSmpLJDBifvGHvuUnED8o8jD2r1FlbPFORbEaf7ajXbpQ575RoO\nAPWsKUYRhCWr3fjxuThnQfnikqvrmYYq7jl3mxdUujyrZ5Z3Q0U5eZeqNxJPRc7JY9naHiEA43bS\nZ3XZup7gHHDPgOwV83QoCi+1I1CT/SC5FFRCyZ9b3takpbAVapWv3iDVkrt8AcF6rTdSQZZRFvOI\nSZZ10XHhoCg8BnLuBPds8GeBjRhxg3fQo+SdSa6F0bwBm3qD7UoPFCLyfp9hREMKN867BjcvuEF8\nxt08Lvt3XP+5uOGoqwPH7U/y5pa1S8KTO9QSy7s05t2UpOpN9bHaQO05fzlzyzuMGPnLP6xjG/+e\nZIjlffHJ4zB+RFVACW5fQL9fETHN/kJ5b2ceFgAASw+St6oquPC4oDu+QrIkeQ9s/pKloQefMK5f\nPBvzpjQibUrk7Oplcyxn3E8f2YxPnDEJx83wBXZ4xj9A5+vKc6dhFksMvPCEsRjZQL8/5+QDSXma\npsLZOg7FdVP9awxJ2otp/vg4cTbX+Za36ypI6v51N1VVwiME67cPQFVUnNi+AIBfTiff0zE1I/zs\neAauugbQOefQG1g3L4mYDFWjiyJmeXuqLeLoHPa2MSDwUGxaRj/gmvNmubdJMQsgRKEd5cDIW/N/\nG0VlqGwMHfXU82GOXAU11ReYJ/n7AKDIiJfY1LXb0ZTG+NagZgPgaxYQ4ru9eYJWERkYqiEWdfXp\n8pBYaSIeHAMFL49ETIfCeoDzkM+ytUHyVpmSn+w14Za8Yuap7oHmwstUwjI1arm7vuVcTG1mx0jN\nhQIxawIt3Y8RqTbUWLWB8Sqai6ydp78VtmjpzxQDx1tSCaCzg4Z3+D1/bPNTeHzz037mPQn3KO5v\nROT9PkRDog6VVvnKmbvGLSWFpmRDYFtIXtY+QxXkHX5SlcfaWcy71FoZXdmBzx7xcXxp9hXB47ik\nNCPv0pp3gFs1VIq0FMJtHmJ5nzJnBK67ZFawZn4foCgKrr5oBmbWURWz0sQdAEiYMVHrbGkmkjFd\nkLKmBUkLAJKmFONn19DMFjgcfOlVmbLw6bOmoDImddjytLJEx3qplOeoiW1YMK1ZzB2AwPOTCLsG\naQ5ly5vfS/klHUbeimv6nhNO3jVJcbzreUhLY2itoZ6ATN5BIqajwkoHqiHkMVZYaehvLkL+9YXi\nM/k+xIzdh5scj8BQLKiJQSjxAXhKOXl7fY1I6Wm4Zl/g+NLKCQFXF2JKybgRVC5TWaxXImdu9QF+\nXH+4RLiiTscwsq4ON1xKQz1hiZIJk96zwWxRsjqZ2xw5VJgp8ZzIz5x8DYCfsEk8DXllANvct0SN\nNo9Db9hRrtYG0JJD/qyLOTviaRH+IPkkYqaGfNHBUM6fI0/3NQ9Kx2N0rGJzRFATr0KMe5mE5e0K\nsR5O/rmCFPPWHDom1QOxTdibJtFxSc/tfW89KMJYils+twcCEXm/j/GlC6ejrT6JY46gJQ+cvGWC\nGsvqoxtq4uUn2EfwOHRYpjbgW+Ya++2kQmq5p9VNLluAcLe5RuiPKszyPml2G7568UyMaPDJ6+On\nTUR7QwpjWqk1EbS8939ayOSRNbjsiPNxycTzsajjxLLtluRaNzUqQiMatygKNFULvDiC5E3nroy8\nSxwnSSNoeZeiPu6Tt0zEHBVSA5rQBUiAvP2xcs+HHDoI08h3XCI8LPwl3taQFDK6iZiBasn676j3\n3fiJmA5VUZHQ/WssXfAYugqST6Hw5mwU101FXaU/3ngIecvEWSi6mKgfDUVzoTdtgIci4iELkLRR\nCaIEO7uVhif880slmpoaIG9P9ZOkOJRA1QDvXS/FsaUua45OiW/e+A7RwS6szDNlMNlbqVaeyt8S\n2MghLXljShd79BqY5rzJa8jpta8k//TVANl5t3aXS6USxwCIhoZqdi8kmeOWDno9Zx45HjGTajP0\nZf2ySRKKQnq/AAAbb0lEQVQbCJxfHo+i29CbN7BrTIjx8Xtijl6OdblV9CBXFzkAckKbodN7Egif\nlNxzy6I/suaaYEXJgUJE3u9jTBtdi29/4ihhhXLXuKym9qULp+Nri2dhTMv+eyC55T0cefPt6YSO\nb19+JCrfpqv6iDHUHdZWzVyKIdZFzNQxqaM68PI5dnoLbrr8SBh6iOUdco79AV3VcXTLkaFjNAzV\n713NSJeTN19Y8aoCUzUQtyTVOmF5S33R6ZkCf6UMyfIOJW/frZowysm7UnqRlxJj6THyNXLrkkus\nAggo1vE6esfxfCEPRkS1FTGcMa8DC6Y14XPnTEW15Sccjm32a/m5cE9ausbSaxC1+P31cLvbxLMD\nhJO3vMAp2C7aY+MB0HI7opAyyxsAKqQmL9yKa2sIL1NEqbWmlv825PvUW/Sbhvglf/52N6RxkRyO\naU+34YNjTgsQZMqS480avEwaWsUuaPVbQBQvcDwAjKroQEqpFp2+xjbV4rR57SK0YW8eL/bV0n1C\nOnU48FBGQzUTKUr6IaW8QRu3jKqvQ8LSkc07yG0ZAQyy6xS96/05mDLC9x7yTPWkkfS9H9K+y7LP\n0jE4BkawBQ4kt7mha9QL4egY1ZzGly+cXrbodPUcTM1EOvHOQmtvFxF5RxDglrfspo5bOsa1VQ13\nyD5B3UPMW1jm8IZ/2YXg46dNxFXnHYGFU6hs4R7lW4eBoRnCZbuv53gnMHVNvMy5vCTPOOfZ2jUx\npg6lKIhJ8WruNi9VsCq1vNvSkmBMyAtVJtQwy1te1ISRe3wYy1tkrEtWolySyMvRbMcT18jjoYqi\nIBEz8IkzJqO5NonqmL+gbK3zn9EjxtJrr5LIPVGywCgNxUwf689XmEuYuJqwVfNFF0nTAnE1EVMP\nI++URN7cyjv32NE4//gxZfsSV99ziZFENos6ThL/z/UQZOudZCqRe/HkQGy/QiJfRVGwqOMEJGzf\nQ5OWyRsKiutoAqLesLnseAC4es7nME+7UCwARjdX4YLjx4pyPrenFflXj5PGT3uNDwfujeGWt73N\nn6ch0Bh52kxhVEsFXI+gp9dD88CxwXOwRe8lp4zHF844DpW5CQAANUkt85SRRJznHYQtJFzd98rx\nksGqbrg1a2jioWvgklMmYOro2jLvQ09uV6gH5kAhIu8IAsfPpAlJs8aHtxvdXxhTORJAmHVIMb1u\nCh1P24K9Om/M1DFjXB3SZhKWZgaSrvYWPEZ6INzme4IpWd5claqmgvc7py8MTmxFtwhLiqNazHug\nqzqumvFptPfRspVSWhhd6ddUf2NxeQWCjHgIecsItbyHiXnLXp20Qq1dXv8L+Ja37bpoSNDnkBBg\nzsRgDgYAVPMFDILiOVzJri3V4o+x1PKWyPu4GS1oqfWvIWUl8K3510IfkhY4ro6PnEItyUVzR9De\n6LYpuqrFJaW9tnrqrm9IBpvoANSDcvq8Dnz9yC/jxBHHBM4vo3YoqONAXE2QCQBMrh+Lr5TkfJSF\nPzxdlNQBCP09aNLzLSc+AsBXPngs4GqC+NIhxxMoYlw8h+Wy0yYKWWRSjAnLXHZpA0DlllNwztgz\n/HOxPIiKBPME9jYF8hIA6k2Z2O7f95baCmiuf2/5d+SLDlRFRYfH+ruzkreUZHmHClY5BkYIqWMF\nbj99RovJbfQjVy+rfvHnItwDc6AQ1XlHEDh9XgcWTmt+227qfcWF4z+ICdVjMatxeuj2SbXj8b2F\n3wyWK+0FNFXDNXM+H3AN7y0Sehx9hf6DYnkbuubXm7JabRHzZqRTKxGX/DIxpSz6CTVjYbkZAD1l\n7D0i5WeOj24O96x8ceZnsHFwS5m7tBRhZYmyNR6WkAYA7eZErCg8g+aUb/0J8nY8zG2cgXV9GxDL\njMTZx00qO77GCo77psvmYiBbFHM1qmoEsCV8DHweZ42vx8dOnRgcu6WjLl6LOSNH47lupn3u6Zg/\npQknzaZCL6+s7mJSpdQzInsqvrZ4Nrr6cuhVN4nvT5oWvvrJo0TYoCXVhIUtR+GxzU/R87s6IHHC\nCGUatrxYg/icf9APSohGVYKJi6RE2lh87vj3Jox8PdZ61elpQoKFX266bC5Wb+7D5JE1UFdUwovv\nYseXPwfHTm/GP5dyOWBK3o01CXz90jlYuWEX/vPe10DsGHXtl1xDtVGHk9tnYkzlKNz94t+xsZMu\ndqaMqsH0TbUY0ZjCX5esByGK8C6kzBTGj/AXXifPbsPmt6rQyYVYuDgMlzw2LZCiJRZZKTOJmDK8\n5U1cHUdPbcIf/klbpBbfnIPYnL/BNQbE9pgxvOt/uGf9QCCyvCMIKIpywIkboC7Go5pn79aqTUuZ\nrfuCpmQjUua+kT9A3c5pI1VWc/5uwNJV1roRqGSJYaUxb9EUAcGyt9KSLz6DpIS9Dc3AhOqxGF9V\n7sLlGFc9Bie3Hzfsdo6w8IdcBx6WkAYAU5Nz8NkjPo6zR/v19kdNpkQ+sb0auqrjkknn47w5c0Q+\nggzZbQ4A7Y1pTB3lx65HVUqysiXPkio66ZW7qrkVP76+zf/Q1f0sZQCmqQUy5mXhoLilo70xHcgb\nOG56G1rqgs9jXPYGuDpOnOV/XzpBVdZ4YlmYlRhIOvR8adLW+iTOOYY1B5LqpsPCHzWDM+H2NqCt\nOE/McXtjGiczmVut6Lv+wyz3uso4KhL02r2S52DyyBpceMJYv1lKyTV091PCHVXZjqnG8aLne1NN\nAl+4YDqV2iWqyI8wVB2WZiIVN/CJMybhyxdOR3tjGpMa/C6AHz91Mlrrk2KR1TdUCCRHUstbCx0P\nAMAxkIobuP6js3Hy7DYACohtwVN5A52g5X3JxPOFpxAID58cKESWd4QIIfjY5ItQcIvvaAGxr9B1\nFW5nG2yjgCvPOx8AUJP21a+AYDKWXH5klpK3SJ0t/56rZn66/MO9wHnjzsIDax7CxJpxZdsaE37o\nZTjytkwd0+qCNevnHDMacyc2BKoBhgOPaZdm1nPwpD7e31lGW30SmzuHUFc1/MtWvi7iagGXv6Vr\nAZd02Eu7OdmI5mQjtmd2oqmy3LshW84nzhyB8SP8fToa0wAUmEocBZIVVutpR7ULyeOEEYcChS7M\nJCI6YWYrTpzVhhNnt+Hz/1WEmuqDSozQZ/mCo+bi0aWNWPyhCaFzoBerwIVd08N4ssZVjcZLna/5\n1QESqtIWyFZekkUt81HNVP50guT+lhdn3PuSZImH3kAN1FgWtudn4C+QmvzMbZyJf215BgBNPj12\nuh8uaapJYHl3HGqKJr8ljSRivN+BY+L85o/h3rcegJryLWuAVtmMba3EkZMace+W5diapS4U4hgB\nsaKjW47ElNqJeK2b9q1/N2PeEXlHiBACUzOHVak70DCYhKuzdRzqE9R6a29M4ZxjRokOaTweXB+v\nDVjbZoj4DDCsmN07wokjjgnGbSXIRCFaNZYgbKyqqgTaq+4Opmbg5gXfGHZxAAC5l06iL9sPBD//\n6AcmoK0hxayr0rHTfyulxjtfuXB2YB/TUAPkHQ/pLqcoCq6Z83m81rUC0+unlm3XVA03zbsWf1n3\nMOY1B88/aSQlNiPXgEJsA5QEJRdNU8X9VhUVcT2GrJMrkTv1O7DpJIbCiqNRlQ4nlTEtlbjinOHl\nluOZDgwZW2FU9pZpP3AsnnQBptVNxuyQMFgypkvtR6llfvyMFpxz7KhABYvrlmfX88WS21/ni+WE\nYGTFCDQmGtCSKs+h+dAxo5B5eRJeGKKqZykjAVPKjxhd3Q4vlxbkXZr1P7atEvW91YK8DcUs03pI\nmynoqg7HcyLLO0KE9zOSMQOXfmBCwPpUFAVnLRgl/q6NV+PauVehxqoWtdMAy1QPwbvUKyGARR0n\n4G8bH0d7upwggXIvwb6gcpjOdhy3f+kUhDlPYqaO044qb4RSio9Nvggv7HwFExpaA5/HTC0QTx7u\npW1qJuY2zRz2/PWJWnxy6uKyz6tSFlrrkujcUgN97Aa/R3XJjeTiIh111Th/8Ww8+vwmHD2Fkpii\nKHBcD4CCUU27n6fhoCkGiqtnY+KY6mFzH3Z3jRPaq1D7Vgp96BYKZJapBcIbANA7SGPSFSG9Cnij\nkXFVo8u2AfQ6bzjqK6GeBUPXcMmcU/DCE4/T79aswH6WqQXaKB83pTyMJHdPTJrhXRJrrCp05roD\nuQ8HGhF5R4hwCIJn/u8OYaSol3Tt8t9T7z57nz36VBzVNCvUnQr4mfEHEqX913eHay6eiQeeWheY\n+yObZuHIpvIOfg3VCZw4bQyeHqB61/EDYHGdd/wY/PTPWRQ3TII3SC3x0kXYpJrxWLVrNU4ffRLG\n1ldibFu4FT1hxL6Ve/L5U/YxPUpTVZw//Rj8YvkGuF10XsMWmJUpujiZPKqmbBscE1+c/CWMqKsu\n38awu/CWoer4ztHXI+fky/bTVQWktxVesg+jnWNw6YfKEyPlBWKlVU7eAK3+6Mx1v6sJaxF5R4jw\nHkLpy4n/fRAMbyiKMixxA8O7+A8WJnZU42sds/e8I8MHpx8JrO2GSzyMZuWP+xMzxtZhbGsVVm30\nPQSliYeXT/kICm4xkMAYhpHDtOfdEy47bRLufvRNXHRSeV7D28XMhmn45lFfxdeefx1AeF+Bs44e\nicqkhWOOaA58fvnpk/Dy6i6MaWh6R9LE1bEqhFG/rqtQMrUoLF+A6qnhYYFKSU2wKhHufeClm2Hh\nkwOFiLwjRHg/4GCw9x5Qmhl/uCFhxHHxxPMO6Hc01iSwamOv+LvU8k4YiVBteY5vXDoHb27uxbi2\nfVNIbKlL4tpLyj0Pe4vGZD14NnxYuMTQNZEhLmPhEc1YWELo+xO6poqSjPgwuvNT6yahVZuADVuK\n+MAJ4eWtnLwjt3mECBH2Cj/43AK4XnnSz26SzQ869kfM+72OxupgedecCeHW4XAY3VKB0S37ZnUf\nKBxKizZFAbhBP5znPWkkcN2xl9Me4lY4ZY6pot6RltSBW2iUIiLvCBHeA6geJpt4yqgavPRmF2aN\nrwvdfjBxqLnND0U0VvtW9e1XH79XMfxDFeYwCmXvJj555iS8sbGPlX3tObSkKsqwxA0A46vH4gfH\n/ntkeUeIEGH/4NjpLRjVVPG26qbfbZjvASI60JgyqhqTOqpx9NSm9wRxA76lezBx9NRmHD2VWsn7\nK6fz3SRuICLvCBHe01AVBR1N+67xfiDwqbMmY/32gVDVtAhBGLqGr148fKnZ4YTT53Xg2eXbUfUu\nqDjuDfzQ0qEYXBoeB3Qpd/PNN+PDH/4wLrroIrz++uuBbc8++yzOP/98fPjDH8Z///d/H8hhRIgQ\n4RDC/ClN+MjJ4/e8Y4T3FM4/fgx+eOXCQBOZQwEXn0wz6WVltsMBB8zyfv7557Fx40bcc889WLt2\nLa6//nrcc889Yvt3vvMd3HnnnWhsbMTixYvxgQ98AGPHjj1Qw4kQIUKECBHKILvQDyccsCXQkiVL\ncPLJJwMAxowZg/7+fgwNDQEANm/ejMrKSjQ3N0NVVRx33HFYsmTJgRpKhAgRIkSI8J7CAbO8u7u7\nMWWK322lpqYGXV1dSKVS6OrqQk1NTWDb5s2bd3u+6uoE9P0cI6uvP7RigYcronl854jm8J0jmsP9\ng2ge3znejTl81xLWSjV59xa9vdn9NBKK+vo0uroG9+s534+I5vGdI5rDd45oDvcPonl859jfczjc\nQuCAuc0bGhrQ3d0t/u7s7ER9fX3otp07d6KhYe/EByJEiBAhQoT3Kw4YeS9YsACPPvooAGDFihVo\naGhAKkVrTdva2jA0NIQtW7bAcRw8/vjjWLBgwYEaSoQIESJEiPCewgFzm8+aNQtTpkzBRRddBEVR\ncOONN+L+++9HOp3GKaecgptuuglf+cpXAACnn346Ro0atYczRogQIUKECBEAQCHvNBj9LmF/x2Gi\n2M7+QTSP7xzRHL5zRHO4fxDN4zvHYR/zjhAhQoQIESIcGETkHSFChAgRIhxmiMg7QoQIESJEOMwQ\nkXeECBEiRIhwmCEi7wgRIkSIEOEww2GTbR4hQoQIESJEoIgs7wgRIkSIEOEwQ0TeESJEiBAhwmGG\niLwjRIgQIUKEwwwReUeIECFChAiHGSLyjhAhQoQIEQ4zROQdIUKECBEiHGY4YF3FDmXcfPPNeO21\n16AoCq6//nocccQRB3tIhzRWr16NK664Ah//+MexePFibN++Hddccw1c10V9fT3+4z/+A6Zp4sEH\nH8RvfvMbqKqKCy+8EBdccMHBHvohg1tuuQUvvfQSHMfBZz7zGUybNi2aw71ALpfDddddh56eHhQK\nBVxxxRWYOHFiNIf7iHw+jzPPPBNXXHEF5s+fH83jXmDp0qX4whe+gHHjxgEAxo8fj09+8pPv/hyS\n9xmWLl1KPv3pTxNCCFmzZg258MILD/KIDm1kMhmyePFi8o1vfIPcfffdhBBCrrvuOvJ///d/hBBC\nfvCDH5Df/va3JJPJkEWLFpGBgQGSy+XIGWecQXp7ew/m0A8ZLFmyhHzyk58khBCya9cuctxxx0Vz\nuJd46KGHyB133EEIIWTLli1k0aJF0Ry+A/zwhz8k5557LvnTn/4UzeNe4rnnniOf//znA58djDl8\n37nNlyxZgpNPPhkAMGbMGPT392NoaOggj+rQhWma+PnPf46Ghgbx2dKlS3HSSScBAE444QQsWbIE\nr732GqZNm4Z0Oo1YLIZZs2bh5ZdfPljDPqQwd+5c/PjHPwYAVFRUIJfLRXO4lzj99NPxqU99CgCw\nfft2NDY2RnO4j1i7di3WrFmD448/HkD0e94fOBhz+L4j7+7ublRXV4u/a2pq0NXVdRBHdGhD13XE\nYrHAZ7lcDqZpAgBqa2vR1dWF7u5u1NTUiH2iefWhaRoSiQQA4L777sOxxx4bzeE+4qKLLsLVV1+N\n66+/PprDfcT3v/99XHfddeLvaB73HmvWrMFnP/tZXHzxxXjmmWcOyhy+L2PeMkikDvuOMNz8RfNa\njn/84x+477778Mtf/hKLFi0Sn0dz+Pbxhz/8AatWrcJXv/rVwPxEc/j28Oc//xkzZszAiBEjQrdH\n87hnjBw5EldeeSVOO+00bN68GZdeeilc1xXb3605fN+Rd0NDA7q7u8XfnZ2dqK+vP4gjOvyQSCSQ\nz+cRi8Wwc+dONDQ0hM7rjBkzDuIoDy089dRT+NnPfoZf/OIXSKfT0RzuJZYvX47a2lo0Nzdj0qRJ\ncF0XyWQymsO9xBNPPIHNmzfjiSeewI4dO2CaZvQs7iUaGxtx+umnAwDa29tRV1eHZcuWvetz+L5z\nmy9YsACPPvooAGDFihVoaGhAKpU6yKM6vHD00UeLOfzb3/6GY445BtOnT8eyZcswMDCATCaDl19+\nGXPmzDnIIz00MDg4iFtuuQW33347qqqqAERzuLd48cUX8ctf/hIADX1ls9loDvcBP/rRj/CnP/0J\n9957Ly644AJcccUV0TzuJR588EHceeedAICuri709PTg3HPPfdfn8H3ZVezWW2/Fiy++CEVRcOON\nN2LixIkHe0iHLJYvX47vf//72Lp1K3RdR2NjI2699VZcd911KBQKaGlpwXe/+10YhoFHHnkEd955\nJxRFweLFi3H22Wcf7OEfErjnnntw2223YdSoUeKz733ve/jGN74RzeHbRD6fx9e//nVs374d+Xwe\nV155JaZOnYprr702msN9xG233YbW1lYsXLgwmse9wNDQEK6++moMDAzAtm1ceeWVmDRp0rs+h+9L\n8o4QIUKECBEOZ7zv3OYRIkSIECHC4Y6IvCNEiBAhQoTDDBF5R4gQIUKECIcZIvKOECFChAgRDjNE\n5B0hQoQIESIcZnjfibREiHC44ZZbbsGyZctQKBSwcuVKzJw5EwBw3nnn4UMf+tDbOscdd9yB8ePH\nCz3rMHz0ox/Fr3/9a2iatj+GHcDOnTuxbt06zJ8/f7+fO0KE9yOiUrEIEQ4TbNmyBR/5yEfw5JNP\nHuyh7DUefPBBrF27Fl/60pcO9lAiRHhPILK8I0Q4jHHbbbdhy5Yt2LZtG6699lrk83nceuutME0T\n+XweN954I6ZMmYLrrrsOs2fPxvz58/Fv//ZvWLhwIV5//XVkMhncfvvtaGxsxIQJE7BixQr89Kc/\nRV9fH3bs2IGNGzfiqKOOwg033IBCoYBrr70WW7duRVNTEzRNw4IFCwI9ijOZDL7yla9gYGAAjuPg\nhBNOwJlnnokf/ehHIISgqqoKl1xyCb797W9j48aNyGQyOPPMM3H55Zfj/vvvx9///ncoioKdO3di\n9OjRuPnmm2EYxkGc4QgRDk1EMe8IEQ5zbNmyBXfddRemTp2Kvr4+3HTTTbjrrrtw6aWX4vbbby/b\nf+3atTj33HPx29/+FpMmTcLDDz9cts/KlSvxk5/8BPfddx/uv/9+9Pf348EHH4TjOPjjH/+Ib37z\nm3jmmWfKjnv22WfhOA5+97vf4Q9/+AMSiQRaW1txzjnn4Oyzz8Zll12Gu+66Cw0NDbj77rvxxz/+\nEQ899BDeeOMNAMCyZctw66234r777sO2bdsOSy9DhAjvBiLLO0KEwxzTp0+HoigAgLq6Otxyyy0o\nFAoYHBxEZWVl2f7V1dUYN24cAKClpQV9fX1l+8yePRuapkHTNFRXV6O/vx+rVq3CkUceCQCor6/H\n7Nmzy46bNWsWfvKTn+ALX/gCjjvuOFxwwQVQ1aCNsHTpUuzYsQMvvPACAKBYLGLTpk3ieN4+debM\nmVi7dq3okxwhQgQfEXlHiHCYQ3YrX3PNNfjWt76F+fPn4/HHHxfNPGSUJqSFpb2E7eN5XoCIS0kZ\noL2M//KXv+CVV17BP//5T5x33nl44IEHAvuYponPfe5zOPXUUwOf33///fA8b7fjihAhAkXkNo8Q\n4T2E7u5ujBs3Dq7r4pFHHkGxWNxv5x49ejReeeUVAEBPTw9eeun/t3eHOAoDYRTHHyGYJlwAMAjg\nAFROSC0STCWCIJCYBhwOwxEqegIkuqLBbRN0LQaBxkBZsdkaDJutmeb/05PJ517eZCbz9bYmSRLF\ncazhcKggCOQ4jm63m2q1mh6Ph6SfVv97VJ/nuXa7XdH+z+ez7ve7Xq+X0jTVYDAobX6gSmjeQIUs\nFgvNZjO1Wi3N53MFQaAoikrZezqdKo5j+b6vTqcj13XfGnq329V6vVYYhqrX6zLGqN1uy3VdrVYr\nNRoNLZdLZVkm3/f1fD7leV7xVWq/39dms9HlclGv15MxppTZgarhqRiAj1yvV6VpqvF4rDzPNZlM\ntN1ui3fn/3U4HHQ6nbTf70vZD6gymjeAjzSbTR2Px+J/4tFoVFpwA/gbmjcAAJbhwhoAAJYhvAEA\nsAzhDQCAZQhvAAAsQ3gDAGAZwhsAAMt8AxJ5C+54P8QOAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzsvXe8XVWZ///e5dTba3pCQiAJCSWE\nIJGmoSSgjsg4gmCb4Tf+dCwURUdEQXGs41gYFQvDiIyIiKIIJIAgEBJCgJBKertpt59z76m7fv9Y\nu55zboiQBCL783rllXt2WXvttfden6et55Fs27aJECFChAgRIhw1kF/vDkSIECFChAgR/jZE5B0h\nQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8IESJEiBDhKENE3hEiRIgQIcJRhoi8I7yp\nMW3aND796U9Xbf/iF7/ItGnTQsfdcMMNoWOWL1/OBz/4QQB2797NCSec4O3btWsXH/vYx1iwYAEL\nFizgkksu4bHHHgPgpptuYuHChSxcuJCZM2fy9re/3fudy+VC19A0jfvvv/9vvq/Vq1dz1VVXHdSx\nDzzwAF/72tde9bVcvNbz3wi46667+P73v/96dyNChFeE+np3IEKE1xsbN24kl8tRX18PCBJas2ZN\n1XErVqxg/fr1IZIeCZ/97Gd597vfzW233QbAqlWr+PCHP8zDDz/MV77yFe+4+fPn8+1vf5vTTjut\nZjvr16/n/vvv55JLLvmb7umkk07i9ttvP6hjly5dyvnnn/+qr+XitZ7/RsAHPvCB17sLESIcFCLN\nO8KbHm95y1t49NFHvd9LlizhxBNPrDruuuuu4+tf//pBtblp0yZOPvlk7/fJJ5/M4sWLGT169EH3\nq6+vj09+8pO89NJLXHHFFYCwAPz0pz9lwYIFmKbJypUrufTSS1m4cCEXX3wxS5cuBYRV4IILLgDg\n1ltv5atf/Sqf+MQnOO+883jve99LT0+Pd53ly5czffr0qmu98MIL/OM//iMXXHAB73vf++jq6gKg\nu7ubD3/4w1x88cWcf/75fO9736vZ18p7ueqqq1i4cCHz58/njjvu8PatXbuWSy+9lAULFvCBD3zA\nu85I26dNm8b+/fu9893fy5cv5/LLL+fqq6/mM5/5DAD33nsvF110ERdeeCFXXnkle/bsAcC2bb7x\njW8wf/58FixYwC9+8QtvrL74xS8CsH///pD15MknnwTAMAy++MUvsmDBAi644AI++clPVllMIkQ4\n3IjIO8KbHhdddBF//vOfvd8PPvggCxcurHmcbdssWrToFds855xz+PSnP82dd97J1q1bARg1ahSS\nJB10v9rb27nuuus45ZRT+PWvf+1tt22bxYsXoygKX/7yl7nqqqtYtGgRH/3oR7nppptqtrVo0SJu\nuOEGHnvsMdra2rjvvvsA2Lp1Kx0dHYwbNy50rVwux8c//nGuu+46Hn30UT70oQ9x9dVXA/C///u/\nzJ07l4ceeogHHniArq4uLMuq2VcXP/nJTxg/fjyLFi3il7/8Jd/97nfZt28fIISiq6++msWLF3P+\n+edzyy23HHD7gbB+/Xouv/xyvvvd79Lf389Xv/pV7rjjDh555BEmTpzIj3/8YwD+9Kc/sXr1ahYv\nXsx9993HXXfdxerVq0Ntff7zn2f69OksXryYn/3sZ3zuc59jcHCQJUuWsHv3bhYtWsQjjzzC1KlT\nWbly5Sv2LUKEQ4mIvCO86XH66aezefNm+vv7KRaLrFy5knnz5tU89oYbbuA///M/KZfLB2zzO9/5\nDldeeSUPPPAA73znO5k/fz533333Ienv2972Nu/v+++/n4suugiAOXPmeNppJU477TTGjRuHJEnM\nmDHDI85ly5bVvNcXXniBUaNGceaZZwLwzne+k127drF3717a2tpYsmQJzz//PPF4nP/6r/+is7Pz\ngH2+8cYb+dKXvgTAhAkT6OjoYPfu3Wzfvp3BwUHOPfdcQJitb7311hG3vxKSyaR3P21tbbzwwgue\nteO0007zxuepp55iwYIFxGIx6uvreeihh0LWlkKhwPLly/nIRz4CwKRJk5gzZw5PPvkkra2tbN26\nlUcffZRiscg111zD2Wef/Yp9ixDhUCLyeUd400NRFC688EIefvhhWltbOeuss1DV2p/GzJkzmTt3\nLnfccQezZ88esc1EIsFVV13FVVddxdDQEIsWLeLrX/8648ePf80TfXNzs/f3Aw88wJ133kk+n8ey\nLEYqVdDQ0OD9rSgKpmkC8Mwzz3gEFcTQ0BBdXV0hC0Q8HmdgYICPfOQjWJbFV77yFXp6erjyyiv5\n1Kc+dcA+r1mzxtO2ZVmmt7cXy7IYHBwM9U1VVVRVHXH7K6Gpqcn72zRNfvjDH/L4449jmib5fJ7J\nkycDMDg4SGNjo3dsOp0OtTM8PIxt21x++eXetkKhwBlnnMFJJ53EjTfeyK9+9Ss+//nPM3/+fG66\n6aZQexEiHG5E5B0hAnDxxRfzve99j5aWlpo+2yCuvfZaLr30UsaPH19z/8DAAC+//LKntTY2NvK+\n972Pp59+mk2bNh0yLa27u5sbb7yRe++9lxkzZrBjxw4WLFhw0OcbhsGaNWtqCiGdnZ1MmTKF3//+\n9zXP/ehHP8pHP/pRtm/fzr/+678yZ86cA17r+uuv58Mf/jDvf//7kSTJG4OWlhYymQyWZSHLMrqu\n093dPeL28ePHI8uyJ3xks9kRr/nQQw/x+OOPc9ddd9Ha2spvf/tbHnjgAe+6g4OD3rF9fX0kk0nv\nd1tbG4qicN9991FXV1fVtrs6IJPJcMMNN3D77bdz7bXXHnAMIkQ4lIjM5hEiALNnz6anp4fNmzdz\n+umnH/DYzs5OrrzyyhHNuKVSiU9/+tM8/fTT3radO3eyatWqEaPKR4KqquRyuZoa9cDAAOl0milT\npmAYBvfccw8A+Xz+oNpevXo106ZNIx6PV13r5JNPpre3l1WrVgHQ1dXF9ddfj23bfPnLX+aZZ54B\nYOLEibS3tyNJ0gH72t/fz6xZs5AkiT/84Q8Ui0UKhQLHHHMMo0eP5pFHHgHgd7/7HV/+8pdH3A7Q\n0dHBhg0bALjvvvuQ5drTWH9/P+PGjaO1tZXBwUEefvhhb2zmz5/Pgw8+iKZpFAoFrrjiCjZt2hQa\n93PPPZff/OY3ABSLRb7whS+wb98+7rvvPn70ox8BwgoyZcqUgxrvCBEOJSLyjhABkCSJCy64gLe+\n9a0jkkEQ//Iv/4Ku6zX3jR07lp/85CdeVPiFF17Itddeyxe+8IVQBPrBYM6cOfT09HD22Wd72qaL\n6dOnc84557BgwQIuu+wy5s+fzymnnOKtPX8lLF26NOTvDl4rFovxwx/+kFtuuYWLLrqIT3ziEyxc\nuBBJkrj88sv53ve+50W4z549m3nz5h2wr1dffTWf+MQneNe73kWhUOCyyy7jS1/6El1dXfzgBz/g\ntttu48ILL+TPf/4zN998M5Ik1dwOwvJx88038+53v5tUKuUt8avEO9/5TjKZDBdccAGf+cxnuOaa\na9i/fz/f/OY3ufjiiznrrLO48MILec973sN73/teTj311ND5N998MytWrGDhwoW85z3vYcKECYwZ\nM4bzzjuPdevWceGFF3LRRRexZcsW/vmf//mgxjxChEMFKarnHSFChAgRIhxdiDTvCBEiRIgQ4ShD\nRN4RIkSIECHCUYaIvCNEiBAhQoSjDBF5R4gQIUKECEcZIvKOECFChAgRjjIcNUlaenuHD2l7LS1p\nBgcLh7TNNyOicXztiMbwtSMaw0ODaBxfOw71GHZ0NNTc/qbVvFVVeb278HeBaBxfO6IxfO2IxvDQ\nIBrH144jNYZvWvKOECFChAgRjlZE5B0hQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8I\nESJEiBDhKENE3hEiRIgQIcJRhoi8I0SIECFChKMMEXlHiBAhQoQIRxkOK3lv2rSJ888/n7vuuqtq\n39KlS3nve9/LZZddxo9+9KPD2Y0IESJEiBDh7wqHjbwLhQK33HIL8+bNq7n/a1/7Grfeeit33303\nzzzzDFu2bDlcXYkQIUKECBH+rnDYyDsej/Pzn/+czs7Oqn1dXV00NTUxZswYZFnm3HPPZdmyZYer\nKxEivGmhGxZL1+6jWDZe76542NuXZ822/te7G0cNXtjYy879wyxduw/Lsl/v7rxq9GWKrN8x8Hp3\nA4D9AwVWbekDoKyZPPdyN7Y98tjmSzovbOw54DFHGoetMImqqqhq7eZ7e3tpbW31fre2ttLV1XXA\n9lpa0oc8Z+xICd8j/G2IxvG143CN4d2PbOTXizdw3twc11x+6mG5xt+Kf/nm4wDc/+13oSiHTn/4\ne3wP9/Tm+NEf1ni/48k4F8075rBe83CNo/vcf3XzQpobEoflGn9rX+79+jv4+d0vsmzNPmRV4aK3\nTq55/I9/8SzPv9zNdVecytvnTHjF9o/Eu3jUVBU71JVuOjoaDnmlsjcjonF87TicY7hhu9BwN+wY\neMM9p737syTjh2YK+nt9D7dWaKobt/dz2tS2w3a9IzGOXXsz6K3pw3qNg0V3zzArN/YAsGnnAKcd\n117zuA3Oc3hh/X5mTWw+YJuHegzfUFXFOjs76evr8353d3fXNK9HiBDhtcE180lIr3NPqqEZ1uvd\nhTc8SroZ+m2aR/+YvZFcOJZtY5jiG1EPYAVqrheWgsHh8hHp18HgdSHv8ePHk8vl2L17N4Zh8MQT\nT3DmmWe+Hl2JEOHvGq6LTnrjcTdGRN6viHIFeRtHsc/bRb6kv95d8GBaticQqcrIH0mLY+bP5N44\n5H3YzOZr167lW9/6Fnv27EFVVRYvXsz8+fMZP348F1xwATfffDOf+cxnALj44ouZPLm2ryFChAiv\nHW9E8tYj8n5FaHp4jEzz6CfvQun11byDQWeWZeP+UuSRddn6VAyAzAE072x5iKZE4yHp48HgsJH3\nrFmz+NWvfjXi/rlz53LPPfccrstHiPCGwf6BAo3pOOmk+Nx6MkXSCdWbEGqhe6BAQzpGOukf0z1Y\noLk+QSJWHbiZzZUxLZvWxmRou+Wazd+A7H0kzOYDQyUUWaKp/rUHSFm2TVd3jgmj6pEliZ7BAk11\nCRLx8PMoayZ9QyXGtde9pusVSjq7e3OhbYPDJbJ5jaa6uLetN1MkGVdoSMcrm6BYNtiyJ8u49rqq\ndwOEANWXLTKmrbqvA0Ml4jGF/mzJu+dK2LZNV0+Ose11ntnZtm329OUZ116H5IxTXeBdz5cM9vbl\n6WxJeedYts2W3VniMZljRjfSkynSmI6FYiJ2dQ8zpq2OmFqbZGudUwslzbdmmJZV9bdl2WzenSGV\nUEknVOJxxfuOhgo6fZkidakYqYR/naV7V/B/G+7lIye8n4s7zjng9Q8VjpqAtQgRjkaUNZMbfvYs\njXVxvv+pswD499uWIQG3//v8mufohsXNd6xg9nHtfPQfZgLQny1x48+X8455k7jk7ClV51z7388A\n8D8VbbpKhvw6cLdpmWzL7mBq85SawsOR0Lw/++OlQPW4vBosfm4X9z6xlcvPO45Tj2vn33/6LBOm\nZ7A7N3LdnH+jOdEEwDf/70V2dg/z7Y/No7059aqvd/MdK+jLlkLbNuzKcO2tS0L3c8Mf7sYup/nF\nxy6vauOexzezZNdKmuoVvnvlZVX7n1i5h3v+spmvXHU64zvqve2mZXljB3DlBcdz3pzxVeev2TbA\n9+9dxZknjuaqd5wAwOMv7uH/Ht3E+88/jtOmdfLvP32W9iZfcFi5uZdfLd7I208dxwcvnAbAqi19\n3HqfiKq//v2z+c7dK5k+sZnPXSFWSGzcNci3fr2SudM7+fgls6r6kc2V+ffbljF1XBM3fHBOjdH0\nEdT8g0vvXFJ/YVMvP7l/rbddSuZomLUSuXkaVqaTz922jHHtddzy/73FO+avu5cAsKJ7JRefeGTI\nO0qPGiHCYYRmiAlhKK8BYDj+tQMZP4tlg7Juhvxre/pymJb9N/vcfBPhkWfvezf/ie+v/CkrulfW\n3K8bZs3tbyT8cevDfGHJLWimxsrNIsh21ZY+9vYXIFair/FZ+kuDdA3v8c7Z2S0ijQcOMrhJt2qb\nkSuJuxZ2De0hPmkDieNfrLm/qzdH4riXKI15ofY1MkVsoKsnrOFXmrbXjrAuf9veLADPrNnvbXtx\nUy8AK17uYSivIaWHyE+7H7mpx2lLRG4/8aI/Zv2Be3VzAGzYlfHb3LsRKV5gxYaemv3oHiwCsGVP\ntuZ+0zIxLfG+BX3uZoC8NSe+oC9bDJ0rN/Wjy3kxxoo4d09fPnRM2RDPOqkcuSVwEXlHiHAYURlf\ndDDapghSssnGtzGsiUm1NyMmt6DPc1XvWnoLB0524pL366F5P71HJF7al+/2tlkBf+PR4PN+ZOcT\nDGnD9BbD45zJlVEa/WVcRaOaaJUDBEC52DW8m2v+egNL9jxb+wDJIjZlNXJzd2iz+1yf6FpywPZ7\nC/6qHsuuHm83IK43EyasSvIeyVRdy6Li3rdhWpiWTWzsVtHGpA0j9jN4vf4KoaW/OMCSwu+Jz3hu\nxPMrz6nEf734E7723HeBcLR7Tisi1WWIH/8Cw8ZwVV8AJFXz/pbragsHZVMck1CqXReHCxF5R4hw\nGFG5tEc/iKU+Zd1Ead9LpvU5/mfdrwF/cnU1hf35Hn625k7+47n/8jQGoCoDl6d315hkNw5sYU3f\n+oO+l1eLxri/TjVI2G/0pWLuhAygW+EI6d5MESnuE0ZBD5MfgHIQEtOjO/8KwIPbH625X2ndj9q+\nl8TxYeuFu7xpW3YnALZR7QEtaQZFxSfvWn0cibzzAQKT6rIMpNbXJH+5xj2qTuCXYdqifcVpq0Yf\na12vp6Ivq3qFCVtOjEzQwf6XtDD5dg3vYcfQLnoKfWim7l9L0fn+y98mOfNZlOZetiUfreoLgBQr\nB/7WqIWyeeSj0CPyjvC64I2UZvBwwqwgU10/OPKWG4RWtze3D/AnJ3epUG9RTMq6pYcmm0rh4EBW\n8x++9DNuW/2/r9ifkfBiz2q+8dz3KRrVpBCc6IPEFyTv16p5W7bFrSt/7hFg9X4bpW0vcsv+mvtf\nCbuGdnt/l4zw5NybKSLFAuRtVCeRMi2bIW2Y32/+MzktX7U/eI2JDeOq+g4gN4nnbNvhB1jWTWzb\nJlN2TMuKUUWufZkScr2vKeZr9NGNZu/LhImx4JiWpbosyZnL2Bd/gd3De6vOryWfuEuuTMuirJtI\nqmjLNqsDNF3BsxAwZe9zTNKphAgEfMkhb9saWRjakFvtPefKe1m+z3cZ5PScZzavJGJNyZLT86G+\nACEhDTV8zu83/5lfrL0LzXnHa1lgDhci8n6TwPUvWrZdMYGaNY97pW2vBat61/LJJz7PtuyOmvst\n2+Kl3rUU9dIBr23VEAAOVV9/uf433LT0m6/6fLcfQfI2TCtEriMJMJpmIsUFIbYmW4Cg2VycP1jy\n/YHByaaSED2zeeC3bdvkdJ9MKv3ouiGIwbKtmtqWi9vX3sXu3F5W9673zgHxXILm/JAGG+hfppzh\nzvX30F3oxbQsBofL2LZd9QxFX6rHqrfYz4bBzdy/9aGq4yzbxjBMYsesIzaxtrm2ss18SQ+ZTHcN\n++RdOSn3ZkpIcX/cCjUEGNO0+c7z/81fup7imb3Lvevphskftz7MZ578Mn0lIaTJkh+xXiwb7O4f\nRG7sQ2l0yLscDnzTdJMhLYdhi/5KEhR1v49lzWT7/iGkhE/Y+RoCxLDdR+yYtfQMD4W254o6iZlL\nSc70a04MlAZDxzyzdzm7zZeRm3tInb6IPbl92LbtRZAXk7tZMfCMr3lXQB2zleuXfImCXgwJoK5F\nJp2IYds2e/Ou8CXh2pJ0w2JgqESuqDNYyrIz/gyJ414C/JgDwzK4e+PveaFntX9fWt5/xnJ1v7Kl\noWqzeby25m1ZNn/peoqVgfaPJHlH0eZvAnQPFvjCT5/l4jMm8fLOQbbvG+J//n0+Dy7bwX1PbuPm\nf57LxFEN/PWlPdy5aCPXv382MyYJ0vjNXzbzyIouvvWxeXS8hsjZIO7fIibbv3Y9w5SmY6r2L937\nHHdv/D2N+jF0r5zOj649J7QsA2BTV4Zv/t+LfPySWcydLrLzPfp8F3c/tpkbPjgHPdVNS6KZ0XWv\nLnPfc/tFAJBhGajy3/aZvLxzkO/cvZIPLZzGceP9VIr5khEycRumTUyt1iZKuomUEGSwfVeZNW39\nXhCNaxbvCfgyQ5p3gBwf2LaY4eQQ0IYkSZR1k49/90nOOGEU557lB9Zc96On+MAFM5h/6niG8hrX\n3LqEc04eiz3xRV7u38R/nHUjsQOMgaZb/P//+SRnnTiGf3nHDD73k6VkpN0kRCBxyKToE7PFnwZv\nB6A50cSG5Z1s2JUhEVcoayafuewUZk5u5cWe1WzdrPDw091V0dvd+XDw0n/d8xLb9g3xb+85ke/+\n5iWuuGgikmKCbGFaJorsE+SqLX384Herue59JzNrShvb9w1xyy+fB+CbH5tHZ3OKVV27vON/8dBq\nxqnT/WsnXkJp9f3QtUzSmfKgR3j3P72de34Dn79iNt/69UpSpz8ROvalbfvZ2TlMc32cz922DOnY\n5SSm+wKQVKHxLVu3n98/v5LkTH9btpynLp6mpBlc/+Ol5EsGiVk+mQxr1Zp3X/ol1NR+8oqBbvhR\n0kPlAnJdmNAHy2F/76833AdAfLLQqH+y5EFSPafQ0SKeUXncc7yUAzlZ+x5iEzZj2LBpcAu5cnXf\nDNPi9kWrKDrmckm2QBZC4pduX06PE6TWcEwXeJ+5ze0PvsykUQ3stzZXxRIM6znyJdFfqYZQ8Z3f\nPUeb7FtB5MZ+5PQw2BJIdugeilr1+bWEuMOFSPN+E2CjE7X50LM72b5PfJCWbXPfk9sAPzr0waXC\nf7Z07T7v3EdWiIIxm7p8Te+1wtXmRlp77EbuZhFmuv6hamn2ryvFMfc+4ZeSve+vIjBm+cbd/PdL\nv+CW5f/5mvsa1BoPFktWi34/tGxnyOddKOkhzbsye5aLkmYguf492eLhZ3d6y1hcTb7HMZvHlXhI\nU3DJ0bAMFu34CwPNKwDxvIediPdn13fTlfMjfVEML3p2l6O1PLVqN893v0TeKJAp1Q7ScbEvI/Yv\nWbMPy7YZGCqHTI1lI0jezrNP+pO1bulelHDZuc+la/ezYWAzt6+9i8cHBUm8vCus+e2vIO91OwYp\nlk1+/dgG1NHbWbRGmFslSZivg1i0XBDzn5buAMS6eq/d/gLL973Attxmb5tml0NLBOzOTeJ/UwgE\ntczmPaVe729TEtaRxc+NUIBJ0dm0O0NXTw7dsFCawgFykmqA5L87f1ixKqQVA2RLeedehCY7arRN\nLOW/vzm9uo+u8UFp2093xo84HypWa+lBzTtoNZJi4t4yWYOd3TnvGVYimbb454unM2/mKN4+2yfI\nn6/9FXvG/L7KvVHWTZZt3hHaJsU0hgs6PYNFL9+BlvYtJA2NYoy27xtClqvzIeS0PANDzvtYg7yF\nZu5bshLTxfcjmXHn+v67nCtWZ4orRWbzCIcSlZGiUrzAQCHrmbfcCdUNsKn000LtwJRXCzenkSzV\nfv08U63j56t15ZST8CS47MM1t9nqa5N+g6biVxOIogdyJQfHslAyQj5vbQTyHtZySJK7QNsMCS8e\neRd8Yqg1BtlymKzcyF8Xe3P+RCkpJprmulWcyzb4wlqmXE3ewTEKEpebgSpE3gEBSDMsUDXkQKT2\nYDHnBWC5SMQVT4iT04JUKpPT7A1EsQ+X/D70JlcTm7iR4lh/nfJAKSx8uglz+lIr+dnqX4aEqoHC\nEHe+fA92zH+PJMW3mtiYge2Oz7aW5q0NBI4TRDE40lI/xaA3U/RiG1yhIISA1hcb5QsBVkEEBA47\nhNubKSI39DM0cTGmFCDvcjUhm5Lfn64B35ozXK6+n6DmXWt5m6aJ96w0AnnbisbZJ43lX981k7NO\nGlO1X2kJC2NlKUd8imOSdueCWJk9TuKaGcc0ITf1hPz6l54nhILebBHd9L8LN2ZgWM95Yyyp1fcg\nqToZR8gNCktSsQXbkpDrhjyXVqYQHs/GeEOkeUc4tIhVJNxPnvIUNy3/ukfq7oTvEnStmsG1siu9\nWnjBOCO8fqZDDF6QTsW1C3qRVervUNp3UyxXTxSGUjs46GAR/AC1V6F5uzm7K8k7XzIOSvMeChCv\npBj0Z/0J1jQtbNv2JlLN1MgXq33Kg+UwWelGONZhMKhNK4bXF89H3uATT7YGebtL2AAKZoA43Ykx\noKGEzeYWyROXED/Gj3LPFKsrMCVisqfpuYFKlQUt9hd88t7W7U/87uQaxEDRHw/TMulveB65foBy\n0xZW9a1DM/yJfqhUYwJWDIYdTasQ8C3bloRqJ8g774zhPV+bQS2gPTvrg0dapy8pOn2ZkhfbYBuB\n4C6XuFS/j5Ll77cKIrmKaxbvzRRDhKZaooJXvkLz3jS4FSvhH7c34z/zXMDEbmbbkGyZTDDOooal\nwRVkhgvV38yY5Hh0S/e+p2DSFu/8eFhrjU3Y4AluFB33k6qxs1tsax6TJTFNuLdkW4yHkhTj25sp\nUTQDz0lLOPeV9yPTlWrNWVI133IQ0MylvTPBjCHFyyROehqAgYL/DZzVeiGtyRaKRumIBeNG5P1m\nQ0CajKsVmrdyZDXvkczmXiCRM2lVLrfaNbybItmQ9haEJudqbg9ix9CuUDRxELkAMb0as7k7gcdU\nudpsHiDQkTSUIT3Qf8UMBVaZlk1eL2AENJ+hgJbktl+pLZtWOFguUwoICLJBWQ8njwlOpBkt7PuE\nsHBQNH1hyU0sEgzyqQxYq4zyHa6hESZiCrudSHsMYbKsDCQKmnF39PmWCKxqrTWoea8f2EivuoHE\nCc+BLO47GImdr0HekmJ4VoVcYLy1DXNRiFN0iNHVzmPHrmJDYZV/vqPlDeVqvE+2BIpBT7ZAr5sg\nJBCZHdNF/Elw3EzE3+ZQC9aQKBE6rPmad5D8k5YgviB59xT6+MHKn3r3D9CdC0SmOwKKOdSKtuUU\nVCsVGsOagVmOcDGU10LzjLZsY2voAAAgAElEQVRtFi0J0Qc3ULJWamC5gryDEfbGUKM3Bjv2i/dR\nSfnHj5VEPIIuFVBkib5MMWzCdsZzqJwj4zyDWj5vggKSs9/oHYdeSHrjLzlj1p0V42V0T6TdmEZK\nTWLaZkjjP5yIyPvvEAOlQTQzaEoNkETg5Yx55C32u2ZzzSpXSemHMsmHa3IdSZu3bLe/Yn/l8ifX\nZOx+XJU+tqLtE9NI0dI/X/Mrfvly7dz6w4Go3FdjNvfIW5FCVox8yXCehY0yagd7ctVLbwAKhk/e\nUkVErGnZVcQ8FCC/2uQttO6gmX446ANWzCrNO6g51zKbBzX3vOFf39e8S9imgm0qlMxqn3cQtZaa\nxWOyPz6KDtgVS+L00Du6dzAgyFnV01qwv7XKoxYD1oOc5vfX1WpRDE+wcTXvuvxUrFwrshX3rDWu\nEKS2+W4J28bT8mwIERtA0m5Ckm36snl6B4uOUO2/N0nTIe9gwJfzHWtbT8Z2hJu8JvrQmy2FiClh\nC7N60KKU06sF3IGCP0Yll7wHRoMZQzHTDGnDXpayWm4C1zIwVNA9rdUcGIXZN576mMidviWzXRxb\n49sXAl8wsMA/xsoJ8pcSBU/zjiXEOOp7jmVS8ngAslqWtqYkvZli2IK2XUT2DRQDgmhgjLRts5x7\nEGOcSqj+flMV30dIKLTodSL0bSNGb6ZIWhWBevkaY3M4EEWbH4XY1T3ML/68nk9eeiKdLeGi9nm9\nwJeWfoPx9WP5wunX8H+PbOIvL/oaZnACiFVq3g5Db2m+l889bdO89VLv2B/9YS3nzxnPFRccf1B9\nfOCZ7by8c5Dr3z/b+1DveXwz2ZxGLqWBAiDxg3tXsdkpSPCpfzyJyWMasQhr3pWlI7tdf6/zcXUP\nFkKBa3nTn4SeXtPFI8v3ceOHTvMi1otGkUw5S4MzET350h4ef3EPLQ0Jpk9soXNyteb98PKdrHi5\nhxs/dNorWiFcYUOpMJs//uJu9vUXkFI54pM2cHfXBo4ZfQ0dHdNC5xfMvC9WKz7hqopMMbGXb6y4\nN3R8vkLzfu7lbh7ZvAncVNWySV+2xH/+5iVQyySmr6Bo+WSlxk1PAHIzuAXJ+4k1W2nMdHHh3An8\n/qmtDA6VGTfL13SDxPenZ3aI8+MlbC2JpOrsGxxC003iMSUsSOL4CZ3+j++oZ3dvDqV1L08Ul1G2\nXQ1JRBkPFzRu+eXzDAyVuPyiseLWJBnLtugeGgScSHTJH3NbjyPFNJ7esI1ZiX5mTWmrGVQUFCCW\nb9hLYgZItoK2+VSSJz8VIkPN0kgACScVpmzF0S2DsqlVuUJsSwJLDWt5FRpfPpNAaRVBcXv6ZEa1\npukPHJM0WxlmK/Gpqyi91Iytpfz2jJj4B+zoHeCWX66gN1MiMcrC/WpiOLWoy4M8u24/v35sM23j\nstDqdlJEUvfkMlzzvb/S2ZRkz2CWeDNgim9GNtPY2GTKQ7SlWnjmZT8S37+voNbqru0W53em2wG4\nc/09nNwxq3a8i2yKNtzgMMdaUVpzJraewDZU1NE76Fk3FmjwrmFmOmlKNEFJBLt2NI9l3fYBVm7d\nBzEorT4Lu1SHjEJ/UVhrmuvj5J0xHNf/DrYMlGHKWi/4rqkuTlmvuIf++fR0OMl0VIP+vAZpMUZ9\nmRId44UrIK8XSODniT9ciDTvoxD//fs17O7Nc/+S7VX7smUhDe52tJYgcYP/UUE1eXuk5Ex++/rD\n2vdjL4i2Hti6iJuWfjNkuq3EH57ezoZdmdBktvi5Lp5d3+1V7zEsg1Vb+ymUDTI5jXXbhfZkOaTq\nkXeFGd9dJuVOYNv2DbFuh29CzZm+VnnnY2vZ11/w2ga8NchlS5DDLxdtpKsnx+qt/fz2iS0hs7mr\nzdz7xFZ27B9mYFhM/I/tepIvLf1GTbO6YYj+xlQ51Hd3PIPEuLXGWveiJTRZ24gJE51kkUooJOMK\nQx3LveMkXZBVIUBGmmFy2x/XMRQ0dct+pLrascf3IzqIxy3vOXmacUwTpk5bwlSKvLBR+JT/vHQn\nz6zd7yWPUWWVklXh/5QNpJiOrSWxTRXd0ti0W5hcS3p4vNpTbRhogM2oVnE/8amrKdhhbV+KldnX\nX2D7viGyeY2N+4VmO8FJbtJfEM9/zrQOkulAycfhZmxLwlIK/OB3Ivgp6Au1ikIjdCO1g+M1Sj8Z\nu5wS5tsa5OvmsVY1oRWu69/gCUFuxrPyunnYhhr2V1eQt2vilhQD07LpaEqScuSQs8fOo8mY6B3b\n0C76Kak6tiWDrXiad1e+i+37hsgVdU8rbUk0M8Y6EXO4ha58F89u3k6uqLN70DeB1xtCECrbBbbu\nzrJsXbfXx6ljBOm675rrLtm63/fne/0P3CMO8br7zh53BhMbxmNjM1QerhKgZEMoIak5j3uBeZKq\nYVsSdrEejDj67uORZBu5LksqoVK2xHc0a2In582ayrFNk9kwuJlp08XzH8yL91wkh5EYk5jAgN6L\nlMgzujXtPceGRJpPvWc2tiV5yk1bU9IXnB3yTlvtTFBO8PqWLfrfaaGkM7vzJE5sn0FnXTtHAhF5\nH4UYcgJC6pPVfiP7gCUvCJnN3UxIru9VqTRlSbVNzot2Pk5faaAqorkWatbudYSDsiHuo9NZu+ul\nAK2INh9Z8xb34i6BclG2AxOxc7/BW3Ozk2mmVm1Wl02yAZPycCk8ybhc/IctDzJQGqyZdcowLZAN\nCsmuqr7Hj3+exPTnvd+1AuLKtiBDu5T2+pSIKSiyhGwEAn1KQrovB8hINyykZA65MbBGOKC9h9Jo\n6qItJRYgb9MCbKRYmeZEI7aeQIqXqtJn7sntJ6HEmdQwAc0uhd6Vjsni2tZwC5gKyCb5ongPKrN8\n1cXS4n1QjFCZSxcJSbwbUkwjmw8kRTHFxDyrbTqqpNDPDuIxiX+7ZBbHT/K1HqtUJywA8ZJnBXH9\ntbGhiRyfOA2A/ny1sCNZCiCBWaE5O/uTagJFlogPC3J9dt/zvrAqmyT0Nuxio/C31iB/M9sqtErX\nv+1sb29OolllpjQdw+XT30NCrqO8+RQAzjhFVC5D1T2N2y6lMTPtKI0DyM3i25Ad8rzm1I+RUJKY\nvULI2Ws6Firn2zH6RzPJnOe0GXgXnb5ceubxpBMqaOJdcZMDlZx3Ttt6EqUX345VrAuR97hRTh4B\nh/iS8RjHNYtqeHkjXxWVPaV9lD+8DQNCaFUMMGOeddF2+iCpGsm44rXxrxefQjKhcsGkc0UDDb1M\n7KzHkp3+OO/8xLiwcClt+xjVmvb6m5QTzD6uQ5i9nW1j2tJV1gNVkZkxfpTTB92L1bDNGLppM731\nOD520j8TV0Yu9XsoEZH3UQg3pWFDuka6wVcIsAp+YO5yD9eXqCgyECAb+cDZygyrOjBDMzVRLEEO\ntx3KmuWQd8kh77FO3WOXICqXihkBn7duGV6gkrf8JlS9yaZsB5f4iD4G/es9gexfwSUvUqJA6rRH\nWbTjL962XLmCvKtyh9fI8mZaxI9byZ66p9haeDm0T2nuC/2u5VPXKGDbYDlZtSTFEOStSEiaX3fZ\ndLRGzfKfeX+5j8TMZUiq4S83CvrNAwFKdllMikrM94drugmqjiTbNMQbsEsppHiJTD6Q7U6y6Cn2\nMrZuNC1JQSZBa4LWuB3bkjB6JmBbigjGGhSknXeimG1TYczwOYK8Ee9lMmVWvXMtsrOkKFYmGxDS\nipYg77H1Yzix/QSM2BAtHWUkSfKIxb1HW0tCrOwJGG4msobisXTUif5nQwF8jvbs+DhtUw2Rr/ve\nJeQE8ZiCVaynM93OjuwuQd6ShSTbWIZ/vhCgrND5Vq4Fu9jgkYvSIN7rlqYYNjZJ1dHsFckjLtcq\nIyl6IChNwth/DBAonOFcI6UmUWQJa0jYyPP0e+MNYPaOp0FtAFsKPUM1bnrnx2Iylkve5QyWbVGS\nRTu2HgdkQdJObAKAGjP8sUMQn/usl+xZ7uVkd3FSu59tJnHcS6RmPYuk6khmjNaGROBaQEwnHlMo\nODEPKUX0rTMlNN5sOYuqytiyLtwWtqC5MbFjkWwFdew28vWbkRQD25KIqWIcG2PNSIkCclOvqG8e\n8Hm7z8H1a8tNfdhjnRUThloVVHskEJH3UYxaUeFBM26tJQth8hbHFsoOwcmSZ+6CEaIxAyjVIJ4l\ne5fzu81/Iu6UKHQTHlQm+we8oLr6VIzm+rgXqeznwq4OWOst9PmEqRiA7ZVelOoyJE56GjsogDj3\nIAX81K7mDWHylOurE9G4ZnMXlZp0Lf+pYfpJNoaMAye3qWV216WiiLB2J+eA5m1LgQxteYe8bf8e\n9pZ3ICkmetfxGN2TgLDmHXz+linGRFZNz/pSMjSSJz0FQL1aj1VOIUli+ZUXSZ7MY9kWY+pGe+lb\ng9HphprHLtWBkRBaqwQ9Q4Jsi6azpGr/MSQK47wJXU6UeEL/XxLTnwtZB1plx7edyntCK0DRsa40\nJxqZ2ijiMFIteeceAuRddDRvyRcw3ICidCzFqAZB3iUr8Jwd8rZ0550xYk6kcfC9EwlyknGFsm7S\nnmwjbxTIlYre+YYhuwMi/ne/rQpSsDVBCLGJGyFWoqlRXNclJUWWQBcEVrByoh+q7hEj+OlTvWVy\nAdO+KsvYWgoZGSvmm91BmHx1A5JyGjk9jNK2l8TMZ5AdQSKlJokpMmbJ1byz/HHrwxjNwuftWg2E\ni8dGaXfW5scCPnkH7rNetm8Fd738W4KYN2Yu7516iX8/ySFQdFRJCEiiLdcXrpGIyRSMIkkl4WXO\na3LqqWfKQyKHhaO5e5kiTJVYYTSSbLFOe1ospzNVYoo4/8KxFyFJoI7ewdi2tDf/ueMcU2XqnMC7\n2Litfl/N2EEVHDrUiMj7KEN/IUPylCeQW/bXXCccJCOj1gtVg7xdYrVtO+QTr6V5u/5qqC7WAHjL\nJJTGAZANr22fvG1PA3LN5om4Qkdziv6hEoZp+ZHyjoYeLIPZEyBeSRZtuZp3/NhVyE7mLjcgxtMw\nAm0EyTtoqQhOhi4KevgeNcMK+fprJWUIErxsB9s8sLAFIg7AUHJYpTS25ZyrGCTiCrIsYztadHnT\nqZ42plt+H12t08o3CpO1c76LEHk7yT0kxcS0bAzTIqsNeoFCti152rmUKLK/wmffnGyixZkw3XSu\nSBaWpHt+WDdCtzebc8bL0byNGLph0eFoS7HxIpuZXJ8N9bfDnoptg9wUWAoGDNvid0uiBVsT10qk\nxHlFowSWjLZtltBuXXOrI2C4qTjr4knGNAvhQx29E6VjlzceAGXNyXtQSiPJlne+uz+pCGIp6xYt\nSeH3HigP+uTtnO8SnEsGleZYs38Mck6YY+Vkgfp6cZ6vecvYegJsWJ9ZizJqJ5IEiuW7GWwthW0H\nnoOsE1fiKLLiLAGVSMuNSIkCx09o8oPLzBjDeY3ZTWcgqQbxY1cj1w1jJ7NOH5IidqMk+rJjaCeP\n7XrSfxCGK4CIMY5PWRsao+A35Uac10JKTXJK58zQNkm2ScpJLzmPa2lQO/YwNGoJBb1ISvXT5SbV\nBEklSaacJaZIQrM2VRJxcb5hWpT2jQ9dwzZjnvvw2JaJIj4hVqalIeG9h+51Fdm3HoTa0OO159rD\njIi8jzI8vnOZSBRw3Es10xAGyeBHf1hbtT84eWu2+LusmTy8fKfIchUMOlGq28+X/PaD5kkXwQpS\nUsqv4OOlHJRsQbrgVeJJxAR52zbc9sd1dGcdE6ZD8oWSwc/+tI7t+4ZYvlVIvHaAmAbcDGSBicI1\nobkfYFk3uOuRDdz57OPszPjpX3/1WKAkZg0ff7DYA8AdD7/MMxt2+PsDWt7zG3q4/+ltXoY1gBUv\nB7JG1bBkDBULWJbN/zz0Mt/9zUrW7u0CyRZBOs49SorQvFVZwpZ1UnIdVqbTm1SCfmQ3iMc2Yz75\nB4Uwl0D2noDpaeZim6abFAMa6JT6qb5Glyiw28ls5b5DxbzM48sdM6yreTt+U9fE6fZxy0AXS9fu\n8zVcI0ZXT47lS2JYxXqkej/gUJIgZTfz2TmfADOOlWsRVhGnbbmhnyFpPzNaj6cp0YBWEtey4jl+\nt/lPDJYzKHoDZt94QAqR992PbWbzfiG8NSTqGN/a4l03Ptl5FxzXgusxsYvChy4lnVgKR4BKqkkS\nMZmybnoWiKyWDWhsYdLxfMoBUhjVmgYkFIe8pXiJtMNHSVX0W5Yc06/zPONOXeykEqg1YMvCwuCQ\ntyXrYc0dyA6oSDGd3WN+i5zyg62GCjpnj52HbVbTQVJJEFcV8jlQibNreE9ovyuY6Hum+hslv5JY\ncL16LeKb0Xo8lxx7MZIk0ZRoqNqfjqX9zHqm6iXsKSf3M1jOkI6F6y00JxrZm9/PQMOLYi4zVS+W\n4p7Ht1AeaGFy9p3e8dZwi5dpsi6pYhtxJFVHVWTf8uhp3lLoHqxcE+WNp4KexDBtNnVl+PkD6w95\nIaeREJH3UQYvwtGSR9C8fXJdvbWvan9wOYcZINpHV3RhmFY4KrZm1R0/GKyW5h3URCXVr9Lk19AN\nrNX1yFtm1mThk3txU6+fwcohnadW7eXZ9d3c8svneXHnDsDRLBHLSTzLgeFrIp3pDq8PAC9s7OXJ\nrmdZXlgUShm5bmcgM1egb66/2fXLu9jTm+dXj/vJN4JLjH58/1r+9MyOUCYwSxJ/nzdnfDga10Hf\ncJ6+TJElq/exbscgK7YJ4cQq1vtai2z4ZnPZICY596m7+Zb9PnqWF1MNkH+15l3YO9Zr31svr1ue\nyVnfNY2J6SkhzXt3r/PsHRLq2qvRtdsJ7nK1UvcenWdh9ApNRx23lb+8sMeL5G9K1lPWTdZvz2L2\nj64al3pzFJObJqEbFtZwC5IEckoIdW5ynrdPOFvcS05MY7uNjSLeApDxScMn7yKPPt/FcKmIbcP0\n8e00pfzJ2DYVLjhtgvfeubGK9YrQqpV0nqb6uDdeKVVohZpmehaIgVLGF5Zcn7kTeCincsSOXYXa\nIVZttNbVMXl0g9NH8b6NHS15edBd8vVQIfx1NjaGfttaCjlRIjZlFSYaKZf8HfK2HdO7jS0sHDbU\nxVNccf5xjG2vJ2ZWL29SZMVZlSJR6vOjqMubT0HvOs57zu3pZibEnWWkqk5RFXPP5I4OLxVqXQ3N\n+x+mLOSCSW8T/ayxfGx0U6OnOYPvv3aRVivJ2zGdpzYiyRa2odJQkRBmzqQp3t9mpsPLQJlMqCTk\nJHJcFwKPa4Gq4bcHMHomYGU7aUzHMEyLp1ftZdm6/fRnj0x+84i8jzJ4ZGHEvIQQQYQCoCo0ye99\n8kzGjvJfZMM2mDymkUmjGiiUDQzDqliPWi0cZEv+MqNaPu9g6khJMTzy9uoDy7XIW+GMmaOZM80h\nXOe6biajYPUeOT0szLmOyTc4oU0d3eH9PaqCvLsHi6GsX36DZu2/HXPgLv1lnt+/MnxOILCndi7j\ngHnc6d+0Cc186rJpVUdqZjkkhA3oTgnIYr0n8UuqCNBRFAkUHQXXzxj0AYoJzrUE2IbqTToN9YHP\nXNHF0idLEe3bEpYsyLikGb7mbsQo66YnxMjJgle8xBUWTC0WIka3L+75AP/90X9gUuMElPpBhu0e\n9sbEWH78nbO5/v2zAbDyTVXjIluOS8AwPdLxVg441291TNXZrF1V79pSfWuEbz1wBQwDhRjzZo5B\nkiTU3aeKZUKKiT12LRPGiOuVyqKm9LX/cBYA8+c1M3/2uEAwWIJETMEGGmMueQ/6AW+u5u1o7kpr\nN2rbPuQ6IYTc8E9v83IPWI5PefrxKe+7cjVvL6eMHo7GnzVugvf3tz8+jwmdghzV9n0YUtkjb1fz\ndp+Vi3Qsxa1Xn8Ox45qIqTIzx4r2bEOl05jBW8ecDvhLSs0+EX9QrzZgDY5G6fdzPnz742+lNS2+\nydiY7QzYuzm+ZSpffN+5/MvFM4BqzbshVs/ExrAZ28qFBZJxLS2hnPZSxZyUrqHNB2FrqVBFwpnH\ntPD2U8czPjlZXG+oDdW5P1mSOG5MBzYW31/zQz/4L0DeDTFfwDH7x5KMKzTVJzBMS9R4l6Ct6dBU\nX3wlROR9lMElC3dyrUTIh1rxosdUORSJbdg6MVUmnVTRdKeggHJgzXsoQN7lGpp3KFtWgLx9zdvv\nk0fejmTtfaTudR3ydgPzpGQOuW4IK9vmE1egv2ogAdKYOmfpiUMm/dlSyKzuJVEIEHZQcLED2ZTu\nWH936B6DwVmVZnWlfTdKu798zG1TUYTJOwjbktCscMrUYVMEuNmlOo+0pFiZZFxBloVA4+ZxxlKw\nLRkppnnpJstWwIXg3INhB56pE8ErGEEiIdWhSeKZarqFZrmFMWLifdATyCjIyYDP2yFRvaSCGcM2\nFc9c6xKrazaPqTJtyRaQID9KVGhS+o9lcvNEjxRcK0pojB0TsW5YHmkpDYNI8aInILgTaX8mXPEL\nwFQC5K255F1EqsuKwCzbJ8J0cRJG9zEAPLN/GT3SJm98VUVmVLodWZJZ2buGPnmrFxOQiiW9dzet\nOMVB9Kz/jjvjbzlL+jwyQBBZc6LJS0lslcWzzpQy3mqKSh+xvP2tnNJxovd7Zsdx3t8xVWFWy6zQ\n8UmPvMU4G/smc0LsLG99eqXw3ZRwnoNkc4w1jytnvFcc5xatGWrj7JaFvHvM+wHoqCCphrgjPIze\nSVxOcOnUd4b2B8n7golv4wunX0MlyhtOp7R2nve7M91BIjYyTbkCigvTDs95drE+RN5u8NtFoy6l\n+OLbwYyhBoJZ61RxDz3FXuRkwUmyI85RFZn6eB0fPuFyJmbeAbZMQzqGqsjohk1vtkRrQ7KqENTh\nQkTebzDolsEL3atGXPJVMv3JtbbPO1A4voJ8Y6rirSEGQLZEBKVTYSlb0MKm3Rqad1+gwENNzTto\nNlcM8k4ke9Eh76CJ17CdJTfOB5WIK0h1Wc8n7loO3OVZSpvwVZt943yTslJtogY/o5N7P2U9LJgk\nqHP6GPQHB9ZDl/2JqXISDd5DZWrP+JS1xKesCbTpkLcsoVNhTrNUdFsLZR3TLFdzjnvFFKR4mURM\n8QOeLH+JkK3HQfXJW7c1Z3mM4pnNTcIJQkLEJTWiSQWQRIpUL3LdUB3BSyItNSKlh4jPWYSUzHkC\nUbkozKlWoQEplUNp24OccM+Pe/fdFHdIIZHHzHQwpnwasiT7BXMMv7b4jBZhnVBNJ5LesDxBQB29\nk+QpTwrLhy15/s7eTMl7Z+aOOhWAMaXT/HE2Y9iGitLc65fRDFilEjEFK+dr/xa+5qwqMnElzsJj\nzmNYy/F88REUZy11OuYHU6WoR5UUitKQJxDalkIqoYAR9zK9uVAlX5sDMHSFhBJnsJxlq5NCdErT\nJIKQyg28a8qF3u+JjX5lrrgqc+boed56cPDJ0hUQsFSmpWbzDqeNyY1+8hcARXIEViksCfnr6yVa\n9anETfE8O5rDxOmSN8C5o89hQsPY0H41UBP+H45d6AsLQVgqdqGJy6dcwTWzP8bcUbOrqskFMXfU\n7NDv9x1/ifftg1jn71aQA0g6wlZSjXvvnRog2/p4WJMXFj4xfm5g2+mjT0XVxfuSTsaIKRKGKQJn\nK8fkcCIi7zcYFu94nP9Z93/cv/Xhmvu9IDFbqql5B3OaV5KvLNuUrTC5xlWZdDIG2BhNO/xoVaiK\nNpcb+1jc/cfqvgQQIrMKn7c6ertXHxfAQiz18j5OtRSuUSyHydsNGDKHW3yTskNoqiJj4pN3a7KV\nhJIIpYN1NSZzqIVOyzFhBwQc19xp5ZrQd87wtlea+4Jthgs01Fiap7h542VPsDGHm/nQ8R/ANhVM\nWw/lHNesssiFbckhzTsekz1BwF0/DIARR1I16p01/yaaFyTkBqwFyRtVJyb5ZFkvi0lIbhhk7eAa\nj7xtM0beqaLVoDoR5bKN2tnlkVCx4JiF841IEsSPXYM6QQRTeZYRSQpN0ma2jQ4nKU88oFG17lnI\nl97yWT4y/QOUN84hVRJJRXTDCsUyiPHQUOwEsiRjWlaoZOqxzZP477d/i3GcGDonKIyBsxzPQTyu\nYA2OQt80h45Um3+Q5QsYFx9zPh+Y8T5/V66JZEz1a0obNu2pdqz4MI0Nzn2ZKumEeBZuJjcXHzxB\ntOUSgmHatCRb6C8OsGlwK82JJi8ILvhadaTamdQwgYXHnIcs++MXU2WSCVUkxnFwaufJ4hoBzTKm\nysyfcDafnfMJPjDjn0J9mtYqgs7M3rApOzPsv++9mZIXhNpWURmsMemblMc3d3IgjFQO2MVJ7TM4\nrmUKkiQRj/vvu7Z9JlY5yWzlHXz5LZ9leutxofPG1o/m+jmf8n7bxbqQ5u0+r2CKYzVQddHVvF2Y\nw63e30GN2nUDphNqiPzbm4+MyRwi8j6ieOz5Lrp6wqkphwoaDyz1g5y2O8kLdgyJZSuPrOjinsc3\ne8uhvCQNstCUlq3bz12PbOS3T2xhqKCFfd4V5Js3CuGkIgHNW27qJT55HWpnIA96KEDGJhYo4wjV\nAWuWbYfK5EmKEYo2V8eJ7E5mthUz60ySkuV9nHvl1aH2RE1rvw61p7kYcZ/YHD92XVL1lr5JG9/G\njq4yacXPmAR45KdvO9ExHVdq3k7U9daTwIxTeukcZDPBkBZ+ZkENasPgZm596o/c+9ctxBM1stsF\nNG+3kIax5zjmjj0RLAXN0vnLCr82s47uCCaSuE9bglhZWCVc4UMPrO/V40iK5UUoIxu+VcJdR+ya\n6yUTSbZIyP6k26AKv3HsmHX8pe8BhmNOX4yY9+wanWPASTiiashWjGLJoqUhUdNnbet+bEWQvO1S\n2iPvYKnaJI2MruskEVOxsh1YTlSxHtC8XcjJAoolnv/9T2/HtGxithCwOlLtSJJUpa25goxVrMPM\ntnGccoa3TxwrYWY7mBYkA0vxtFZJkpg35jRa4+K9NbonElMV7zovbupF1uqRFJN0Y9k737VqeX57\nQNtyMjNahb/YNWnbNqT0jnwAACAASURBVExrOZaSWaZgFJnaPLlm8Q5FVvjc3E/xrikLKrZLwrxs\nJLC1BAoqJ7YLAVRRwiQPMLlpkhfU6eLE9hOIbTsXfdf00PZgVbvebJGCM1e11CdCxzUHyLsj3Uot\n/MeZX+Rrb72h5r4g4oHnlwz8bfZOoLzqbYxPTmZUXW0BIRiBbpdTNc3mSlCgCZJ3haBuF/0I+CDJ\nu0pJXSoW2t4RkfffH/b05vj1Y5u56X+eC23/5cMb+MNT2/ijk6fc9dmokkJfpshv/rKZxc+JZTa6\nqfumV4e871y0kcdf3MOi5bt4fkNPyOddGdzh1om2yk6QkWx6Pu9ahelD/uBE0VtD7aLSbL5++wAl\no+RPtorOcN5J0lLWQRITsbb5VH8Nsmx6H+cwvdiWRGnVOZhZ5+OXA+StaiKBhy37EbxOn+rTMXRL\nx9bjFLJJfvC71aTUdCi5hr+EJ4ahy1X36Js7nWIMRh2y1iisCZIFWMSnP4fS0uMtWQHYYDzDw8/u\n8pa+ubBtKeTzdjNC2UYMWZbEGnDZ4PHnffK2EMk3Jo6qByTQ48LnHTCbG5r/2bpaaSJtihSwiu6T\ntvMcOjqcNe+uoBPQLprjzc44OlYBxe+jm9K0PuaTr6Tqjt88TqFk0NaYrE3eAW3ZM5sjfPluLedY\nYFJWHRLz/LPOulndsJDNMEkAyGaSkmbw4DIh7F7UcQVXTv8nprUI7TFoKhX37rgjSmm0jXOZnvTN\n6u77Z9swJu2n6cSWQxM7wPunXIm2YwZm/1hiqkzKuc79T29n5y4np3Z6nTjdUkgnVc49ZayXZAXC\na59PmSpMvP9w5jGcMdrv07njzwx0vur2PbQ42cckSfJIpLT2TK4c93FPuw1mF4yrI5ugAS47ay7Y\nMmec4I/DhXP9wLjeTNEjrvGdgqynTxTvUFOAvNuStcm7OdHkrYmvhfEd4t0MCl+1zOYH8oMDTJFP\nQ983GZBFeteKtoKat+dWIOxDNzPtmAP+OAQ17wWnC5fD204ZGyJvNxvckUBUVewIoViuvfZv/4CY\nLF3Tn0veiqzSE8gnPVzQ6Q/UL0a2KGtmyHReLBuUpaDmHSbkYUeDtMtpSJRANompCnXJcO7lifHj\n2aVtCpO/G6S07xhGG7Pon/DnqoC1fFlDUkyxbjemISkG/UMlLMumZJSRZJvpbZP5+GfO56tPbKef\nHpAt74PSEZnF7HLaXx8qW77ZPKb564fLKSRkSAhLREdTim5LCwWaqXZSRKzLplgj6yWmUCmXbVER\nyCHsKWMbSU2qZ1sOT7CoS6pCY0oBqoYkWSL5DAifbkX0uhtjEJfjDL94BvETlgc0b9kb/1s+IqKX\nZTuGpYhc4u4MLSkGthHnuHHNfPby2Vz/2FKkZE5MHE77Wjkwm7sJJGI6LQ0xioqF5Wb0slQScoJU\nyuCmj8zllj8+AAjTopsfqjXRChUp6t1o9JyjeadUn4ileAlUDb2QwrJt0kmVn3ziHfz48TgbSi+i\nNGRoUBspBsgqpHlrqZqatxfxK0tIkh+kqBsWsRqEI5kJT7iYPKaBK+efSl+fbyFprwim0ndNI3Hc\nS+h7BbknAqbYoJY3OqTNSSGTKMC4pk7MHuGLjqkyHQHTsV1hGscU39aHFkyjdetuFu0Sgkaw1vak\n0Q386NpzhGUFeNv4M2lPtYX93QcoV/Ctj83zAh49Td2Ih4g0SE6V91OJd501hVMmt4a01ffNn8q7\nz5rMt3+9kr39ec+d0tqQ4NZrziYVF8cGY0Nqrek+GHz5I3PRdCtErkGzubftAH5wgOPUuazrEgpR\nkLxdn/dImncwT4W2KRA3QVggPPeUsZw+YxTppMpTq/wA1WT8yFFqpHkfIVg1UpUCVaYxw3KLhMih\nYhCFkkFf0c/JLSt+SktX8ivr1oE1b6fghuf/U0zH5+1XParPT+WczvnO/mCqVD/pQn3MCc6p8Hkv\nyQo/va0lhd9WFVWSBoZLXtnIpkQ9qiLTXu8kvohp3sdZtovexGa7NZklV/O2QdWwveAmmTq5Ednx\ng7c3J0WQn+l/1JYernYkMi4JE2nZ4V3h47dJxhUM1zfsCACphIrlraUuh0zwthFnsiYKIXjJLZzx\nPnPs6SiWWKftEroiS2TKQ0hIjKpvce6gVhIVE0yVeFymPhUThUEUC1s2sJVgoJjTDyeoTZcLtLW4\n5nKfHBpiDWS1Idqakshp8fyntvqaVGstDckQZnt3kp7deiqzW+eIcUgNI8m2l9WsLqkSjym0SZPR\nd8yEgQl8aMpVBNXFUGCSLdf0eYeIXJFFwiBElbRa0buSmfDM+lPGNFV9R5WBQ9bgaN7ffg22YyUI\nam5BIvdWKQT6EkRdYAKPq3LIxxn0N4MTsJZUkSSJ9nTAOmGE1x2nEiqyJCFJEv90/Lt5+4Szqu53\nJKiKHCLaWvcUJKr4K5C3JElV7cnOto7mJLphsddZMphOxqhLxjyiDa7jrmXyPxioilxlNamteR+Y\nvINCyiuZzYPHnth+AnVqmium/2NVm3WBQlCSJHn9DL67ifiRo9SIvI8QauUZr7Xd07wlhd6MT475\nUrXm7aIuJV4iTTdDRSrCmrfN5sw28Zdjcg6bzZ3lN6UpXuIKKWg293Ihq6QSKkk1GdK8dctgW2GD\nc7AFZswj/L5Myat85Urnk5pEQJJclyURU9AtQ0RKuyZ3JxmDu9YbVRdm4YD/s0FpEfV3FZ2OppQg\n74DmrZXcQLhAZivHZFly5CK1bT9Kx26x3MPSPHJXFYlEXMEoOwJATAv5usHGGhjDhORkp9604Res\nUBNiQgsUtFBkiaw2RH28zsvFbLqme1dIkiyRWMJUfFOuQ85lCuiKIF+9GCAMxyeXNftobnL8pwGz\nbGO8kbxeYG+xy8vHPGOUr9U1JeqIyeIeZdstpOFkbnPMo+lEgg/NfC9WOemZ122nIlk66QpbNnax\nAXXvKbSkwlHESSXBuPix6LunoioyTfV+JLoLNaAdKrLkFXoQmncN8jYSnvm2crKH2r7HukCyjrBZ\n1m+/MR7O8hXsl/gd9h8Hr2NraYrPBXzRAZ93Q9zXhG0zTN6viFfBg0HNVKkIbHu1cO91pxO3U1dJ\nskqck9pnVvnjXytqEXWlUHWg/alEtQl+pIC1hng93z7nZs4c+5aqNmu9ZxAm/1cSKg4lIvI+QhiB\nu6vglsNUZYW+rK9590gb+O2m+/0DA8TqlgYtaWaIUIOat9zcy7J9ItLbKjnLpGJlYorsmM2dNddy\niua0Y/JSqoO9MEVO6sZ4AwOlQTQnA1le9zOvGb0TsE0VJSau35sp8v/Yu/P4qMqzf/yfs81MJpls\nkAAJ+yabICgo4i5Qt69WWxUXcKlaRVu1daFUpbUPuFT9Wbva1trqQ12hllddeLpp1YLWlcUVtAjI\nkkD2zHaW3x9nmXMmM5mQZCYZ5vP+h8xkZnLmJMx1rvu+7uuOWevL7eA9uXqMeVyl+/CrDx7Gqzv+\nbZ6npJ7Y/knrzKBmNUZxFy/ZhVRCoB0DyvxQDc0zbG533rKDrjkkbZ6rcLsrWJTvRWNwM3a173Iy\nd0kU4VckaBFX8PZUrsdR3xSBz96yUo45vxO/5IMkCOb6Z8mcKxdFc7cjuwMUAGiqfYFiPq+4OLGk\nx+nnbL3fiNaGmGiNnESCieYeVrOaFr0egZB1fK7gXR4wA+nKj58xHx8NYGBxYs4x4JfNoXMAJdoQ\nCLGg01TEzmz9PslsEqO5A5V5UWF/gNt75IiC0GGeWBAEnFJ9DtQvx6KqPODMwbqzMzkp8/YOm4uY\nV3YhYlumIbLpaMS3j4PSNNK5uEgOIgBQXtJx7tGdOaWbUxUEAQtGXYzohzM7HFcyRRZR2mFnP8HZ\nIcuI+Z2LG89FQYoe+r3NfUHiHjbvjeAdjWnmErqkQCUIAr459RKcMvLkbv+MVFLOb2e4oHFfdLmH\nsv0phs2TL9DSCaYY4QAS9RoAg/dByT1s/t6n9dANsxfuftd2llu/bEJMtTJcw0BdY9jJAPeXJ/aA\n1ttLrAIq8zXNDy8De/WtqI/sT/xQV/C1h5dlQYbeWG0uMSpuhqJ4h819QgClwSLo4SDEUKPzGu7M\ne2d9G0q1oYjpcTyx/jWs/2C3s7etumeY9fqJIri6pjBiVqFdsbWOcnRlrbmOdsBufNGyHau2/MU8\nUCt428PmghL3NOZwF0KVyFaLVCWC0lJ7eU7iP09Ts13oZm1VKCUqsdvaBGCH+SErVdShrug980nW\n/2NZMiuW7QsdqWq7N/OW4tjXHIEMa3jWmuMHzEzTzLytD3YljrgRQVxXUe4aQraXfCnDzKYgpSF7\nIwvZmUqwq5TbjVZExCYYutnD2mn5GPfDiPuwH19ik/oP8xQ0JqqI7S077SmX2JbEOmDAXPs/sMgM\n3pE2H9o3HO08xh42d9bhC4lhUfu47OBk/32bc9YdPwztAJuuGtcdJCVR8BSs+WQRgwODoe0fAqO9\nFOquMdBVn7MbXjDFvvbuzMo5Blfm7Q48/qQ51YmV46C3mFXlyRciboospXyvV0y5GPH35gGaL2Xm\nndziMxv8nmJAd/DufnAZ6JqKSHXBlC2ZsuxMz3FPz9gXAuky784Up/g7AwBZTrwWg/dByJ15P7Rq\nA155dyfuXvmOp9HK8sfeRn2zOTcc0+PY1xxFZWkAJQEFQsRd9GP9J7IztiIFUtUObCsyd/s5acDp\nAOBtB2oF4UvGLwIMEXpbGUR/GIZoNfiQzbaZPsmHoF+GVl8LQdQhVe72PB+agoaWKN57y/xD/vN7\nr+PXaz7Axm3m4+zgamgydMEMmvubo4gLVvC2Mm9REJ0Mz3OeUvTrdg9ZuzPvErnEeZ8lQSvwuTJv\nLe7q2Caaeyy7s57wl0M9VePun2suvZGgt1ZgQukkSKFGSFWJZXSxz6bCMIC4nZl7Mm+/uYey9f7E\nYDPaNHsLy0TWO7bYWspTuQcQVZSUJC5A7A+BgUHz8a/sfx5hcb815SGgusIOgmaTFA0xRIw2xHeM\nxfTBiTXqdvAGzPXtRpv5evaccNAvozpoBqq2Fsks7LOCS0u7+Tu3i3wG+BJroO3gbVfXjq01f85h\nYwc6owLuoFhZGoAAYGhVx9854B16lCUhkXlrZuadnDFqmp5YrpNuODPpQzlV4RLQ8QPXPUfaWYGX\nnbFVliay/AGl5haVMsy/02CK4J343XXN6CHm//3Dxg3M8MhERul+T6LYu5k3kH4IORvsn+WTRQyz\nKtyryjpvhuK+6PLMSSuJkbVU3+/KcSSTPXPeuQverDbPESczKWmAPPQTfLjTu7ymTdgLsbzOmeON\najFEYqq5jlY30KaJEACE6meiWbaWFllV1MGA7GzacP74ryKyx9zooXaIjMtPnAXdMPBKXQPW7/3M\nyXy11lKIZXUIi/sQ9I81s1PV3NtWFAV8e958/PKjTzF1qoh3/55ocFLiC6IZifWPdr/o3U1WW0+7\nGMfOOiUVMVWHjihEeCtSJ9YOxsdNiZaR5vPNDz17pACAOd/tt+daXQ1GrOB95IwifN76mfVzXWug\nnUYuGgTr/BiajOKAbA25CmZBmL9jsxnJNSw4o2w2Pmr+wNmJKbLhWHO/agDtLQJQYgV915y3JEWc\nJVRicTNaVfPnuzPvb596PJa/vBX75E8hKDFUlgWxwzpGe8574UmH4uebXI1rrO5XQwYU47yTxiIU\n9OGdLwdgXcM/UOoL4ZRDvo5h1SFcEDYb5OwT/us89fCRI3HWCWbryTsunYn9zebWh9VNZlCwh8JP\nnjEUf39nBzTdQHFAdoYd506agj98bG7KcsnJhyEkDMCU0WbWfszUIRhUUYTRNWaf7GWXzkSFK6hV\nlRfh9kuPwODK1FXInjlvSUQsrsEwDKfaPDljVDU9MSef5kP1vmuPxusbduGZl825fk+Rmiu4JQc0\nd5CXU2Tw9187B63huJN1L7t0JprazIs+ewcrO4ja2Zq7u9hti7xVzJnMnjIYA8sCGF2ToiNZkvsW\nH43G1ljSnHfXC9Y6M6DU3BfdMHIbvAM+GcsunYnykB+KJODLfe2oTXMRaHNfdCkpRlnSFay53X/t\nHLz18V488Tdzu9p0GXqqi4NcYPDOEbswTSzdB6m0Ac2tewAkrh63la6Fuyg3psUQi5vLqEQB2Cuo\nKJaLIDQMhVi1y1xcJOowAAQUGYJsZn2TB0zA3z7ZZ3bv8oedtZhavVWQ5jOvnu250jZhPwRBgCDH\nYaiJhgMTBtdC/FhE3JpntTPvimAJmvfHzbXWmuQUpe1vbzH/muxqcSdwxs0GNIpVze5aQlJRFAK8\nsdvJrGPbJiIweb35GnIMorVlpN6ayFxDivke3t3/Nt7d/7Z5pyvzdobQ5Rj8483vG9EihII+TB0z\nEOs273aCfak+BOMGV2P9f8zzJEuCk50FjUrobaUQi5s9xwgATc0wg7ccd4oI/ZIPoiA4PbvF4ia0\nqOZzy1xz3n6fhOpQGfaFzfqD8jLRXLalJ4bNq0PeavD4DrOJSHFAdrLYE8dNw4mY5nmcX5FQWQpU\nxkc79w0vH4RqK3sqtiqFAWDW4MPx7KsfQ9s/BJNHVngyQ3e2NW5AotBt+qihngsxURBwyPBEtfWI\nwR23dxw5OH3w6ThsbjhD58mZt98nQdONRJerNMOZpUGf50PeHdB8nmHljnP0smQeQ6oP9oqQ31lf\nDQChoA+hoLeRjP1+3BcCK+bcDkkUUaIcWMFa8rntTFmJH2VJ8/2pmrR0hyyJqAz5sa85mnYIOVvc\nf0/2KE9n5DRLwVIWrKW4QAPM3/PIFH/Hydw1BRw2Pwg5w+ZWT+WInmKHKxd7yZdfkcwPJ1GDIvoQ\njWuQkpYYybIASUlsU1jXGIERC6BdS6x7tVtzhvzmB669WUMM7TAMwwrePkSsHbxkUUaFvxx14X0Q\ny/dCHmAPi7u2WlQVZ/mUvduYMyft7GGsojUchVhsZuYhV+FOyrWg9rB7WzmiH1vLk5QoxFAj9EgR\nEHd1B/OluPp27fhlX0BIFXsgKHFIrYOh7jSDn7OUyPp9+EQ/Lp9yEcT95m5DdsEaYFZdq9aOSoYu\neLL7BmsBgFi6z5mXtzd+QDwAI+aHEGzGjlZzyL0maSlSib1LkRxDaYld7Z0YNncXOk0UToTeYI6q\ndDXzce+6lG7tbUD2I/blKECXMLC8KG27R3exXbHcvXW86aQqWItZ65cVyRu8Az4JqmZkHDYHktY4\np8mQUs2P2wVv7u1dD4Q9kuD+PZX5Qx365OdCb2XeQOJiLpeZd3d4Mu8U1eBdybzNx2U+X+6Lg1R/\nS9nC4J0jTsGatYGCs/uTLWlLQ7tHud9ndmkSJA2KYG4Dam9q4ARvSXSGtQNyAPWNYQhqAO1qO1Td\nvD+shiEKIoKKGbTsIdKI0WbuYiQYgKZ4CuiqigagOdYC//h3nPvawq41yZriFGm1WNXmRofMW0Wj\n8hnE4hZUxMd4AkiqYOLeYcoZQg81QJDjHdbRFrn28lWsYUnPPLp1DGKRtca8bZIzn+tklFa2LMIq\nHrP+I8qS4BS6tEfi0PbVmIFb9cFd6mq0l0JvLYNUXg+p2mxp65f8zsWa3h6C6I9g475NKJKLMCxU\n63kPIcVV+e/XneO2P2R8kmvNtpgYdTiQzOeiCeeiSC7C5AET0j4mZm0vW1bs82Qi7vXSgiDg4gnn\n4uyxp3d7HW86SoqlYnbzEZ8ieoJOQJGg6ZmHzYH0WZV7Pa6U4jF2Zt3Y0vlFdjr2h36uM9RUpG4U\nZ6VjX8wV+/v+fXXGezHoyox9nS8VS9aF2J2x8U22MHhnQTiqoqXduyuY0yXMyvTiRlJ3LtVbgOEE\nb8Xa9UtUAV1GLK5Bttbl2vPjih28NRkCBNQ1heEXzMBoN2Zpj4dRJAcgSaK5VCfuh2EIaFWb8ceP\nVgEAtP2DPB9WA4OuTRosEVenOMOqKPcrIiL2Bh3OnLcVOBUVMdnMumvh3bIwOXifNOxYs2DKZjVZ\nEcvMPa6T23C650EXTVqAG2dcA3XXqMTx6e75bxHlYiLrtT+c7WAfMMzXtocYJUl0/qO3RVRA9SH+\n3ymIb0/sya3IImCIiG01h6ztna38kh/2SgB7HXZYi2B8+egOGzKU+q3aASXmbCBiaHLK5TElUuLi\npegAMp+ja2bivuN+2GlbSltxkeL5MEquDp9dMxNzhx/f5Z/dZUnLxlQtfebt90nQNHPY3C4sTCdd\n5uS+v7Pg3dDazeCdIvPuK+5h855edOVL5q2kec+pMu/OCtbc2/Wm09MLou5i8M6C6x96Ddc/9Jrn\nPrt6Fk7w9gZ3Q9CgR4LmGlZRcTbZ8CsSSopkCJKOPfti5iYMgt061GroIgnOhhRtERXhqIZia3/h\npqg51xpWwwhamar5QWj2zd7eth0fNXyKQfIIaPW1GDwgEVA9OyxZhg8yX3dAqd8pShs62Oc0QnEy\nb6dtpwpVtIbsFe/8kXsI8VuHXdlh/99xQ8xWlfb/PfcmAYD3P+DQkhqMLR8FGCnmvGF1RLOqdodV\nlzgZROyzqYhvH4ehhrk8ys4AZVFwisbs4VmtvhbaPnP4fNSQUqdHtxEtcvrFA4Bf9jkdLY32xEhA\nqsy3LJAI3s7fhC55AtL/G30Kjhx8OAJS9pbq2PPcQyqDnkrsQTnaaMHdrEiWBOiG4azEUBTJO2yu\nmHPebRHVHJXqJCBJXVjDW2o1jXFn92NqzIu5IQO6N8ztVyQU+eU++2B3S3Vx0l2DrL+T0mJfhkf2\nrXS/d/v3ka63eTK7F7zYyd9YV9eJ97b+ffmUp+xCG8MwnA+WRPA2/9XQMXhDC0DdNQbV46LYGfkC\nsLbLnDV5IF54C04wcipXreCtSCIgxWFEfE5L1XJ/KRoANESbMArmnLddLFUe8mPP/nZz/tgXRZEc\nwHdnX463yxow3bUcZXz5mMTxqTK+MuJknHTUZLz7aT3iqo6nt5hrzysrJWyPxc1kU1NwzNQhEMsF\n/CeyCZKswfBFYBhCh0zbfXt8xRgIgoAbzp0Gnyxi9/52zJo4CDe/9ifnfert3jluWRKwaOL52NL4\neYcLjbISH5paDRiGGfyrS8px+lEjURr0YfaUwSgOKLj27EPx8z9thLprDIQae7g8kXnbRU32nuTj\nhpbhzGNGYV9TBDPGV+FnqzbA3GFcgN40EKK1I5tn2Nx1wTFj0FQkq7CaqMiDvsCGfebPOe+YyZ6i\no1NGmu1qX3rjC+e+AaW9u2/wLRdMx0dfNGDK6AGIxTV8/YQxB1Qo1VPupZR2sLH/litCfs8oix3I\nW9oT+5ink/yhe+uF0xGNe7OpMTVluPTUCc4GGwBw8hFD4fdJmDHeu/NWV10wd5wzrN/XejN4H35I\nFRbOH4+jJg/utdfMhuRs+vuLDkdTa+Iz1/130dn5GTE4hEtPnYBDhqcfteqrCzQG7yxStcSmCppm\nz3mbHxya4A3eEPREYxIkdtzy+yRnq0l7DbOdeTt7RUsCdCEOQwtiZ521UUdxJT4PA/sjDeZuZLrq\nZN5V5QEzeFsXEjXFQ1CsFOG4ad4sa3jpUAwOVmN3+15EP5yFY4+ag1DQh+Om1WD9B4lK7WDQgICw\ntbm9gLOPHY09cRn/ec8cNheUKBD3IVDi/aAt9lQrm+996hgzCE8YYfX/1v3QxXarUKxjRe+RVYfj\nyCGHdzj35SV+8z+rIQCCgepQGfyKhLlHJPp6H35IFYJ+Ge1R1cmU7Yst93CsnXkfOnoAJo9MVH+7\nq5zj28cDhoihlRVQRNmpcTDCJdAjRThhzHTPHL1znEWJC5KdrealwNETRnV4HODNEFJ1EOuJytIA\njp4yBIBZiX3aUSMyPKN3uTNve5jX3rSnqjzgyYrt77dFVFRXdF44l5xVpbsYOW5ajee2KAgd7jsQ\n44ZmnqLIld4M3pIo4sQZQzM/sI8lz0PbIympZJpKyPR30JWitmzo+zGdg5j7Cl/Tra+tYXNnj2Xz\nljlfavfztpc7WTtuOXt0W8HSZ+/HbDdOkTSn4GxHnVn1XVtqZtANkUa0W/PRRdY+t1X2jkuy+bo1\nJemvom8+4jpEP5wJI1zq3bQ+oCSGyJUwRH8EmpUZK7LobK0nyCoEXxRGzO/Zlxfo2s5DJa3mHLPe\n2DED6uxDKRS06wLMoFDiSz386QzJG4bntuyZ844797l55v00H+LbJmGYPsN6Qet+Q0R0w3E4b/xZ\nKX9+kc/nFA8CZuFdukpud/FVLqtac8Gdedvnede+xI5x7mFz9+890/RBbwaufCX1g6H7XMvlUHa6\nTaeyrfB+qzlkL7sCEsPmguBu2WmxN9+wMm87wxZE1QreiblQAAiIAc9rGFYWb2gydlrBe0SlOV+8\nP9Jo7kcNIGgFVLvNYeyzQzEiNAynjpyb9j0E5IDTKtL9HyIYkJ3g/Z+ItZtYuGPwhq/dXI8eD3To\nhdyV4F0ePgSRjXMQ+++UDt/r7EMped/idEt07Kvu5P9/dntUIJF5JweCVEU79lW49+VStwwFzGVP\n0Q+OcgJ4mb/jDlm2rhTP5Bv7nRquM2af50TmXeQ59+7fe6bCqT76XO1XCvECJpdD2U5ilmMcNu9F\numF45lI8mbfmLVhzb7cJwargtoKz0SHztrqLWXPefsneDMMM3ppgXQioMnZY2/UNG1AJn6hgQ/1m\nJxjYa4ZDRebws948ELfMPK/L7y8580bS7kh24xdFFlFkWM1gfFZns5i/Q1WwLMo4YegcDAqmn1eU\nZbFDoZqts0KT5Cvv9MHb/Dd52FwUEsHbnitLfs1Uy4DsD8p0u8h1PE4RRqzIXG5WuRdiJzsu2Mv4\netJoo78RBAGGYSRl3lbw3tcOvyIhFFSSNjFxZ96dz3l39fdwMGPwzi7nsz3HDp5PgT62e387rrjn\nn/j7267+1/HEsqrkgjVB6ph5G9awuW7vNuWLwO9zZ97mtZZTdWy9hu7KvJvaYigt9iHgk+GTzCD9\nft0mAMCospEA9OeBHQAAIABJREFUgPJQ9ypFk5tcOHtuW/RwCLIkmPv/Wpm37rOat8T9Kfv+njv+\nLBw39Oi0P7OzZRzJnbHcyoq9c8IBOXWBlz13XGQdWyITTPS/brcadSRn+ikzbyl1Jp+JfcEW19MX\nOdmvOXxQ560h84m9JMtd4S675rQHlgc6jES4f++ZMu+DbXqhOwrxHOTygqWvLqaZefeStz7aCwBY\n+ddPnPvcm444QyuiO/M2AAiJPautYFgSrwX8m6AM+wQ++WRnztvOvINyEIh3zLxrKspR7h/gVMi2\nurbpBIDRZWYR0qSRlTh99ghMH9e1Stprzz4UX9a3ej4EKkJ+nHroNLypbcWYkrF4+8P9MNpKofgT\nFfHmkjd7NzJft1oHdjZ3lep7t1wwHe9vrcfXjh8DUQT+ZT9WTP2zrz17Cl5Y/wVOn20VaLlesqqi\nCKOGhPD5LnP0IPkDIdV8qzNsbkVaWRJxwdxxad8DAFx51hSsa9iBrZFdnuHjZGcdMwqqpuOsY1IX\ntOWj75w/Df/3n+04+fBEEdSx02rQ0h6Hbhg4ekqiHuPCueMQ8MnY+mWip26m4D24MohTjhyOKaMq\nO33cwcyvSDhzzshO29MebIr8svmeh6R/zxfPH9/lTUk6M2N8FU6cXotjpw3p8WsdCAbvA7S18b+o\nC9fjqCHezQXswCYEWiGW1UPbMwLRlJm3GagF0XA2FnGG0q3g7YsOQoV/KBqKdwBSvMOcd5EcgBAX\nnNakLaq5DehXpo/F7JpEj+sLJ3wNL3z+NzRGm5znAeaQ8NeOTywDy+TwQ6pw+CEdA/3Xjp6Mq6uO\nwsdb67B+rbkft/sqNCgH0BSzh/SVbgZvMem22XMaSD1sPmFEhVOpfv5J4/Avc5dMKGLq4dXqiiAu\nPTWx/lpAYthbFARcd85UfPfnr6c8lmCKLlPJAf74aTU4cXpth8e5nXncGEzYfiZ+uWEfFhxydtrH\nBQMyFn7lkLTfz0dDBhTjklO869/H1pbh21/vuKzOXimwbXeLc1+mYXNBEHDeiWN74Ujz21ePHZ35\nQQeZTO/5pF6qmpclsU/+XzJ4H6AH3vkFAGDW4Bmebln2XHdgqtmcJdJa4Q3emrdgDYCZfetyIhu3\nhs1jcQ2KHgREQBOirszbqjZXJBTFi9BqBe9P2z+CKIiYPND7ITin5kjMqTkSb+95DwNTNFzpLe4P\nUPeSnqASRFMssZtXd7bLSw6YPlmCqplDy501TrAdWzsbr+5ch9HWlEEmiepz89+yksQUQ3Kmn7pg\nzTts3tWGVhWBciyddWPXHlzg3Bdt7o0/iAoJ//K7SdVVZ04ZAJKnWARRSxo2TypYgznsbcQDEKwm\nJPYcciSmQdB9ZvAWI4iq1lIxe523LCIoF6FNaoXgb8OeyJeYWDnes4mF2+GDDuvRe83Ep4hmP2rd\n8GTe7mpyo7uZd9J8kt8nOXPQXWn1eN74s/DVMachIHdvXbS3mYP3WFIOm9tz3vbwd+FNN2ad5ClY\n40cYFSYWrHWTmlRY1CGQGELSsLm9zjuRedubeiSGzc3gFo1rEKyGJHEjirC1TttemuWTRQSVIkCO\nQRpgNvaYOWh6z99UNwmC4HyIuueQPOuVXZttHIjkbPdAd0USBfGAAnfyum+3mKp5bqfaijIx523/\nfEbv3uYtWOvfG2QQZQuDdzfFda3zBwg6/vi3T7Hps30AXJm36FoTaFecJw2bb9vdgi92mvPcUSOM\ntri53tVu0qLIEkqUIATRgFS1A7IgY2rV5J6/qR6wP0QVxTtsbjNUxbOTU1clD5tne79cZ847xfda\n2uOe26kL1rpXbU5dx8ybiMG725Izb7ufucMKyA88/T6AdMPmKgZVFGFghbWHtWvplb0dZtQIO3tx\n25m3IotOYBT9EQwtHppoitJHjpo0CANKAzh8fLVzX1BJtAOVoXSrjaA7eI8cHMKF88b37EAzSZEo\nf+/iGZgwvByzJ3v34lZkEbMmVnsKopKHzZl4976JIyowqDKISSMrUFHau21iifIFL1u7SdW9WVgs\nufuVNY9tf3inK1i7Y9FMPP3uK3izHYAuosgvIRzVnK0129VE8LaboiiyiGI90XQk5Ov7db9nHjMK\nZyYtYXIPm/vl7q0td+/zfPslR2R9swdnnbfr1zRuaDluuXBGx8cKAq4+y+z89vQ/twBwVZsbicdQ\n7xo3tBx3XXVUXx8GUZ9i5t1NquEdNjdbV7rms63MuzJkZsTJvc0Bs1GLTxFRVGT9GgzRqdy2s+zW\nWBva4+3wiT5nWN0niyj3JdYvhtIUqvU1d8FadyrNgUTBmiSaLUaz3Xwh0XGte+Pe9pJBnfVqRJRF\nWc28V6xYgffffx+CIGDp0qWYOjWxdnPlypVYs2YNRFHElClT8P3vfz+bh9LrkofN46ruZNsAnK8H\nWMN6qea8BVmFJIoIBqzgrZutIOubIs6weVu8De1qGAEpALs1hSKLKJMSwbs0zaYbfc09593duWq7\nOMkO2tnecEBI7pd6gOSkJi3MvIkoG7KWeb/55pvYtm0bnnrqKSxfvhzLly93vtfa2opHHnkEK1eu\nxBNPPIGtW7fivffey9ahZEVyG8u4qnn7lVvBu9Rqv6m72qMaqnnNJPniePHzv6MdjQDMOe9ia39i\nSfdBgIDWeBva42FnO0/ALFgr8yeCd1mgf3ZOch9z8qYkXWVn3nZGm+3t9xLD5t2M3slLBhm7iSgL\nsvZJuG7dOsyda+5WNWbMGDQ1NaG11exzrSgKFEVBe3s7VFVFOBxGWVn6/Vb7Ql1jGI+t/djZDjJZ\nqszbvVOY0/LUCgLujUnsrFoYsAN/+XwtXtn5uvVY0Wk6IUkiipUgGqNNiGgRTxaryKI3ePeDOe9U\nZDExsNPtzFtK7K8N5K5Pc7eLxa0n9tU2gURUGLI2bF5fX4/JkxPLlyorK1FXV4eSkhL4/X5ce+21\nmDt3Lvx+P04//XSMGtV5v+aKiiBkuXeXCVVVpZ8rXrHyHWzZ3oiyUABXnNVxO8rikOJ5viCJiXXb\ngJN5y4qEqqoQJFkCYEAQAD3uAwLtHY+nrBhl1hy5IokYXl6DD+o+BQBUliSC9eDqEAJFiYA9fNAg\nVA3su3nvdOfRCNYC7wB6OIjSEn+n5zudygpzSkCWRef5AZ+E8cMruvV6mVx46kT86JE3cN68Q7r1\n+iWhAKqqQrj6nKn45aoNOHXO6C69TjbeS6HhOewdPI89l4tzmLNqc/cwZGtrKx5++GG89NJLKCkp\nwSWXXIKPPvoIEyZMSPv8hoaOwa4nqqpCqKtrSfv9fY1h69/2lI/b19CCOiVxf2tbzDNsPn54CB/s\nBMLhOOrqWtAeiSWK1TQZhi4msnPLVacfin+tM3+uKAoY5B+ED2AGb9lINKNoaQ5DiyZ+dVq72Ol7\nyabOzqMAH04oPh8vvl0HjDO6dYzhtqj1WnCe/7MbjoMgICvveVRVMX57y4kQRaFbr9/cHEZdXQtm\njhuIw7v4Opn+FikznsPewfPYc719DtNdCGRt2Ly6uhr19fXO7b1796KqytzcYuvWrRg2bBgqKyvh\n8/lwxBFHYNOmTdk6lG6xLzbSjdJ2GDbX9KRtPs3MW7NeJ2q0QvBHrBcXALXjdZMsys7wuiyJqA3V\nON9zL7tSJNFTCFWi9M9hcwCo8g0GNB/8Svf+1Ox13u4qc9GqPM+W3hqaL8StGIkoN7IWvOfMmYO1\na9cCADZv3ozq6mqUlJhBpra2Flu3bkUkYgazTZs2YeTIkdk6lG4xUqzTdY8edChYi2uA7B42N7Nq\nOxjvqFqDwNRXrRcSYcQ7NpdQRBmqtaRMEgUMK0kEb/dabrt/+IiQuctSd/t254IdfANK9wZ5ZDm3\nc91ERPkga8PmM2bMwOTJk7FgwQIIgoBly5Zh9erVCIVCmDdvHr7xjW9g0aJFkCQJ06dPxxFHHJH5\nRXPIvdRH1VX88aNVmO3aBlQ1OmbeYlFiqMQQzO87Veae1xZgtJVBLDYff+GEr+HThs9RVTQQmlYH\nwCxYqykZjONqZ0MWZcypmYUnsN45JgD47uGLu70eOVfsgjNfN1qjAole6WKWq8x7C+vUiCgXsjrn\nfdNNN3luu+e0FyxYgAULFmTzx/eIu8nGxvoP8cbut/HG7red76tJvc1jqg6xsinxfGgQBQGaYeCL\nvc3eFzdE6K3lQPUOAImtO4FEm1VZFCAKIs7vZH9nScxun+/eYGfe3a02l6zny3mSeff3iykiOjjk\nRzrTBxKZd+pGG8lz3jEtBiHYAr3NrArXoEIUBei6gR/8YZ33yboIvS310rhZE83+2ccfVpPy+/mm\nImQO6Q8o7V7v9UTm3b+D9xGHmPUcIwf3zzX3RHRwYW/zDARBgF/s2Jc7ntzbXG6EIBjQWiogBJuh\nG6qzx7Wn8xoAASLu/8ZXsOaLCMaVj/Z878hJgzC2tgyVKTZc+NkNxyY6teWJMbVluPvq2RhY1r3g\nbQ+79/fg/c2zJuO8ligGlhVlfjARUQ8xeKfhDJsLqbt6JQ+bq4K5lE2PBiHpkpN5R2Kad/03zD2m\ny0sCWDTp/JQ/e0CaQJevexdXl3c/oLl7m/dnkigycBNRznDYPI3EUjEBmqF3+H6HLUFFa9vOmB/Q\nRWiGBkkUsLehvUPmDZ2nvavsXuH9PfMmIsolRpE03FXDWlKWDXirzQ3DgC5Za7jjPhi6BNWIQxIF\nGAY6ZN727mCUmZ1550vBGhFRLjCKpJEp845riYC8e387oJidwIy4HzBEqIaayBalpODP4N1lSp7M\neRMR5RKjSBruOW/N6Dzz/vmfNkFQYgCAUn8I0BKZNwAIHQrWGIi6yqdIkCUBRT6WZxAR2Ri803A3\nadFTDZu75rwjMRWCEoUiKvjhJbNRO6AUcT0OwT67ycPmev9fn91fyJKI755/GM4/aWxfHwoRUb/B\ndCYDM/NOVbCWCOiabkDyx1DmC6G02I/yYDF2RXRIkpW+JxesaflZNd5XDhle0deHQETUrzDzTkN3\nNWlJOeftWuet6ToMKWoOmQPwS+YabdGa6xaS57w1XjMREVH3MXin47RHFVLPebuGzXUhBggGQtbu\nXgEreAv2RiVi0rC51rHpCxERUVcxeKdhrxRLV7AW01QnOzdEs9K8WDG37fRbu3wJaTNvDpsTEVH3\nMXhnIKYpWPt8dyN+9Zy5B7kmmpXmQTt4S1ZmbQftDpk3h82JiKj7GLy7INWcN0QNb31sbt9pWMG7\nWDaDtzNszsybiIiygME7A90wUg6bu7umGZJZvNZh2FxUARgQ/OGkF2XmTURE3cfgnYFupMm8reBt\nGAYMKXnY3NoRTFQhDdoGsbjZ05iFTVqIiKgnGLwzMAwj5Zy3ORRuQDcMCLKdeZu7StnD5pBUSKX7\nAQCXTlqQk+MlIqKDH4N3BuaweYrMGwAkFf/e+SaU2q0AgGDSnLchqBCKWmDEFVQFB+bkeImI6ODH\n4J2BYaRYKmavAZdUPPnpaufu5DlvTYpADIRhREKQRc5zExFR72DwzkDXOxasybDntL33FyctFYvI\n9eY3wqEO+38TERF1F4N3BobRcT9vUU/MaeuRIud+RTSXgNnD5mFxHwBAiIYwpHgQJMOH+M4xEFiv\nRkREPcDgnYFhGNCT5rwF3cysBUmFEU0Eb8GKyvawuV1ULmh++CQfZukLoe4cl/2DJiKigxqDdwYp\nC9bsJiuSCgjmBPjo8Hzn285SMYtgPV4wmHITEVHPMXin8NH+TyH4zMYqqQrWjLgVjK3gbRgCSvUa\n5/uKKENxFagJujeYExER9QSDd5KWWCt++t5v4J/2CgAz8/7vnibPY1S79kxSIQgGYAiQRG9WHfKF\nnK9FnbuIERFR72HwTtIWbwcAp6hM1XTsaWjzPMauX7MzbxgCxKQzGfKVOF+LOnuZExFR72HwThLT\nY57bcVV35rVtumadNsnsXW4Gb++pLHUFbwHmELoB7+sQERF1B4N3koga8dyOxXVAMAvWoh8dgaJw\nLbR6c37bnXlLSeu/Qkpi2Dz5e0RERD3B4J2kPSl4x1XNybz15gEI7T0aajQAABCUqBO8haQzWepP\nBG8hKXgn3yYiIjoQ7NmZJBz3bt8Zs4bNDQMABLRH44Dqg6EqEAJt1lruFAVrimvOW2SwJiKi3sPM\nO0lY6zhsLgg6YJinqj1ilprr4WIIgTAEUYNhCB0CtGepGDNtIiLqRQzeSbyZt4G4pjtD4wAQjpql\n5kakGIJgQPBFAUPskHlLouR8bX+L5WpERNQbGLyTeDJvwUAsrgGCDlmUMLqmFLo5fg4jXJx4nCFA\nTMquJw+YABgCYtsmdPgeERFRTzB4JwnHXcFbVJ05b1EQ4VcS2bQeDSYel2LYPOQrwYTGi6DtGclh\ncyIi6lUM3knCqmvYXNSdanMR3uAN3fV1ig5rAKwit8SwORERUW9g8E4Sdi0VEyTVWectQoLf5w7Y\n7lPXMfMG4AyxC4zeRETUixi8k3gzbw2abkBwhs1dp8u9Q1iKOW8gEbyd2M2KNSIi6gUM3kncTVoE\n0W5ibgZvn2vY3NATpy7VUjEAmDKyEgBw2NiBnvs5BU5ERD3BJi1JYpqrt7lkB28doiBBkdJn3qnm\nvOfOHIZDhldgWHVJh+8RERF1F4N3kqh7Y5KkzFv2BG9vIE+VeYuCgBGDQx3uJyIi6gkGbxfDMBDX\n4s7txLC5DkmQkoK3O1innvMmIiLKBs55u8R11bttp5TIvCVBhCy5h8q9WXiqYfNkrFcjIqLewODt\nYs93S4JZmCaIGiCqEATAJ/o9mbe7YC3dsHk6zNGJiKgnGLxdolbwDohW9zRRg+Azq89LlNABF6wl\nG1xpvu7omrLeOWAiIipInPN2iVvFakVSEG1aCyCp5sYjAEqVEGQxdcGakWadd7KTZtSiOCBj+riB\nGR9LRESUDoO3i5N5C2aGLEgaBMXMvEt9pZCTsm33110ZNpclEXMOHdJ7B0xERAWJw+YuMavS3O8M\nm6vOsHmZrzT9UrE07VGJiIiyIWPw3rp1ay6Oo1+IWcPmPqMIgJV5W8Pm5f4yyHLP5ryJiIh6Q8bg\n/e1vfxsXXHABVq1ahXA4nOnhec0eNldgBm9zztvMvCuKyrwFa8jc25yIiCgbMs55P//88/jkk0/w\n4osvYuHChZg4cSLOPfdcTJ06NRfHl1N2gxZBV2BoIgRJBXwRGLqIkFKMqBRJ/URD5LA5ERHlTJfm\nvMePH4/rr78eS5YswdatW7F48WJcdNFF+O9//5vlw8stO/OGIQG6DLG4GWKgHdq+wVCUpA5rbhw2\nJyKiHMqYee/cuRN/+tOf8Je//AVjx47F1VdfjWOPPRYbN27EzTffjGeeeSYXx5kT9py3oUkwNAmC\nYt6v7hwHRfL2Nvd0WwOYeRMRUc5kDN4LFy7E17/+dfzhD3/AoEGDnPunTp2aceh8xYoVeP/99yEI\nApYuXep5/K5du/Cd73wH8XgckyZNwp133tmDt9E77A5rhi4CmnlqDEOAEQtAlrztURXZvc5b5Jw3\nERHlTMZh8zVr1mDkyJFO4H7iiSfQ1tYGALj99tvTPu/NN9/Etm3b8NRTT2H58uVYvny55/t33303\nLr/8cjz77LOQJAlffvllT95Hr7CXihmqBEO39u6OKwAESJLgqTZP7rbGYXMiIsqVjMH7e9/7Hurr\n653bkUgEt9xyS8YXXrduHebOnQsAGDNmDJqamtDa2goA0HUdb7/9Nk466SQAwLJly1BTU9OtN9Cb\n7DlvXXNl3roMSTSryd0BO3nZGIfNiYgoVzIG78bGRixatMi5fdlll6G5uTnjC9fX16OiosK5XVlZ\nibq6OgDA/v37UVxcjLvuugsXXHAB7r///u4ce6+z57x1VYSzFEyTnEDtnvNOzrwZvImIKFcyznnH\n43Fs3boVY8aMAQBs2rQJ8Xg8w7M6MgzD8/WePXuwaNEi1NbW4qqrrsLLL7+ME044Ie3zKyqCkGXp\ngH9uZ6qqQp7bwqfmMUqiD7D28jZ0CQFFQlVVCEUlifddFFDgXMIYAgYOKO7weoWiUN93b+I57Dme\nw97B89hzuTiHGYP39773PSxevBgtLS3QNA2VlZW49957M75wdXW1Z7h97969qKqqAgBUVFSgpqYG\nw4cPBwDMnj0bn376aafBu6GhPePPPBBVVSHU1bV47mtpN39GuN2A4Lf28tYlSKKAuroWxOJa4sGu\nixEYApoa2+EvwOQ71XmkA8Nz2HM8h72D57HnevscprsQyDhsPm3aNKxduxbPP/881q5dixdffLFL\nmfecOXOwdu1aAMDmzZtRXV2NkpISAIAsyxg2bJizTnzz5s0YNWpUV99L1tjV5mocTuYNXXIqy93z\n3JJnqRg7rBERUe5kzLxbW1vx5z//GQ0NDQDMYfRVq1bhtdde6/R5M2bMwOTJk7FgwQIIgoBly5Zh\n9erVCIVCmDdvHpYuXYolS5bAMAyMHz/eKV7rS1E9BlmUoaqAPedtqDJ8VtB2B2jJ9bVhCDBARESU\nGxmD9w033ICamhq89tpr+MpXvoLXX38dP/jBD7r04jfddJPn9oQJE5yvR4wYgSeeeOLAjjbL4loc\nPlFBXNUhfjEdyvCPEd5+CJSqjgMU7gI1QTAYvImIKGcyDptHo1HceeedqK2txa233orHHnsML774\nYi6OLeeiWgw+yYeYqkNRy1C291hA9UNJUSgnJbdKNRi+iYgoNzIG73g8jvb2dui6joaGBpSXl2P7\n9u25OLaci2kx+CQz81ZkCbpuBmR3NzWbuymLJAuoCAVydpxERFTYMg6bn3XWWXj66adx7rnn4rTT\nTkNlZSVGjBiRi2PLuZgeQ7lYigZVQzCgQNV0AHDmvN3cwfvsY0alDPBERETZkDF42wVngLmka9++\nfZg4cWLWDyzXDMNATIvDJ/kQ13Qosoj2iAogc+bNGW8iIsqljOmiu7vaoEGDMGnSJCeYH0xUXYUB\nwwzeqg6fLELVzcxbSbEVqOgJ3kRERLmTMfOeOHEifvKTn2D69OlQFMW5f/bs2Vk9sFyLWq1RFVGB\nqhlQZBGaZs15KykK1kT3rmIM30RElDsZg/eHH34IAHjrrbec+wRBOOiCt92gRbY28VZkyZnzTpV5\ne4fN9RwcIRERkSlj8H788cdzcRx9zt4ONBG8RSd4y3IiUFeVB1DXGPEOmzPzJiKiHMoYvC+88MKU\nc9wrV67MygH1leTM2yeLUK1hc9k1RL78yqMQi2t4+p9bnftYsEZERLnUpQ5rtng8jvXr1yMYDGb1\noPqCvZe3CHN+293HXHb1MZcl0dkaVI8UQQyEUSQX5fBIiYio0GUM3rNmzfLcnjNnDq688sqsHVBf\nienmsLmExLC5rUM3Nfs5Hx+B4NCdOPb4g2v+n4iI+reMwTu5m9quXbvw+eefZ+2A+krMybzNU+Ju\nzCKLqZbGGTCixZD3TIFPUlJ8n4iIKDsyBu9LLrnE+VoQBJSUlOC6667L6kH1BSd4GzIA3ZN5y510\nTzv4VrwTEVF/lzF4/+Mf/4Cu6xCtoq14PO5Z732wiOnu4B3zbEYipxk2JyIi6gsZo9LatWuxePFi\n5/ZFF12El156KasH1RfsgjUY5ilxr+2WUgybc3UYERH1lYzB+9FHH8WPf/xj5/bvfvc7PProo1k9\nqL4Qt9Z5Q7fmvBV3wVr6wfGDsVUsERH1bxmDt2EYCIVCzu2SkpKDMmBFtCgAQLCCtzvzdq/ztjHx\nJiKivpJxznvKlCm44YYbMGvWLBiGgVdffRVTpkzJxbHllB287cxbUdzrvDnnTURE/UfG4H3bbbdh\nzZo12LBhAwRBwJlnnolTTjklF8eWU1HVCt6anXm7C9YOvpEGIiLKXxmDdzgchqIouP322wEATzzx\nBMLhMIqLi7N+cLlkZ96GZgZt91KxUNDX4fFVZQEAQG3VwXUeiIio/8s4Hnzrrbeivr7euR2JRHDL\nLbdk9aD6gp1566oZvH2yiOVXHolLTjkEIwaHOjz+lCOH44KTx+HKMybl9DiJiIgyBu/GxkYsWrTI\nuX3ZZZehubk5qwfVFyJaBIqoQNPM24osYsiAYhx/WG3KxyuyhHkzh6XMyomIiLIpY/COx+PYujWx\ng9bGjRsRj8ezelB9IaJFEZD8iKvWHt6ddFUjIiLqSxnnvL/3ve9h8eLFaGlpga7rqKiowL333puL\nY8upqBpFQPYjxuBNRET9XMYINW3aNKxduxarVq3CkiVLUF1djWuuuSYXx5ZTyZm3z9UelYiIqD/J\nmHm/9957WL16NV544QXouo4f/ehHmD9/fi6OLWd0Q0dUi8Ev+xHXmHkTEVH/ljZC/eY3v8Fpp52G\nG2+8EZWVlVi1ahWGDx+O008//aDbmMTuax6Q/IjHzYo1Bm8iIuqv0mbeDz74IMaOHYs77rgDRx11\nFICDt4931FrjHZADaGPmTURE/Vza4P3yyy/jT3/6E5YtWwZd13H22WcflFXmABCx1nj7JbNgTRBS\n7yRGRETUH6RNL6uqqnDVVVdh7dq1WLFiBb744gvs3LkTV199NV555ZVcHmPWOZm3VbDmk6WDdpSB\niIjyX5fGhmfOnIm7774br776Kk444QT8/Oc/z/Zx5VRYjQCAWbCm6hwyJyKifu2AolRJSQkWLFiA\np59+OlvH0ye8mbfG4E1ERP0aoxSA9ngYAFAkB9DcHkdx4OCqpiciooMLgzeAdtUM3qLuRzSmoao8\n0MdHRERElB6DN4D2eDsAIBoxT0dVeVFfHg4REVGnGLwBtFmZd6SNwZuIiPo/Bm8kMu+WVvM2gzcR\nEfVnDN5IzHk3NZnd1TjnTURE/RmDN4C2eDsUUUFLmxm8K0L+Pj4iIiKi9Bi8YQ6bFytBRGPmpiQ+\nhduBEhFR/8XgDXPYPCgXIRLX4FNEiGyNSkRE/VjBB2/d0BFWIwgqRYjFNfiZdRMRUT9X8ME7rEZg\nwECxHEQ6NoGoAAAYmElEQVSUwZuIiPIAg7ddad6so6E5Cr+PwZuIiPq3gg/eMc3co3zL9jYYADNv\nIiLq9wo+eMd1M3gbunkqGLyJiKi/Y/DWVfMLBm8iIsoTDN5W5g3dDNo+peBPCRER9XMFH6ni1pw3\nDPNUBFiwRkRE/RyDtzPnbWfeDN5ERNS/MXhzzpuIiPIMg7cz583gTURE+YHBW/MOmzN4ExFRf5fV\n4L1ixQqcf/75WLBgATZs2JDyMffffz8WLlyYzcPolDNsbhWsiSI3JSEiov4ta8H7zTffxLZt2/DU\nU09h+fLlWL58eYfHbNmyBf/5z3+ydQhdkrxUTNP0PjwaIiKizLIWvNetW4e5c+cCAMaMGYOmpia0\ntrZ6HnP33XfjxhtvzNYhdEksqcOapht9eThEREQZZS1419fXo6KiwrldWVmJuro65/bq1asxa9Ys\n1NbWZusQukR1qs3NzLu4SOnDoyEiIspMztUPMoxERtvY2IjVq1fj0UcfxZ49e7r0/IqKIGS5d4vJ\nqqpCkD63bugi5h85Al89aTwkznsfkKqqUF8fQt7jOew5nsPewfPYc7k4h1kL3tXV1aivr3du7927\nF1VVVQCA9evXY//+/bjooosQi8XwxRdfYMWKFVi6dGna12toaO/V46uqCqGurgXN7ebrGrqEuTNq\nsH9fa4Znkpt9Hqn7eA57juewd/A89lxvn8N0FwJZGzafM2cO1q5dCwDYvHkzqqurUVJSAgA45ZRT\n8MILL+Dpp5/Gz372M0yePLnTwJ1NqqvaXBILfuUcERHlgaxl3jNmzMDkyZOxYMECCIKAZcuWYfXq\n1QiFQpg3b162fuwBi7matMgSh8uJiKj/y+qc90033eS5PWHChA6PGTp0KB5//PFsHkannI1JdAmy\nxMybiIj6v4KPVqquWg1aBBaqERFRXij44B3T4xAMs4qdmTcREeWDgo9WcSt4CwJboxIRUX5g8NZU\nCKw0JyKiPFLwESuuxwFDYqU5ERHlDQZvPW4tEyv4U0FERHmioCOWYRiIaXFAl1lpTkREeaOgg7dq\naDBgwGCDFiIiyiMFHbxjWsz8QpcgcdiciIjyREFHLDt4G5rEYXMiIsobhR28rb7mhsaCNSIiyh8F\nHbHszFtn5k1ERHmkwIM3M28iIso/BR2xYnoi82a1ORER5YvCDt4sWCMiojxU4MHb3stb5FIxIiLK\nGwUdsRLrvGXOeRMRUd4o6IjlLBXTRQ6bExFR3ijs4O3qsMaCNSIiyhcM3gCgSdzPm4iI8kZBR6zE\nsLkERSnoU0FERHmkoCOWe9hcYcEaERHliYKOWFFnqZgEHzNvIiLKEwUdseJWhzWDmTcREeWRgo5Y\nMVfmrchS3x4MERFRFxV08I46c94iFLmgTwUREeWRgo5Yqq5CggRAgI/Bm4iI8kRBRyzVUCEKMgAw\n8yYiorxR0BErrschwpzrZvAmIqJ8UdARK66pruDNgjUiIsoPBR28VUOFYJingJk3ERHli4KOWKqu\nQrAybxasERFRvijoiBXXVQgG57yJiCi/FGzEMgzDzLw5bE5ERHmmYCOWqqvmF8y8iYgozxRsxIpr\nVvDW7cyb1eZERJQfCjd423t5W8PmLFgjIqJ8UbARy868DZ1z3kRElF8KNmLFrMwbmghBACRR6NsD\nIiIi6qKCDd6qlXnrugBFFiEIDN5ERJQfCjZ423t5G5oIRSrY00BERHmoYKOWXbCmaQJ8CivNiYgo\nfxRu8LaHzTWBmTcREeWVgo1acd0O3iIUpWBPAxER5aGCjVpxa85bUwXIYsGeBiIiykMFG7Xcw+ay\nzEpzIiLKH4UbvK2CNV0TITHzJiKiPFKwUcvpbW6IkCVm3kRElD8KN3jbvc11ETKrzYmIKI8UbNSy\nm7TAENkalYiI8krBBm9nP29dhMTMm4iI8kjBRq2Ys6uYBJmZNxER5ZGCDd5x97A5C9aIiCiPyNl8\n8RUrVuD999+HIAhYunQppk6d6nxv/fr1eOCBByCKIkaNGoXly5dDzOGSrYgaNb/QJBasERFRXsla\n1HrzzTexbds2PPXUU1i+fDmWL1/u+f4dd9yBhx56CE8++STa2trw6quvZutQUgqrEQCAocssWCMi\norySteC9bt06zJ07FwAwZswYNDU1obW11fn+6tWrMXjwYABAZWUlGhoasnUoKUXiZvCGJjPzJiKi\nvJK1qFVfX4+KigrndmVlJerq6pzbJSUlAIC9e/fi9ddfx/HHH5+tQ0kprEYhQLCqzZl5ExFR/sjq\nnLebYRgd7tu3bx+uvvpqLFu2zBPoU6moCEKWe2/f7XA8Ap/kRzsElJYEUFUV6rXXLjQ8dz3Hc9hz\nPIe9g+ex53JxDrMWvKurq1FfX+/c3rt3L6qqqpzbra2tuPLKK3HDDTfgmGOOyfh6DQ3tvXp8YTUC\nBQoAIBqNo66upVdfv1BUVYV47nqI57DneA57B89jz/X2OUx3IZC1YfM5c+Zg7dq1AIDNmzejurra\nGSoHgLvvvhuXXHIJjjvuuGwdQqci8QgU0QcALFgjIqK8krXMe8aMGZg8eTIWLFgAQRCwbNkyrF69\nGqFQCMcccwyee+45bNu2Dc8++ywA4IwzzsD555+frcPpIKxGUSGXAgAL1oiIKK9kdc77pptu8tye\nMGGC8/WmTZuy+aM7FddVqLoKQTffPoM3EVHfevnlv+OEE07u0mN/8pP7ce65C1BTU5vlo+q/CjJq\nRa0GLbv2xgBw2JyIqC/t2vUl/va3tV1+/PXXf7egAzeQw2rz/iSimcHb0MzqdS4VIyLqOw88cA8+\n/HAzHn30N9B1HV9+uRO7dn2JBx/8Be66607U1e1FOBzG5ZdfhTlzjsV1112F73znFvzzn39HW1sr\nvvhiG3bu3IFvf/u7mD17jvO6qqpi+fIfdHj+J598hPvvvweiKGDKlGm49trrU95n/5zRo8di1aqn\n0NjYiOnTD8eTT/4v2tvbcd11N+Ldd9/Gyy//HbquY/bsObj11u+ipaUFd955G9ra2lBSUoI77vgf\nXH75Rfj9759AMBjEhg3v4cknV2LFih93+5wVZPCOWsEb9rB5DtuyEhH1Z0//Ywv+89HeXn3NmROq\ncd5JY9N+/4ILFmL16qdx2WVX4pFHHoaqxvGLX/wWDQ37MWvWUTj11DOwc+cO3H77EsyZc6znuXv3\n7sF99z2E9ev/jT//eZUneLe0NKd8/oMP3oebb16KsWPH4Uc/ugO7d+9KeV86W7duwRNPrIbP58O7\n776NX/zitxBFEeeddxauvfabeOKJxzFr1myce+4CPPXUSrzzzls47rgT8dpr/8L8+afgtddewbx5\nX+nROS3I4G33NTc08+0z8yYi6j8mTpwMAAiFSvHhh5uxZs1qCIKI5uamDo+dOvUwAObyZHcXz86e\n/8UX2zB27DgAwO2335n2vnTGjh0Hn89crRQIBHDddVdBkiQ0NjaisbERn3zyEa644hoAwPnnXwQA\nqKmpxW9/+0vMn38K3n33bXzjG1cf+IlxKczgrSU2JQFYsEZEZDvvpLGdZsm5oChmD46//vUlNDc3\n4+c//y2am5txxRULOzxWkhLNu5KbgaV7fqpNsFLdJwiJxE5V1Q7Ht3v3Ljz11Er87ncrEQwGsXDh\nedZrSTAM3fNaY8eOw759+/Dhh5sxatQY+P3+zk9CBgUZtSL2piR25s2CNSKiPiOKIjRN63B/Y2Mj\nhgypgSiKeOWVfyAejx/Q66Z7/siRo7B5s7ni6a677sR///t5yvuKi4uxb5/ZbGzjxvdTvn5FRQWC\nwSA+/vgj7N69G/F4HBMnTsLbb/8HAPDcc6vw4ot/AQCcdNI8PPDAPZg375QDeh+pFGTwthlx88qH\nmTcRUd8ZMWIUPv74Izz00P2e+0844ST8+9+v4vrrr0FRURGqq6vx6KO/6fLrpnv+9dffhJ/97P/D\nNdd8A6FQKUaOHJXyvjPPPAf3338vbr75egwcWNXh9ceNG4+ioiCuueZy/P3v/4ezzjoHP/zhD3Hu\nuRdg06YNuO66q/Dvf7+G448/EQBw8snzsHfvXhx++MyenTAAgpGq6Xg/1Jvt5uJaHNf89hnojdWA\nIeKWC6ZjwojOe6tTamyn2HM8hz3Hc9g7eB57rrNz+Pzza7B79y584xvfPKDXS6Ug57wVSYHeMNi5\nzcybiIiy6Z57/gdffrkTd911X6+8XkEG72SsNiciomy69dbbevX1CjLl1HXvTAEL1oiIKJ8UZPCO\nxr1VjRw2JyKifFKQUSvWIXgz8yYiovxRkME7OfOW2B6ViIjySEFGrWjc2/mGmTcRUd96+eW/H/Bz\n3nvvHTQ07M/C0fR/hRm8Y0mZN+e8iYj6zIFuCWp7/vk1BRu8C3KpWMdhc2beRER9xb0l6PnnX4gV\nK36IlpYWaJqGG264GWPHjsP//u/v8cor/4Qoipgz51hMnDgJr776Mj7//DP8z//ci8GDzd4dfbEN\n6OWXX+VsAxqLReD3F2VlG1A3Bm+w2pyIyLZ6y1/w7t6Nvfqa06sPxTljz0j7ffeWoL///W9x5JFH\n4//9v6/i888/w09+ch8efPAXePLJ/8Vzz70ESZLw3HOrMHPmURg7djy+851bnMAN9M02oOeff6Gz\nDejixVfiZz/7VVa2AXVj8AabtBAR9RcbN25AY2MD1q59AQAQjZobSZ1wwsm44YbFmDfvFMyfn35j\nj77YBrS5uTkn24C6FWTwrgz54ZNF6IYBVTMgCgzeREQAcM7YMzrNkrNNUWTceOPNmDJlquf+m276\nHrZt+y/+8Y+/4lvf+iZ+/es/pHz+wbwNqOfYe+2V8sghwyvw1IrT8fBNJ+DXN5/Q14dDRFTQ3FuC\nTpo0Bf/618sAgM8//wxPPvm/aG1txaOP/gYjRozEZZddiVCoDO3tbSm3Ej2YtwH1nLNefbU8Iksi\nBEHgfDcRUR9zbwn69a+fj507t2Px4itwzz3/g8MOm4GSkhI0NjbgyisX4dvfvhqTJ09BaWkZDjts\nBm677VZ89tlW57X6YhvQ+++/x9kGdOHChVnbBtStILcEBbj1XW/heew5nsOe4znsHTyPPZd8Druz\nDWjy66VSkHPeRERE2dbb24C6MXgTERFlQW9vA+rGCV8iIqI8w+BNRESUZxi8iYiI8gyDNxERUZ5h\n8CYiIsozDN5ERER5hsGbiIgozzB4ExER5Zm8aY9KREREJmbeREREeYbBm4iIKM8weBMREeUZBm8i\nIqI8w+BNRESUZxi8iYiI8kxB7ue9YsUKvP/++xAEAUuXLsXUqVP7+pD6tU8++QSLFy/GpZdeiosv\nvhi7du3CLbfcAk3TUFVVhR//+Mfw+XxYs2YN/vCHP0AURZx33nk499xz+/rQ+417770Xb7/9NlRV\nxTe/+U0ceuihPIcHIBwOY8mSJdi3bx+i0SgWL16MCRMm8Bx2UyQSwRlnnIHFixdj9uzZPI8H4I03\n3sD111+PcePGAQDGjx+PK664Ivfn0Cgwb7zxhnHVVVcZhmEYW7ZsMc4777w+PqL+ra2tzbj44ouN\n2267zXj88ccNwzCMJUuWGC+88IJhGIZx//33GytXrjTa2tqM+fPnG83NzUY4HDZOP/10o6GhoS8P\nvd9Yt26dccUVVxiGYRj79+83jj/+eJ7DA/T8888bv/71rw3DMIwdO3YY8+fP5znsgQceeMA455xz\njFWrVvE8HqD169cb3/rWtzz39cU5LLhh83Xr1mHu3LkAgDFjxqCpqQmtra19fFT9l8/nw29+8xtU\nV1c7973xxhs4+eSTAQAnnngi1q1bh/fffx+HHnooQqEQAoEAZsyYgXfeeaevDrtfmTlzJn7yk58A\nAEpLSxEOh3kOD9Bpp52GK6+8EgCwa9cuDBo0iOewm7Zu3YotW7bghBNOAMD/z72hL85hwQXv+vp6\nVFRUOLcrKytRV1fXh0fUv8myjEAg4LkvHA7D5/MBAAYMGIC6ujrU19ejsrLSeQzPa4IkSQgGgwCA\nZ599FscddxzPYTctWLAAN910E5YuXcpz2E333HMPlixZ4tzmeTxwW7ZswdVXX40LLrgAr7/+ep+c\nw4Kc83Yz2B22R9KdP57Xjv72t7/h2Wefxe9+9zvMnz/fuZ/nsOuefPJJfPjhh7j55ps954fnsGue\ne+45HHbYYRg2bFjK7/M8ZjZy5Ehcd911OPXUU7F9+3YsWrQImqY538/VOSy44F1dXY36+nrn9t69\ne1FVVdWHR5R/gsEgIpEIAoEA9uzZg+rq6pTn9bDDDuvDo+xfXn31VfzqV7/Cb3/7W4RCIZ7DA7Rp\n0yYMGDAAQ4YMwcSJE6FpGoqLi3kOD9DLL7+M7du34+WXX8bu3bvh8/n4t3iABg0ahNNOOw0AMHz4\ncAwcOBAbN27M+TksuGHzOXPmYO3atQCAzZs3o7q6GiUlJX18VPnl6KOPds7h//3f/+HYY4/FtGnT\nsHHjRjQ3N6OtrQ3vvPMOjjjiiD4+0v6hpaUF9957Lx5++GGUl5cD4Dk8UG+99RZ+97vfATCnvtrb\n23kOu+HBBx/EqlWr8PTTT+Pcc8/F4sWLeR4P0Jo1a/DII48AAOrq6rBv3z6cc845OT+HBbmr2H33\n3Ye33noLgiBg2bJlmDBhQl8fUr+1adMm3HPPPdi5cydkWcagQYNw3333YcmSJYhGo6ipqcFdd90F\nRVHw0ksv4ZFHHoEgCLj44otx5pln9vXh9wtPPfUUfvrTn2LUqFHOfXfffTduu+02nsMuikQi+P73\nv49du3YhEonguuuuw5QpU3DrrbfyHHbTT3/6U9TW1uKYY47heTwAra2tuOmmm9Dc3Ix4PI7rrrsO\nEydOzPk5LMjgTURElM8KbticiIgo3zF4ExER5RkGbyIiojzD4E1ERJRnGLyJiIjyTME1aSHKN/fe\ney82btyIaDSKDz74ANOnTwcAfO1rX8NXv/rVLr3Gr3/9a4wfP97pZ53KwoUL8fvf/x6SJPXGYXvs\n2bMHn332GWbPnt3rr01UiLhUjChP7NixAxdeeCH+9a9/9fWhHLA1a9Zg69atuPHGG/v6UIgOCsy8\nifLYT3/6U+zYsQNffvklbr31VkQiEdx3333w+XyIRCJYtmwZJk+ejCVLluDwww/H7Nmzcc011+CY\nY47Bhg0b0NbWhocffhiDBg3CIYccgs2bN+OXv/wlGhsbsXv3bmzbtg1HHnkkbr/9dkSjUdx6663Y\nuXMnBg8eDEmSMGfOHM8exW1tbfjud7+L5uZmqKqKE088EWeccQYefPBBGIaB8vJyXHTRRbjzzjux\nbds2tLW14YwzzsDll1+O1atX469//SsEQcCePXswevRorFixAoqi9OEZJuqfOOdNlOd27NiBxx57\nDFOmTEFjYyN+8IMf4LHHHsOiRYvw8MMPd3j81q1bcc4552DlypWYOHEiXnzxxQ6P+eCDD/DQQw/h\n2WefxerVq9HU1IQ1a9ZAVVU888wzuOOOO/D66693eN6///1vqKqKP/7xj3jyyScRDAZRW1uLs88+\nG2eeeSYuu+wyPPbYY6iursbjjz+OZ555Bs8//zw++ugjAMDGjRv///bu2CW1MIzj+NcONQQRQi3W\nYnBsjDoSBFKNOVaEo0M4REO4HGyrKQin5ob+gDBaoiVyECEipakhWkKkQKFoiERPd5DOzYxLlysX\njvw+4+F5X97tx/PyHh7S6TSHh4eUy2VP3jKI/A/qvEU8bmJiAp/PB8DQ0BC7u7u8vb3x8vLC4OBg\nW73f78c0TQACgQBPT09tNZZlYRgGhmHg9/t5fn7m5uaG6elpAIaHh7Esq23d1NQUe3t7bGxsMDc3\nx8rKCj09rT3CxcUFDw8PXF5eAlCr1bi/v3fXf4xPnZyc5O7uzp2TLCK/KbxFPO7ztbJt22xvbzMz\nM8P5+bk7zOOzrw/Svnv28l2N4zgtQfw1lKE5y/j4+JhiscjZ2RnLy8scHR211PT19bG+vs7CwkLL\n90wmg+M4fzyXiDTp2lyki1QqFUzTpNFocHp6Sq1W69jeY2NjFItFAKrVKldXV201uVyObDaLZVnY\ntk1/fz/VahWfz0e9XgeaXf3HVb3jOOzs7Ljd//X1Na+vr7y/v1MoFBgfH+/Y+UW6iTpvkS6SSCSI\nx+MEAgFWV1exbZuDg4OO7L20tEQ2myUWizE6Oko4HG7r0IPBIKlUiv39fQzDIBKJMDIyQjgcJplM\n0tvby9raGre3t8RiMRqNBvPz8+6o1FAoxObmJqVSCdM0iUQiHTm7SLfRr2Ii8iOPj48UCgWi0SiO\n47C4uMjW1pb73/m/ymQy5PN50ul0R/YT6WbqvEXkRwYGBjg5OXHnE8/OznYsuEXk76jzFhER8Rg9\nWBMREfEYhbeIiIjHKLxFREQ8RuEtIiLiMQpvERERj1F4i4iIeMwvRph4T/csGFUAAAAASUVORK5C\nYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "metadata": { + "id": "HNqUFL4deCsL", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 4. Case study: building an RNN\n" + ] + }, + { + "metadata": { + "id": "YkC1k4HEQ7rw", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In this exercise we build and train a model similar to the RNNColorbot model that was used in the main Eager notebook. The model is adapted for converting and training in graph mode." + ] + }, + { + "metadata": { + "id": "7nkPDl5CTCNb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "To get started, we load the colorbot dataset. The code is identical to that used in the other exercise and its details are unimportant." + ] + }, + { + "metadata": { + "id": "A0uREmVXCQEw", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def parse(line):\n", + " \"\"\"Parses a line from the colors dataset.\n", + " \n", + " Args:\n", + " line: A comma-separated string containing four items:\n", + " color_name, red, green, and blue, representing the name and\n", + " respectively the RGB value of the color, as an integer\n", + " between 0 and 255.\n", + "\n", + " Returns:\n", + " A tuple of three tensors (rgb, chars, length), of shapes: (batch_size, 3),\n", + " (batch_size, max_sequence_length, 256) and respectively (batch_size).\n", + " \"\"\"\n", + " items = tf.string_split([line], \",\").values\n", + " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", + " color_name = items[0]\n", + " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", + " length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n", + " return rgb, chars, length\n", + "\n", + "\n", + "def maybe_download(filename, work_directory, source_url):\n", + " \"\"\"Downloads the data from source url.\"\"\"\n", + " if not tf.gfile.Exists(work_directory):\n", + " tf.gfile.MakeDirs(work_directory)\n", + " filepath = os.path.join(work_directory, filename)\n", + " if not tf.gfile.Exists(filepath):\n", + " temp_file_name, _ = six.moves.urllib.request.urlretrieve(source_url)\n", + " tf.gfile.Copy(temp_file_name, filepath)\n", + " with tf.gfile.GFile(filepath) as f:\n", + " size = f.size()\n", + " print('Successfully downloaded', filename, size, 'bytes.')\n", + " return filepath\n", + "\n", + "\n", + "def load_dataset(data_dir, url, batch_size, training=True):\n", + " \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n", + " path = maybe_download(os.path.basename(url), data_dir, url)\n", + " dataset = tf.data.TextLineDataset(path)\n", + " dataset = dataset.skip(1)\n", + " dataset = dataset.map(parse)\n", + " dataset = dataset.cache()\n", + " dataset = dataset.repeat()\n", + " if training:\n", + " dataset = dataset.shuffle(buffer_size=3000)\n", + " dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None, None], []))\n", + " return dataset\n", + "\n", + "\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "data_dir = \"tmp/rnn/data\"" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "waZ89t3DTUla", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Next, we set up the RNNColobot model, which is very similar to the one we used in the main exercise.\n", + "\n", + "Autograph doesn't fully support classes yet (but it will soon!), so we'll write the model using simple functions." + ] + }, + { + "metadata": { + "id": "9v8AJouiC44V", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def model_components():\n", + " lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", + " lower_cell.build(tf.TensorShape((None, 256)))\n", + " upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " upper_cell.build(tf.TensorShape((None, 256)))\n", + " relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", + " relu_layer.build(tf.TensorShape((None, 128)))\n", + " return lower_cell, upper_cell, relu_layer\n", + "\n", + "\n", + "def rnn_layer(chars, cell, batch_size, training):\n", + " \"\"\"A simple RNN layer.\n", + " \n", + " Args:\n", + " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", + " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", + " \"\"\"\n", + " hidden_outputs = []\n", + " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " state, output = cell.zero_state(batch_size, tf.float32)\n", + " n = tf.shape(chars)[0]\n", + " i = 0\n", + " while i < n:\n", + " ch = chars[i]\n", + " cell_output, (state, output) = cell.call(ch, (state, output))\n", + " hidden_outputs.append(cell_output)\n", + " i += 1\n", + " hidden_outputs = hidden_outputs.stack()\n", + " if training:\n", + " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", + " return hidden_outputs\n", + "\n", + "\n", + "def model(inputs, lower_cell, upper_cell, relu_layer, batch_size, training):\n", + " \"\"\"RNNColorbot model.\n", + " \n", + " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", + " followed by a fully connected layer with ReLU activation.\n", + " \n", + " Args:\n", + " inputs: A tuple (chars, length)\n", + " lower_cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " upper_cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " relu_layer: An object of type tf.layers.Dense\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + " \n", + " Returns:\n", + " A Tensor of shape (batch_size, 3) - the model predictions.\n", + " \"\"\"\n", + " (chars, length) = inputs\n", + " chars_time_major = tf.transpose(chars, [1, 0, 2])\n", + " chars_time_major.set_shape((None, batch_size, 256))\n", + "\n", + " hidden_outputs = rnn_layer(chars_time_major, lower_cell, batch_size, training)\n", + " final_outputs = rnn_layer(hidden_outputs, upper_cell, batch_size, training)\n", + "\n", + " # Grab just the end-of-sequence from each output.\n", + " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " sequence_ends = tf.gather_nd(final_outputs, indices)\n", + " return relu_layer(sequence_ends)\n", + "\n", + "def loss_fn(labels, predictions):\n", + " return tf.reduce_mean((predictions - labels) ** 2)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "JjK4gXFvFsf4", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The train and test functions are also similar to the ones used in the Eager notebook. Since the network requires a fixed batch size, we'll train in a single shot, rather than by epoch." + ] + }, + { + "metadata": { + "id": "ZWQMExk0S6X6", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", + " iterator = train_data.make_one_shot_iterator()\n", + " step = 0\n", + " while step < num_steps:\n", + " labels, chars, sequence_length = iterator.get_next()\n", + " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=True)\n", + " loss = loss_fn(labels, predictions)\n", + " optimizer.minimize(loss)\n", + " if step % (num_steps // 10) == 0:\n", + " print('Step', step, 'train loss', loss)\n", + " step += 1\n", + " return step\n", + "\n", + "\n", + "def test(eval_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", + " total_loss = 0.0\n", + " iterator = eval_data.make_one_shot_iterator()\n", + " step = 0\n", + " while step < num_steps:\n", + " labels, chars, sequence_length = iterator.get_next()\n", + " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=False)\n", + " total_loss += loss_fn(labels, predictions)\n", + " step += 1\n", + " print('Test loss', total_loss)\n", + " return total_loss\n", + "\n", + "\n", + "def train_model(train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps):\n", + " optimizer = tf.train.AdamOptimizer(learning_rate=0.01)\n", + "\n", + " train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps=tf.constant(train_steps))\n", + " test(eval_data, lower_cell, upper_cell, relu_layer, 50, num_steps=tf.constant(2))\n", + "\n", + " print('Colorbot is ready to generate colors!\\n\\n')\n", + " \n", + " # In graph mode, every op needs to be a dependent of another op.\n", + " # Here, we create a no_op that will drive the execution of all other code in\n", + " # this function. Autograph will add the necessary control dependencies.\n", + " return tf.no_op()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "iopcs5hXG2od", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, we add code to run inference on a single input, which we'll read from the input.\n", + "\n", + "Note the `do_not_convert` annotation that lets us disable conversion for certain functions and run them as a `py_func` instead, so you can still call them from compiled code." + ] + }, + { + "metadata": { + "id": "DyU0wnnAFEYj", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)\n", + "def draw_prediction(color_name, pred):\n", + " pred = pred * 255\n", + " pred = pred.astype(np.uint8)\n", + " plt.axis('off')\n", + " plt.imshow(pred)\n", + " plt.title(color_name)\n", + " plt.show()\n", + "\n", + "\n", + "def inference(color_name, lower_cell, upper_cell, relu_layer):\n", + " _, chars, sequence_length = parse(color_name)\n", + " chars = tf.expand_dims(chars, 0)\n", + " sequence_length = tf.expand_dims(sequence_length, 0)\n", + " pred = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, 1, training=False)\n", + " pred = tf.minimum(pred, 1.0)\n", + " pred = tf.expand_dims(pred, 0)\n", + " draw_prediction(color_name, pred)\n", + " # Create an op that will drive the entire function.\n", + " return tf.no_op()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Nt0Kv5OCHip0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, we put everything together.\n", + "\n", + "Note that the entire training and testing code is all compiled into a single op (`tf_train_model`) that you only execute once! We also still use a `sess.run` loop for the inference part, because that requires keyboard input." + ] + }, + { + "metadata": { + "id": "-GmWa0GtYWdh", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {} + ], + "base_uri": "https://localhost:8080/", + "height": 668 + }, + "outputId": "61f4af1d-c81e-44db-9079-1a7b8ed8ce58", + "executionInfo": { + "status": "ok", + "timestamp": 1522345877153, + "user_tz": 240, + "elapsed": 75500, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", + " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", + "\n", + " tb = widgets.TabBar([\"RNN Colorbot\"])\n", + " while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " sess.run(inference_ops, {color_name_placeholder: color_name})\n", + " plt.show()\n", + "\n", + "with tf.Graph().as_default():\n", + " # Read the data.\n", + " batch_size = 64\n", + " train_data = load_dataset(data_dir, train_url, batch_size)\n", + " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", + " \n", + " # Create the model components.\n", + " lower_cell, upper_cell, relu_layer = model_components()\n", + " # Create the helper placeholder for inference.\n", + " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", + " \n", + " # Compile the train / test code.\n", + " tf_train_model = autograph.to_graph(train_model)\n", + " train_model_ops = tf_train_model(\n", + " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", + " \n", + " # Compile the inference code.\n", + " tf_inference = autograph.to_graph(inference)\n", + " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", + " \n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " \n", + " # Run training and testing.\n", + " sess.run(train_model_ops)\n", + " \n", + " # Run the inference loop.\n", + " run_input_loop(sess, inference_ops, color_name_placeholder)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "('Successfully downloaded', 'train.csv', 28010L, 'bytes.')\n", + "('Successfully downloaded', 'test.csv', 2414L, 'bytes.')\n", + "Step 0 train loss 0.37890616\n", + "Step 10 train loss 0.18515904\n", + "Step 20 train loss 0.0892782\n", + "Step 30 train loss 0.07883155\n", + "Step 40 train loss 0.08585831\n", + "Step 50 train loss 0.09302989\n", + "Step 60 train loss 0.089012615\n", + "Step 70 train loss 0.07275697\n", + "Step 80 train loss 0.06644974\n", + "Step 90 train loss 0.0854013\n", + "Test loss 0.13216865Colorbot is ready to generate colors!\n", + "\n", + "\n", + "\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "
" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b102d936-3379-11e8-ac70-0242ac110002\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"borderColor\": [\"#a7a7a7\"], \"tabNames\": [\"RNN Colorbot\"], \"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"elementId\": \"id1\"});\n", + "//# sourceURL=js_e223a56194" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b103532a-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_b8c6a821fb" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b105b28c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_44805e254b" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b106197a-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_a63d3c6c47" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b1069f44-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b106197a-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7e203b8bce" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b1070f38-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_d53293d4a7" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6d90d5c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b105b28c-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_3000dc2c05" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6da872c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_4136f669a3" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6dac868-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_2f70dd9aee" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6db07d8-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6dac868-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7226726048" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6dcc6fe-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_72e7709865" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAFZCAYAAADHDNdrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAB9JJREFUeJzt3E1Lle0ax+HTF4jeEAyMBhE0DawI\nwsCH0AIlaGBWNJBo0CDoA0TQhmDXuKAGDioiCA2KlEAlnl05FD9Co8BeaGCQoBDa2jPZsXt4Bvu/\n0+o4Rmvd1zW4rsmP84bFamo0Go0C4H/WvNYHAPhVCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKDy\nUxgeHq5Dhw7V4OBgPXz4sHp7e+vWrVt15cqVOnnyZN2/f78ajUbdvn27+vr6qqenp65du1YrKytV\nVfXhw4e6cOFC9fX1VV9fX01PT1dV1dzcXHV3d9eDBw/q+PHj9ccff9TExMRaXpWfWOtaHwD+zuvX\nr+vOnTs1MTFRbW1tdf78+dW16enpGh8fr/b29hobG6upqal6/Phxbdy4sS5evFgjIyM1NDRUly5d\nqv3799fw8HC9efOmTp8+XVNTU1VV9enTp2pubq5nz57V5ORk3bhxo44dO7ZW1+UnZkJl3Zudna2D\nBw9WR0dHbdiwoQYHB1fX9u7dW+3t7VVV9fLlyxocHKytW7dWa2trnTp1qp4/f16Li4s1MzNT586d\nq6qqXbt21YEDB1an1OXl5Tpx4kRVVe3Zs6fevXv3Yy/IL8OEyrr3+fPnamtrW/2+ffv21c//+Xxh\nYaHu3r1bjx49qqqqlZWVam9vr4WFhWo0GnXmzJnVvYuLi9XV1VVVVS0tLbVp06aqqmpubq6vX7/+\nX+/Dr0tQWfe2bNlSi4uLq98/fvz43X0dHR3V29tbQ0ND3zxfXl6ulpaWevLkSW3evPmbtbm5ufyB\n+W155Wfd6+zsrJmZmZqfn68vX77U2NjYd/cdOXKkxsfHa2lpqaqqRkdH6+nTp9Xa2lqHDx+u0dHR\nqqpaWlqqy5cv1/v373/YHfg9CCrrXmdnZw0MDNTAwECdPXu2enp6vrvv6NGj1dPTUwMDA9Xf318v\nXryo7u7uqqq6evVqzc7OVn9/fw0MDNTOnTtrx44dP/Ia/Aaa/B8qP4NGo1FNTU1VVfXq1au6efPm\nX06qsFZMqKx78/Pz1dXVVW/fvq1Go1GTk5O1b9++tT4W/BcTKj+FkZGRunfvXjU1NdXu3bvr+vXr\ntW3btrU+FnxDUAFCvPIDhAgqQMi6+WH/kX8eXesjAPytf/3jz79cM6EChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkBI\nU6PRaKz1IQB+BSZUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBB\nBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAkH8D1Aj8lNhhe7QAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c70592aa-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6da872c-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_25c3aaf79a" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c70842c0-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_984c56b816" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c708dec4-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_e0451a1217" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7092726-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c708dec4-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7aa23d7385" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7099044-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_5722756ddb" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "stream", + "text": [ + "Give me a color name (or press 'enter' to exit): \n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7baac12-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c70842c0-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_cdd622e58f" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + } + ] + }, + { + "metadata": { + "id": "AHJ2c47U-A5W", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Where do we go next?\n", + "\n", + "Autograph is available in tensorflow.contrib, but it's still in its early stages. We're excited about the possibilities it brings — write your machine learning code in the flexible Eager style, but still enjoy all the benefits that come with running in graph mode. A beta version will be available soon -- stay tuned!" + ] + } + ] +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..324b23c24b5a7970d7f20ed955839ba1cf1774fc --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -0,0 +1,1078 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "LqNpENf-ec0X", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "!pip install -U tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Pa2qpEmoVOGe", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import six\n", + "\n", + "from google.colab import widgets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HNqUFL4deCsL", + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Case study: training a custom RNN, using Keras and Estimators\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YkC1k4HEQ7rw", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "In this section, we show how you can use AutoGraph to build RNNColorbot, an RNN that takes as input names of colors and predicts their corresponding RGB tuples. The model will be trained by a [custom Estimator](https://www.tensorflow.org/get_started/custom_estimators)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7nkPDl5CTCNb", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "To get started, set up the dataset. The following cells defines methods that download and format the data needed for RNNColorbot; the details aren't important (read them in the privacy of your own home if you so wish), but make sure to run the cells before proceeding." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "A0uREmVXCQEw", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "def parse(line):\n", + " \"\"\"Parses a line from the colors dataset.\"\"\"\n", + " items = tf.string_split([line], \",\").values\n", + " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", + " color_name = items[0]\n", + " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", + " length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n", + " return rgb, chars, length\n", + "\n", + "\n", + "def set_static_batch_shape(batch_size):\n", + " def apply(rgb, chars, length):\n", + " rgb.set_shape((batch_size, None))\n", + " chars.set_shape((batch_size, None, 256))\n", + " length.set_shape((batch_size,))\n", + " return rgb, chars, length\n", + " return apply\n", + "\n", + "\n", + "def load_dataset(data_dir, url, batch_size, training=True):\n", + " \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n", + " path = tf.keras.utils.get_file(os.path.basename(url), url, cache_dir=data_dir)\n", + " dataset = tf.data.TextLineDataset(path)\n", + " dataset = dataset.skip(1)\n", + " dataset = dataset.map(parse)\n", + " dataset = dataset.cache()\n", + " dataset = dataset.repeat()\n", + " if training:\n", + " dataset = dataset.shuffle(buffer_size=3000)\n", + " dataset = dataset.padded_batch(\n", + " batch_size, padded_shapes=((None,), (None, 256), ()))\n", + " # To simplify the model code, we statically set as many of the shapes that we\n", + " # know.\n", + " dataset = dataset.map(set_static_batch_shape(batch_size))\n", + " return dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "waZ89t3DTUla", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "To show the use of control flow, we write the RNN loop by hand, rather than using a pre-built RNN model.\n", + "\n", + "Note how we write the model code in Eager style, with regular `if` and `while` statements. Then, we annotate the functions with `@autograph.convert` to have them automatically compiled to run in graph mode.\n", + "We use Keras to define the model, and we will train it using Estimators." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "9v8AJouiC44V", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "class RnnColorbot(tf.keras.Model):\n", + " \"\"\"RNN Colorbot model.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super(RnnColorbot, self).__init__()\n", + " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", + " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", + "\n", + "\n", + " def _rnn_layer(self, chars, cell, batch_size, training):\n", + " \"\"\"A single RNN layer.\n", + "\n", + " Args:\n", + " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", + " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", + " \"\"\"\n", + " hidden_outputs = []\n", + " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " state, output = cell.zero_state(batch_size, tf.float32)\n", + " for ch in chars:\n", + " cell_output, (state, output) = cell.call(ch, (state, output))\n", + " hidden_outputs.append(cell_output)\n", + " hidden_outputs = hidden_outputs.stack()\n", + " if training:\n", + " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", + " return hidden_outputs\n", + "\n", + " def build(self, _):\n", + " \"\"\"Creates the model variables. See keras.Model.build().\"\"\"\n", + " self.lower_cell.build(tf.TensorShape((None, 256)))\n", + " self.upper_cell.build(tf.TensorShape((None, 256)))\n", + " self.relu_layer.build(tf.TensorShape((None, 128))) \n", + " self.built = True\n", + "\n", + "\n", + " def call(self, inputs, training=False):\n", + " \"\"\"The RNN model code. Uses Eager and \n", + "\n", + " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", + " followed by a fully connected layer with ReLU activation.\n", + "\n", + " Args:\n", + " inputs: A tuple (chars, length)\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (batch_size, 3) - the model predictions.\n", + " \"\"\"\n", + " chars, length = inputs\n", + " batch_size = chars.shape[0]\n", + " seq = tf.transpose(chars, (1, 0, 2))\n", + "\n", + " seq = self._rnn_layer(seq, self.lower_cell, batch_size, training)\n", + " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", + "\n", + " # Grab just the end-of-sequence from each output.\n", + " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " sequence_ends = tf.gather_nd(seq, indices)\n", + " return self.relu_layer(sequence_ends)\n", + "\n", + "@autograph.convert()\n", + "def loss_fn(labels, predictions):\n", + " return tf.reduce_mean((predictions - labels) ** 2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JjK4gXFvFsf4", + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "We will now create the model function for the custom Estimator.\n", + "\n", + "In the model function, we simply use the model class we defined above - that's it!" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yso_Nx23Gy1", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "def model_fn(features, labels, mode, params):\n", + " \"\"\"Estimator model function.\"\"\"\n", + " chars = features['chars']\n", + " sequence_length = features['sequence_length']\n", + " inputs = (chars, sequence_length)\n", + "\n", + " # Create the model. Simply using the AutoGraph-ed class just works!\n", + " colorbot = RnnColorbot()\n", + " colorbot.build(None)\n", + "\n", + " if mode == tf.estimator.ModeKeys.TRAIN:\n", + " predictions = colorbot(inputs, training=True)\n", + " loss = loss_fn(labels, predictions)\n", + "\n", + " learning_rate = params['learning_rate']\n", + " optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", + " global_step = tf.train.get_global_step()\n", + " train_op = optimizer.minimize(loss, global_step=global_step)\n", + " return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)\n", + "\n", + " elif mode == tf.estimator.ModeKeys.EVAL:\n", + " predictions = colorbot(inputs)\n", + " loss = loss_fn(labels, predictions)\n", + "\n", + " return tf.estimator.EstimatorSpec(mode, loss=loss)\n", + "\n", + " elif mode == tf.estimator.ModeKeys.PREDICT:\n", + " predictions = colorbot(inputs)\n", + "\n", + " predictions = tf.minimum(predictions, 1.0)\n", + " return tf.estimator.EstimatorSpec(mode, predictions=predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HOQfoBnHC9CP", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "We'll create an input function that will feed our training and eval data." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "FJZlx7yG2MP0", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "def input_fn(data_dir, data_url, params, training=True):\n", + " \"\"\"An input function for training\"\"\"\n", + " batch_size = params['batch_size']\n", + " \n", + " # load_dataset defined above\n", + " dataset = load_dataset(data_dir, data_url, batch_size, training=training)\n", + "\n", + " # Package the pipeline end in a format suitable for the estimator.\n", + " labels, chars, sequence_length = dataset.make_one_shot_iterator().get_next()\n", + " features = {\n", + " 'chars': chars,\n", + " 'sequence_length': sequence_length\n", + " }\n", + "\n", + " return features, labels" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qsvv-lzbDqXd", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "We now have everything in place to build our custom estimator and use it for training and eval!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 35 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 10604, + "status": "ok", + "timestamp": 1524095272039, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "2pg1AfbxBJQq", + "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval loss at step 100: 0.0674834\n" + ] + } + ], + "source": [ + "params = {\n", + " 'batch_size': 64,\n", + " 'learning_rate': 0.01,\n", + "}\n", + "\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "data_dir = \"tmp/rnn/data\"\n", + "\n", + "regressor = tf.estimator.Estimator(\n", + " model_fn=model_fn,\n", + " params=params)\n", + "\n", + "regressor.train(\n", + " input_fn=lambda: input_fn(data_dir, train_url, params),\n", + " steps=100)\n", + "eval_results = regressor.evaluate(\n", + " input_fn=lambda: input_fn(data_dir, test_url, params, training=False),\n", + " steps=2\n", + ")\n", + "\n", + "print('Eval loss at step %d: %s' % (eval_results['global_step'], eval_results['loss']))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zG1YAjB_cUnQ", + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "And here's the same estimator used for inference." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 343 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 7990, + "status": "ok", + "timestamp": 1524095280105, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "dxHex2tUN_10", + "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n", + "//# sourceURL=js_71b9087b6d" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_e390445f33" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_241dd76d85" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_60c64e3d50" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_14ea437cbd" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_09294c2226" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_e5e8266997" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_07a097f0ee" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_790d669ca8" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_d30df771f0" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_8a43a2da4b" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_893ad561f4" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_2d99e0ac17" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_5c19462e32" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_b9c8b7567b" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_fd05186348" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", + "//# sourceURL=js_efef96e882" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_6eca889864" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n", + "//# sourceURL=js_f02070cc60" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_ed9faba660" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", + "//# sourceURL=js_f3458d7074" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_3ffd97bd6f" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_7f73e8bcca" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + } + ], + "source": [ + "def predict_input_fn(color_name):\n", + " \"\"\"An input function for prediction.\"\"\"\n", + " _, chars, sequence_length = parse(color_name)\n", + "\n", + " # We create a batch of a single element.\n", + " features = {\n", + " 'chars': tf.expand_dims(chars, 0),\n", + " 'sequence_length': tf.expand_dims(sequence_length, 0)\n", + " }\n", + " return features, None\n", + "\n", + "\n", + "def draw_prediction(color_name, pred):\n", + " pred = pred * 255\n", + " pred = pred.astype(np.uint8)\n", + " plt.axis('off')\n", + " plt.imshow(pred)\n", + " plt.title(color_name)\n", + " plt.show()\n", + "\n", + "\n", + "def predict_with_estimator(color_name, regressor):\n", + " predictions = regressor.predict(\n", + " input_fn=lambda:predict_input_fn(color_name))\n", + " pred = next(predictions)\n", + " predictions.close()\n", + " pred = np.minimum(pred, 1.0)\n", + " pred = np.expand_dims(np.expand_dims(pred, 0), 0)\n", + "\n", + " draw_prediction(color_name, pred)\n", + "\n", + "tb = widgets.TabBar([\"RNN Colorbot\"])\n", + "while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " predict_with_estimator(color_name, regressor)\n", + " " + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "RNN Colorbot using Keras and Estimators", + "provenance": [ + { + "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", + "timestamp": 1523579810961 + }, + { + "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", + "timestamp": 1523016192637 + }, + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD similarity index 64% rename from tensorflow/contrib/py2tf/impl/BUILD rename to tensorflow/contrib/autograph/impl/BUILD index 90ffabbc9bf4524ec2ebf54b6dd847bd8768a486..02f16ae1875d6bd1fb87d19f8bfc5cae900391dd 100644 --- a/tensorflow/contrib/py2tf/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -20,15 +20,18 @@ py_library( "api.py", "config.py", "conversion.py", + "directives.py", "naming.py", + "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/contrib/py2tf/converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -38,10 +41,12 @@ py_test( name = "api_test", srcs = ["api_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", + "//third_party/py/numpy", ], ) @@ -49,6 +54,7 @@ py_test( name = "conversion_test", srcs = ["conversion_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", "//tensorflow/python:client_testlib", @@ -65,3 +71,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py new file mode 100644 index 0000000000000000000000000000000000000000..24f87b2c14da4a3523f1e580d4362cbd3679a2cd --- /dev/null +++ b/tensorflow/contrib/autograph/impl/api.py @@ -0,0 +1,296 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Public API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +import six +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_inspect + +# TODO(mdan): Properly document the type hints. +# TODO(mdan): Reduce the type hint information to (module, type). +# (currently we require (module + class name, type)) + + +def convert(recursive=False, verbose=False, arg_types=None): + """Decorator that compiles a function to graph mode. + + The decorator is dynamic - invoking compilation whenever the decorated + function is called. This means the parameter values are known at compilation. + + Args: + recursive: Whether to recursively convert any functions that the decorator + function may call. + verbose: Whether to output the compiled code in the logs. + arg_types: See to_graph. + + Returns: + A decorator that compiles the given function to graph mode. + + Raises: + ValueError: If any of the arguments are illegal. + """ + if arg_types is None: + arg_types = {} + + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def wrapper(*args, **kwargs): + return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +class RunMode(Enum): + GRAPH = 1 + PY_FUNC = 2 + + +def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): + """Decorator that suppresses compilation of a function. + + Args: + run_as: RunMode value. Whether to run the function as-is, or wrap it into + a py_func. + return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or + empty list or tuple will create a dummy return value that can be used + to set control dependencies. + + Returns: + A decorator that wraps the original function. + """ + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def graph_wrapper(*args, **kwargs): + return f(*args, **kwargs) + + @wraps(f) + def py_func_wrapper(*args, **kwargs): + if kwargs: + raise NotImplementedError( + 'RunMode.PY_FUNC does not yet support kwargs') + # TODO(mdan): Add support for kwargs. + return py_func.wrap_py_func( + f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) + + if run_as == RunMode.GRAPH: + wrapper = graph_wrapper + elif run_as == RunMode.PY_FUNC: + wrapper = py_func_wrapper + else: + raise ValueError('unknown value for run_as: %s' % run_as) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): + """Compiles a function call inline.""" + # TODO(mdan): This needs cleanup. + # In particular, we may want to avoid renaming functions altogether. + + if conversion.is_whitelisted_for_graph(f): + return f(*args, **kwargs) + + unknown_arg_value = object() # Sentinel for arguments of unknown value + + if inspect_utils.isbuiltin(f): + return builtins.dynamic_builtin(f, *args, **kwargs) + + if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): + # Regular functions + target_entity = f + arg_map_target = f + effective_args = args + f_class = inspect_utils.getmethodclass(f) + + if f_class is not None: + partial_types = (f_class,) + else: + partial_types = () + + elif tf_inspect.isclass(f): + # Constructors + target_entity = f + arg_map_target = f.__init__ + effective_args = args + partial_types = () + + elif hasattr(f, '__call__') and hasattr(f, '__class__'): + # Callable objects + target_entity = f.__call__ + arg_map_target = f.__call__ + effective_args = (f,) + args + partial_types = (f.__class__,) + + else: + NotImplementedError('unknown callable type "%s"' % type(f)) + + arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) + for name, arg in arg_values.items(): + if arg is unknown_arg_value: + continue + arg_class = arg.__class__ + # If arg_value_hints specifies any name, use that instead. + if name not in arg_types: + arg_types[name] = (arg_class.__name__, arg_class) + + # When called from within a decorator, this is the only indication that + # the function is a method - it appears that the decorator is applied + # before the method is bound. + if not partial_types: + if 'self' in arg_values: + if tf_inspect.isclass(arg_values['self'].__class__): + partial_types = (arg_values['self'].__class__,) + elif 'cls' in arg_values: + if tf_inspect.isclass(arg_values['cls']): + partial_types = (arg_values['cls'],) + + converted_f = to_graph( + target_entity, + recursive=recursive, + verbose=verbose, + arg_values=arg_values, + arg_types=arg_types, + partial_types=partial_types) + return converted_f(*effective_args, **kwargs) + + +def to_graph(e, + recursive=True, + verbose=False, + arg_values=None, + arg_types=None, + partial_types=None): + """Compile a Python entity into equivalent TensorFlow code. + + Currently supported entities: + * functions + * classes + + Classes are handled by converting all their methods into a new class. + + Args: + e: A Python entity. + recursive: Whether to recursively convert any functions that the decorator + function may call. + verbose: Whether to output the compiled code in the logs. + arg_values: A dict containing value hints for symbols like function + parameters. + arg_types: A dict containing type hints for symbols like function + parameters. + partial_types: A set of types (e.g. classes) that will not be converted + entirely. Calls to member functions for these types will be renamed + independently. + + Returns: + A function with a signature identical to `o`, but which when executed it + creates TF a graph that has the same functionality as the original entity. + """ + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, do_not_convert, converted_call), + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) + _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values, + arg_types) + + module = gast.Module([]) + for import_line in config.COMPILED_IMPORT_STATEMENTS: + module.body.extend(parser.parse_str(import_line).body) + for dep in reversed(conversion_map.dependency_cache.values()): + module.body.append(dep) + compiled_node, compiled_src = compiler.ast_to_object(module) + + # The compiled code should see everything the entry entity saw. + # TODO(mdan): This might not work well if the call tree spans modules? + for key, val in namespace.items(): + # Avoid overwriting entities that have been transformed. + if key not in compiled_node.__dict__: + compiled_node.__dict__[key] = val + compiled_fn = getattr(compiled_node, name) + + if verbose: + logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) + + return compiled_fn + + +def to_code(e, + recursive=True, + arg_values=None, + arg_types=None, + partial_types=None, + indentation=' '): + """Return the equivalent of an entity in TensorFlow code. + + See `to_graph` for more details. + + Args: + e: A Python entity. + recursive: See to_graph. + arg_values: See to_graph. + arg_types: See to_graph. + partial_types: See to_graph. + indentation: String, when to use for each level of indentation. + + Returns: + String. + """ + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, do_not_convert, converted_call), + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) + conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) + + imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) + code = '\n'.join( + compiler.ast_to_source(dep, indentation) + for dep in reversed(tuple( + six.itervalues(conversion_map.dependency_cache)))) + + return imports + '\n\n' + code diff --git a/tensorflow/contrib/py2tf/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py similarity index 55% rename from tensorflow/contrib/py2tf/impl/api_test.py rename to tensorflow/contrib/autograph/impl/api_test.py index 02cd8ed2d0ffee8ef2d31ea65902d2b493df9d64..a7737b7f448131b1c54951efa719b481e1f4d0c9 100644 --- a/tensorflow/contrib/py2tf/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -18,23 +18,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import api -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.pyct import parser +import numpy as np + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +tf = utils.fake_tf() + + class ApiTest(test.TestCase): def setUp(self): - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) config.COMPILED_IMPORT_STATEMENTS = ( - 'from tensorflow.python.ops ' - 'import control_flow_ops as tf', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') + 'from __future__ import print_function', + 'from tensorflow.contrib.autograph import utils' + ' as autograph_utils', + 'tf = autograph_utils.fake_tf()', + ) def test_decorator_recurses(self): @@ -47,7 +53,7 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -63,11 +69,11 @@ class ApiTest(test.TestCase): class TestClass(object): def called_member(self, a): - return math_ops.negative(a) + return tf.negative(a) @api.convert(recursive=False) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -78,17 +84,17 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_converted(self): + def test_decorator_calls_unconverted_graph(self): class TestClass(object): - @api.graph_ready + @api.do_not_convert(api.RunMode.GRAPH) def called_member(self, a): - return math_ops.negative(a) + return tf.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -99,20 +105,23 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_decorated(self): + def test_decorator_calls_unconverted_py_func(self): class TestClass(object): - @api.convert() + @api.do_not_convert( + api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1)) def called_member(self, a): - if a < 0: - a = -a - return a + return np.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= self.called_member(a) + while tf.reduce_sum(x) > s: + y = self.called_member(a) + # set_shape works around while_loop's limitations. + # TODO(mdan): Allow specifying shapes (or ShapeLike) instead. + y.set_shape(a.shape) + x //= y return x tc = TestClass() @@ -122,10 +131,11 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_convert_call_site_decorator(self): + def test_decorator_calls_decorated(self): class TestClass(object): + @api.convert() def called_member(self, a): if a < 0: a = -a @@ -133,8 +143,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= api.convert_inline(self.called_member, a) + while tf.reduce_sum(x) > s: + x //= self.called_member(a) return x tc = TestClass() @@ -144,17 +154,20 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_graph_ready_call_site_decorator(self): + def test_convert_call_site_decorator(self): class TestClass(object): def called_member(self, a): - return math_ops.negative(a) + if a < 0: + a = -a + return a @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= api.graph_ready(self.called_member(a)) + while tf.reduce_sum(x) > s: + x //= api.converted_call(self.called_member, False, False, {}, self, + a) return x tc = TestClass() @@ -164,9 +177,96 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_converted_call_builtin(self): + x = api.converted_call(range, False, False, {}, 3) + self.assertEqual((0, 1, 2), tuple(x)) + + def test_converted_call_function(self): + + def test_fn(x): + if x < 0: + return -x + return x + + with self.test_session() as sess: + x = api.converted_call( + test_fn, False, False, {}, constant_op.constant(-1)) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_method(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def test_method(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = TestClass(constant_op.constant(-1)) + x = api.converted_call(tc.test_method, False, False, {}, tc) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_method_by_class(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def test_method(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = TestClass(constant_op.constant(-1)) + x = api.converted_call(TestClass.test_method, False, False, {}, tc) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_callable_object(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def __call__(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = TestClass(constant_op.constant(-1)) + x = api.converted_call(tc, False, False, {}) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_constructor(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def test_method(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = api.converted_call( + TestClass, False, False, {}, constant_op.constant(-1)) + # tc is now a converted object. + x = tc.test_method() + self.assertEqual(1, sess.run(x)) + def test_to_graph_basic(self): + def test_fn(x, s): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= 2 return x @@ -177,15 +277,15 @@ class ApiTest(test.TestCase): self.assertListEqual([1, 2], sess.run(x).tolist()) def test_to_code_basic(self): + def test_fn(x, s): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x /= 2 return x compiled_code = api.to_code(test_fn) - # Just check for some key words and that it is parseable Python code. - self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while') + # Just check that it is parseable Python code. self.assertIsNotNone(parser.parse_str(compiled_code)) diff --git a/tensorflow/contrib/py2tf/impl/config.py b/tensorflow/contrib/autograph/impl/config.py similarity index 73% rename from tensorflow/contrib/py2tf/impl/config.py rename to tensorflow/contrib/autograph/impl/config.py index c90e85c96b690b7781358b173e5d83fe60e29c00..878bb7e12f2b39a0ec40004ff2c7ac3ab8031e14 100644 --- a/tensorflow/contrib/py2tf/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.autograph import utils PYTHON_LITERALS = { @@ -31,16 +31,19 @@ PYTHON_LITERALS = { DEFAULT_UNCOMPILED_MODULES = set(( ('tensorflow',), (utils.__name__,), + + # All of tensorflow's subpackages. Unlike the root tf module, they don't + # have well-known names. Not referring to the module directly to avoid + # circular imports. + ( + utils.__name__[:-len('.contrib.autograph.utils')],), )) NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) -# TODO(mdan): Also allow controlling the generated names (for testability). -# TODO(mdan): Make sure copybara renames the reference below. +# TODO(mdan): Also allow controlling the generated names. +# TODO(mdan); Consolidate all internal imports into a single __ag module. COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', 'import tensorflow as tf', - 'from tensorflow.contrib.py2tf.impl import api as ' - 'py2tf_api', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') +) diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..7802bbbe27ec5fed891440af2f589801918b3bdd --- /dev/null +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -0,0 +1,414 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""High level conversion support.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import imp + +import gast + +from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info +from tensorflow.contrib.autograph.utils import type_hints +from tensorflow.python.util import tf_inspect + + +# TODO(mdan): Might we not need any renaming at all? + + +class ConversionMap(object): + """ConversionMap keeps track of converting function hierarchies. + + This object is mutable, and is updated as functions are converted. + + Attributes: + recursive: Whether to recursively convert any functions that the decorator + function may call. + nocompile_decorators: tuple of decorator functions that toggle compilation + off. + dependency_cache: dict[object]: ast; maps original entities to their + converted AST + additional_imports: set(object); additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code + name_map: dict[string]: string; maps original entities to the name of + their converted counterparts + api_module: A reference to the api module. The reference needs to be passed + to avoid circular dependencies. + """ + + # TODO(mdan): Rename to ConversionContext, and pull in additional flags. + + def __init__(self, recursive, nocompile_decorators, partial_types, + api_module): + self.recursive = recursive + self.nocompile_decorators = nocompile_decorators + self.partial_types = partial_types if partial_types else () + # Required to output dependencies in discovery order, which should match + # the reverse dependency order. + self.dependency_cache = collections.OrderedDict() + self.additional_imports = set() + self.name_map = {} + self.api_module = api_module + + def new_namer(self, namespace): + return naming.Namer(namespace, self.recursive, self.name_map, + self.partial_types) + + def update_name_map(self, namer): + for o, name in namer.renamed_calls.items(): + if o in self.name_map: + if self.name_map[o] != name: + raise ValueError( + 'Calls to %s were converted using multiple names (%s). This is ' + 'possible when an entity with one of these names already ' + 'existed. To fix, avoid using any of these names.') + else: + self.name_map[o] = name + + def add_to_cache(self, original_entity, converted_ast): + self.dependency_cache[original_entity] = converted_ast + + +def is_whitelisted_for_graph(o): + """Check whether an entity is whitelisted for use in graph mode. + + Examples of whitelisted entities include all members of the tensorflow + package. + + Args: + o: A Python entity. + Returns: + Boolean + """ + m = tf_inspect.getmodule(o) + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + return True + return False + + +def entity_to_graph(o, conversion_map, arg_values, arg_types): + """Compile a Python entity into equivalent TensorFlow. + + The function will also recursively compile all the entities that `o` + references, updating `dependency_cache`. + + This function is reentrant, and relies on dependency_cache to avoid + generating duplicate code. + + Args: + o: A Python entity. + conversion_map: A ConversionMap object. + arg_values: A dict containing value hints for symbols like function + parameters. + arg_types: A dict containing type hints for symbols like function + parameters. + + Returns: + A tuple (ast, new_name, namespace): + * ast: An AST representing an entity with interface equivalent to `o`, + but which when executed it creates TF a graph. + * new_name: The symbol name under which the new entity can be found. + * namespace: A dict mapping all symbols visible to the converted entity, + keyed by their symbol name. + + Raises: + ValueError: if the entity type is not supported. + """ + if tf_inspect.isclass(o): + node, name, ns = class_to_graph(o, conversion_map) + elif tf_inspect.isfunction(o): + # TODO(mdan): This is not a reliable mechanism. + # The most reliable way is to check the source code, the AST will contain + # a Lambda node instead of a FunctionDef + if o.__name__ == '': + raise NotImplementedError( + 'lambda functions are not yet supported; declare the function' + ' using def instead: %s' % o) + else: + node, name, ns = function_to_graph(o, conversion_map, arg_values, + arg_types) + elif tf_inspect.ismethod(o): + node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) + else: + raise ValueError( + 'Entity "%s" has unsupported type "%s". Only functions and classes are ' + 'supported for now.' % (o, type(o))) + + conversion_map.add_to_cache(o, node) + if conversion_map.recursive: + while True: + candidate = None + for obj in conversion_map.name_map.keys(): + if obj not in conversion_map.dependency_cache: + candidate = obj + break + if candidate is None: + break + if (hasattr(candidate, 'im_class') and + getattr(candidate, 'im_class') not in conversion_map.partial_types): + # Class members are converted with their objects, unless they're + # only converted partially. + continue + entity_to_graph(candidate, conversion_map, {}, {}) + + return node, name, ns + + +def class_to_graph(c, conversion_map): + """Specialization of `entity_to_graph` for classes.""" + converted_members = {} + method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) + members = tf_inspect.getmembers(c, predicate=method_filter) + if not members: + raise ValueError('Cannot convert %s: it has no member methods.' % c) + + class_namespace = {} + for _, m in members: + # Only convert the members that are directly defined by the class. + if inspect_utils.getdefiningclass(m, c) is not c: + continue + node, _, namespace = function_to_graph( + m, + conversion_map=conversion_map, + arg_values={}, + arg_types={'self': (c.__name__, c)}, + owner_type=c) + if class_namespace is None: + class_namespace = namespace + else: + class_namespace.update(namespace) + converted_members[m] = node + namer = conversion_map.new_namer(class_namespace) + class_name = namer.compiled_class_name(c.__name__, c) + + # TODO(mdan): This needs to be explained more thoroughly. + # Process any base classes: if the sueprclass if of a whitelisted type, an + # absolute import line is generated. Otherwise, it is marked for conversion + # (as a side effect of the call to namer.compiled_class_name() followed by + # conversion_map.update_name_map(namer)). + output_nodes = [] + renames = {} + bases = [] + for base in c.__bases__: + if isinstance(object, base): + bases.append('object') + continue + if is_whitelisted_for_graph(base): + alias = namer.new_symbol(base.__name__, ()) + output_nodes.append( + gast.ImportFrom( + module=base.__module__, + names=[gast.alias(name=base.__name__, asname=alias)], + level=0)) + else: + # This will trigger a conversion into a class with this name. + alias = namer.compiled_class_name(base.__name__, base) + bases.append(alias) + renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) + conversion_map.update_name_map(namer) + + # Generate the definition of the converted class. + output_nodes.append( + gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[])) + node = gast.Module(output_nodes) + + # Make a final pass to replace references to the class or its base classes. + # Most commonly, this occurs when making super().__init__() calls. + # TODO(mdan): Making direct references to superclass' superclass will fail. + node = qual_names.resolve(node) + renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) + node = ast_util.rename_symbols(node, renames) + + return node, class_name, class_namespace + + +def _add_reserved_symbol(namespace, name, entity): + if name not in namespace: + namespace[name] = entity + elif namespace[name] != entity: + raise ValueError('The name "%s" is reserved and may not be used.' % name) + + +ag_internal = None + + +def _add_self_references(namespace, api_module): + """Adds namespace references to the module that exposes the api itself.""" + global ag_internal + if ag_internal is None: + # Craft a module that exposes parts of the external API as well as certain + # internal modules. + ag_internal = imp.new_module('autograph') + ag_internal.converted_call = api_module.converted_call + ag_internal.utils = utils + # TODO(mdan): Add safeguards against name clashes. + # We don't want to create a submodule because we want the operators to be + # accessible as ag__. + ag_internal.__dict__.update(operators.__dict__) + + _add_reserved_symbol(namespace, 'ag__', ag_internal) + + +def function_to_graph(f, conversion_map, arg_values, arg_types, + owner_type=None): + """Specialization of `entity_to_graph` for callable functions.""" + node, source = parser.parse_entity(f) + node = node.body[0] + + namespace = inspect_utils.getnamespace(f) + _add_self_references(namespace, conversion_map.api_module) + namer = conversion_map.new_namer(namespace) + + ctx = context.EntityContext( + namer=namer, + source_code=source, + source_file='', + namespace=namespace, + arg_values=arg_values, + arg_types=arg_types, + owner_type=owner_type, + recursive=conversion_map.recursive, + type_annotation_func=type_hints.set_element_type) + node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + + # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) + if not did_rename: + new_name = f.__name__ + if node.name != f.__name__: + raise NotImplementedError('Strange corner case. Send us offending code!') + + node.name = new_name + conversion_map.update_name_map(namer) + # TODO(mdan): Use this at compilation. + conversion_map.additional_imports.update(deps) + + return node, new_name, namespace + + +def _static_analysis_pass(node, ctx): + node = qual_names.resolve(node) + node = activity.resolve(node, ctx, None) + node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) + node = type_info.resolve(node, ctx) + return node + + +def node_to_graph(node, ctx, nocompile_decorators): + """Convert Python code to equivalent TF graph mode code. + + Args: + node: A Python AST node representing the code to convert. + ctx: An EntityContext object. + nocompile_decorators: A tuple containing decorators to be stripped from + functions during conversion. + + Returns: + A tuple (node, deps): + * node: A Python ast node, representing the converted code. + * deps: A set of strings, the fully qualified names of entity + dependencies that this node has. + """ + # TODO(mdan): Verify arguments for correctness. + + # TODO(mdan): Factor out common elements. + # These include: + # * code move between blocks + # * visiting blocks in transformers + + # Certain steps, especially canonicalization, insert new symbols into the + # tree, which must be accounted. Although less efficient, it is most robust + # to re-run the analysis. + + node = _static_analysis_pass(node, ctx) + + # TODO(mdan): Clean this up. + # Some intermediate analyses are not required, and some comments got orphaned. + + # TODO(mdan): We may assume all converters require analysis to be re-done. + + # Past this point, line numbers are no longer accurate so we ignore the + # source. + # TODO(mdan): Is it feasible to reconstruct intermediate source code? + ctx.source_code = None + node = ifexp.transform(node, ctx) + node, deps = decorators.transform(node, nocompile_decorators) + node = break_statements.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + + node = asserts.transform(node, ctx) + + # Note: sequencing continue canonicalization before for loop one avoids + # dealing with the extra loop increment operation that the for + # canonicalization creates. + node = continue_statements.transform(node, ctx) + ctx.namespace['len'] = len + + node = _static_analysis_pass(node, ctx) + node = single_return.transform(node, ctx) + + node = _static_analysis_pass(node, ctx) + node = lists.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + node = slices.transform(node, ctx) + node = builtin_functions.transform(node, ctx) + + node = _static_analysis_pass(node, ctx) + node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, + nocompile_decorators) + node = control_flow.transform(node, ctx) + + # control_flow may create new symbols and change scopes. + node = _static_analysis_pass(node, ctx) + node = logical_expressions.transform(node, ctx) + node = side_effect_guards.transform(node, ctx) + node = name_scopes.transform(node, ctx) + + return node, deps diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bc61498b5422f5e130bbfeef935d0a796b4f5922 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -0,0 +1,165 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for conversion module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.python.framework import constant_op +from tensorflow.python.keras.engine import training +from tensorflow.python.platform import test + + +class ConversionTest(test.TestCase): + + def _simple_conversion_map(self): + return conversion.ConversionMap(True, (), (), api) + + def test_is_whitelisted_for_graph(self): + + def test_fn(): + return constant_op.constant(1) + + self.assertFalse(conversion.is_whitelisted_for_graph(test_fn)) + self.assertTrue(conversion.is_whitelisted_for_graph(utils)) + self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) + + def test_entity_to_graph_unsupported_types(self): + with self.assertRaises(ValueError): + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph('dummy', conversion_map, None, None) + + def test_entity_to_graph_callable(self): + b = 2 + def f(a): + return a + b + + conversion_map = self._simple_conversion_map() + ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) + self.assertTrue(isinstance(ast, gast.FunctionDef), ast) + self.assertEqual('tf__f', name) + self.assertTrue(ns['b'] is b) + + def test_entity_to_graph_call_tree(self): + + def g(a): + return a + + def f(a): + return g(a) + + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(f, conversion_map, None, None) + + self.assertTrue(f in conversion_map.dependency_cache) + self.assertTrue(g in conversion_map.dependency_cache) + self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) + # need the extra .body[0] in order to step past the with tf.name_scope('f') + # that is added automatically + self.assertEqual( + 'tf__g', + conversion_map.dependency_cache[f].body[0].body[0].value.func.id) + self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) + + def test_entity_to_graph_class_hierarchy(self): + + class TestBase(object): + + def __init__(self, x='base'): + self.x = x + + def foo(self): + return self.x + + def bar(self): + return self.x + + class TestSubclass(TestBase): + + def __init__(self, y): + super(TestSubclass, self).__init__('sub') + self.y = y + + def foo(self): + return self.y + + def baz(self): + return self.y + + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + + self.assertTrue(TestBase in conversion_map.dependency_cache) + self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertEqual('TfTestBase', + conversion_map.dependency_cache[TestBase].body[-1].name) + self.assertEqual( + 'TfTestSubclass', + conversion_map.dependency_cache[TestSubclass].body[-1].name) + + def test_entity_to_graph_class_hierarchy_whitelisted(self): + + class TestSubclass(training.Model): + + def __init__(self, y): + super(TestSubclass, self).__init__() + self.built = False + + def call(self, x): + return 3 * x + + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + + self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertFalse(training.Model in conversion_map.dependency_cache) + self.assertEqual( + 'Model', + conversion_map.dependency_cache[TestSubclass].body[0].names[0].name) + self.assertEqual( + 'TfTestSubclass', + conversion_map.dependency_cache[TestSubclass].body[-1].name) + + def test_entity_to_graph_lambda(self): + f = lambda a: a + + with self.assertRaises(NotImplementedError): + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(f, conversion_map, None, None) + + def test_ag_module_cached(self): + def callee(): + return range(3) + + def caller(a): + return a() + + conversion_map = self._simple_conversion_map() + _, _, callee_ns = conversion.entity_to_graph( + callee, conversion_map, None, None) + _, _, caller_ns = conversion.entity_to_graph( + caller, conversion_map, None, None) + + self.assertTrue(callee_ns['ag__'] is caller_ns['ag__']) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/impl/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/py2tf/impl/naming.py b/tensorflow/contrib/autograph/impl/naming.py similarity index 96% rename from tensorflow/contrib/py2tf/impl/naming.py rename to tensorflow/contrib/autograph/impl/naming.py index 51326091de13715c32d0a79279f1d3274e48ad10..b1d3f76be7763fada88fd0a1da9d3aa43b67ddfa 100644 --- a/tensorflow/contrib/py2tf/impl/naming.py +++ b/tensorflow/contrib/autograph/impl/naming.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import qual_names class Namer(object): @@ -62,8 +62,6 @@ class Namer(object): n += 1 new_name = '%s_%d' % (new_name_root, n) - if live_entity is not None: - self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) if live_entity is not None: self.renamed_calls[live_entity] = new_name diff --git a/tensorflow/contrib/py2tf/impl/naming_test.py b/tensorflow/contrib/autograph/impl/naming_test.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming_test.py rename to tensorflow/contrib/autograph/impl/naming_test.py index beb4e54937bbb91b19157c9b9e3c528353206c62..73fc0894655cb49e4f61bf8ca51995b06feb3072 100644 --- a/tensorflow/contrib/py2tf/impl/naming_test.py +++ b/tensorflow/contrib/autograph/impl/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import naming +from tensorflow.contrib.autograph.impl import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a8177c44c88217560fb7f72c77d3ac1aa0c9ec --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None): + """Stacks the input, if it admits the notion of stacking. No-op otherwise. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any entity. + element_dtype: Optional dtype for the elements in the list. Required if the + input is stackable, and the list is untyped. + + Returns: + If the input is stackable, a new object representing the stacked inputs. + Otherwise it returns list_or_tensor unchanged. + """ + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/impl/special_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52d2a59b5a3e3c92f11343197379c773ecc828 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1), 1) + self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)]), list)) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0c6ab65505ee03e19588adae73d3134399a34b65 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -0,0 +1,64 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "operators", + srcs = [ + "__init__.py", + "control_flow.py", + "data_structures.py", + "slices.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "data_structures_test", + srcs = ["data_structures_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "control_flow_test", + srcs = ["control_flow_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c900fd6af2ea5dfb419f731ee8d8822d68424b27 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module implements operators that we overload. + +Note that "operator" is used loosely here, and includes control structures like +conditionals and loops, implemented in functional form, using for example +closures for the body. +""" + +# Naming conventions: +# * operator names match the name usually used for the respective Python +# idiom; examples: for_stmt, list_append +# * operator arguments match either of: +# - the corresponding Python AST attribute (e.g. the condition of an if +# statement is called test) if the operator represents an AST construct +# - the names used in the Python docs, if the operator is a function (e.g. +# list_ and x for append, see +# https://docs.python.org/3.7/tutorial/datastructures.html) +# +# All operators may accept a final argument named "opts", of a type that +# subclasses namedtuple and contains any arguments that are only required +# for some specializations of the operator. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators.control_flow import for_stmt +from tensorflow.contrib.autograph.operators.control_flow import while_stmt +from tensorflow.contrib.autograph.operators.data_structures import list_append +from tensorflow.contrib.autograph.operators.data_structures import list_pop +from tensorflow.contrib.autograph.operators.data_structures import list_stack +from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts +from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts +from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.slices import get_item +from tensorflow.contrib.autograph.operators.slices import GetItemOpts +from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..671c9ccc13eaa887522cfc248a6d56d7ab9719ca --- /dev/null +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -0,0 +1,227 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow statements: loops, conditionals, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_math_ops + + +def for_stmt(iter_, extra_test, body, init_state): + """Functional form of a for statement. + + The loop operates on a state, which includes all symbols that are + variant across loop iterations, excluding the iterate as well as the + variables local to the loop. + + For example, given the loop below that calculates the geometric and + arithmetic means or some numbers: + + geo_mean = 1 + arith_mean = 0 + for i in range(n): + a = numbers[i] + geo_mean *= a + arith_mean += a + + The state is represented by the variables geo_mean and arith_mean. The + argument for initial_state may contain the tuple (1, 0), the body will + include the arguments geo_mean and arith_mean and will return a tuple + representing the new values for geo_mean and respectively arith_mean. + + Args: + iter_: The entity being iterated over. + extra_test: Callable with the state as arguments, and boolean return type. + An additionnal loop condition. + body: Callable with the iterate and the state as arguments, and + state as return type. The actual loop body. + init_state: Tuple containing the initial state. + + Returns: + Tuple containing the final state. + """ + if tensor_util.is_tensor(iter_): + return _known_len_for_stmt(iter_, extra_test, body, init_state) + elif isinstance(iter_, dataset_ops.Dataset): + return _dataset_for_stmt(iter_, extra_test, body, init_state) + else: + return _py_for_stmt(iter_, extra_test, body, init_state) + + +def _py_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that executes a Python for loop.""" + state = init_state + for target in iter_: + if not extra_test(*state): + break + state = body(target, *state) + + # TODO(mdan): Remove this special case. + if len(state) == 1: + return state[0] + return state + + +def _known_len_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that iterates over objects that define a length.""" + n = builtins.dynamic_len(iter_) + + def while_body(iterate_index, *state): + iterate = iter_[iterate_index] + new_state = body(iterate, *state) + return (iterate_index + 1,) + new_state + + def while_cond(iterate_index, *state): + return gen_math_ops.logical_and(iterate_index < n, extra_test(*state)) + + results = while_stmt( + while_cond, + while_body, + init_state=(0,) + init_state, + extra_deps=(iter_,), + opts=dict(maximum_iterations=n)) + # Dropping the iteration index because it's not syntactically visible. + results = results[1:] + + # TODO(mdan): Remove this special case. + if len(results) == 1: + return results[0] + return results + + +def _dataset_for_stmt(ds, extra_test, body, init_state): + """Overload of for_stmt that iterates over TF Datasets.""" + # Because Datsets only expose get_next, in the style of Python iterators, + # we are forced to unpack the loop as: + # + # epoch_number, iterate = ds.get_next() + # while epoch_number < 2: + # + # epoch_number, iterate = ds.get_next() + epoch_numbers = dataset_ops.Dataset.range(2) + def tag_with(ds, tag): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(tag).repeat(), ds)) + ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i)) + + iterator = ds_with_epoch.make_initializable_iterator() + with ops.control_dependencies((iterator.initializer,)): + epoch_number, iterate = iterator.get_next() + + def while_body(epoch_number, iterate, *state): + new_state = body(iterate, *state) + epoch_number, iterate = iterator.get_next() + return (epoch_number, iterate) + new_state + + def while_cond(epoch_number, iterate, *state): + del iterate + return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state)) + + results = while_stmt( + while_cond, + while_body, + init_state=(epoch_number, iterate) + init_state, + extra_deps=()) + # Dropping the epoch number and iterate because they are not not syntactically + # visible. + results = results[2:] + + # TODO(mdan): Remove this special case. + if len(results) == 1: + return results[0] + return results + + +def while_stmt(test, body, init_state, extra_deps, opts=None): + """Functional form of a while statement. + + The loop operates on a so-called state, which includes all symbols that are + variant across loop iterations. In what follows we refer to state as either + a tuple of entities that represent an actual state, or a list of arguments + of the corresponding types. + + Args: + test: Callable with the state as arguments, and boolean return type. + The loop condition. + body: Callable with the state as arguments, and state as return type. + The actual loop body. + init_state: Tuple containing the initial state. + extra_deps: Tuple containing additional entities on which the loop may + depend, such as loop invariants referenced by test. Used + exclusively for dispatch control. + opts: Optional dict of extra loop parameters. + + Returns: + Tuple containing the final state. + """ + # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch. + # That could be something as simple as a collection of dispatch rules, with + # some prioritization. + if any(tensor_util.is_tensor(v) for v in init_state + extra_deps): + return _tf_while_stmt(test, body, init_state, opts) + else: + return _py_while_stmt(test, body, init_state, opts) + + +def _tf_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that stages a TF while_stmt.""" + if opts is None: + opts = {} + return control_flow_ops.while_loop(test, body, init_state, **opts) + + +def _py_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that executes a Python while loop.""" + del opts + state = init_state + while test(*state): + state = body(*state) + return state + + +def if_stmt(cond, body, orelse): + """Functional form of an if statement. + + Args: + cond: Boolean. + body: Callable with no arguments, and outputs of the positive (if) branch + as return type. + orelse: Callable with no arguments, and outputs of the negative (else) + branch as return type. + + Returns: + Tuple containing the statement outputs. + """ + if tensor_util.is_tensor(cond): + return _tf_if_stmt(cond, body, orelse) + else: + return _py_if_stmt(cond, body, orelse) + + +def _tf_if_stmt(cond, body, orelse): + """Overload of if_stmt that stages a TF cond.""" + return control_flow_ops.cond(cond, body, orelse) + + +def _py_if_stmt(cond, body, orelse): + """Overload of if_stmt that executes a Python if statement.""" + return body() if cond else orelse() diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b14d7edba38461692d9e999a6ce80a5fd84ba80d --- /dev/null +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for control_flow module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import control_flow +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ForLoopTest(test.TestCase): + + def test_tensor(self): + s = control_flow.for_stmt( + constant_op.constant([1, 2, 3, 4]), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), + init_state=(0,)) + with self.test_session() as sess: + self.assertEqual((10,), sess.run(s)) + + def test_python(self): + s = control_flow.for_stmt( + range(5), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), + init_state=(0,)) + self.assertEqual(10, s) + + def test_dataset(self): + to_int32 = lambda i: math_ops.cast(i, dtypes.int32) + s = control_flow.for_stmt( + dataset_ops.Dataset.range(5).map(to_int32), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), + init_state=(0,)) + with self.test_session() as sess: + self.assertEqual((10,), sess.run(s)) + + +class WhileLoopTest(test.TestCase): + + def test_tensor(self): + n = constant_op.constant(5) + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i,), + init_state=(0, 0), + extra_deps=(n,)) + with self.test_session() as sess: + self.assertEqual((5, 10), sess.run(results)) + + def test_python(self): + n = 5 + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i), + init_state=(0, 0), + extra_deps=(n,)) + self.assertEqual((5, 10), results) + + +class IfStmtTest(test.TestCase): + + def test_tensor(self): + def test_if_stmt(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: 1, + orelse=lambda: -1) + with self.test_session() as sess: + self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True)))) + self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False)))) + + def test_python(self): + self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1)) + self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..06d8727b0fcc30b532b3f11281cd1a83c51ac8bc --- /dev/null +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -0,0 +1,267 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to data structures: list append, subscripts, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables + + +# TODO(mdan): Once control flow supports objects, repackage as a class. + + +def new_list(iterable=None): + """The list constructor. + + Args: + iterable: Optional elements to fill the list with. + + Returns: + A list-like object. The exact return value depends on the initial elements. + """ + if iterable: + elements = tuple(iterable) + else: + elements = () + + # TODO(mdan): Extend these criteria. + if any(isinstance(el, variables.Variable) for el in elements): + return _py_list_new(elements) + return _tf_tensor_list_new(elements) + + +def _tf_tensor_list_new(elements): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + element_dtype = tuple(all_dtypes)[0] + else: + # Heterogeneous lists are ok. + element_dtype = dtypes.variant + + # TODO(mdan): This may fail for elements of variable shapes. + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + element_shape = array_ops.shape(elements[0]) + else: + # Heterogeneous lists are ok. + element_shape = constant_op.constant(-1) # unknown shape, by convention + + l = list_ops.empty_tensor_list( + element_shape=element_shape, element_dtype=element_dtype) + for el in elements: + l = list_ops.tensor_list_push_back(l, el) + return l + + +def _py_list_new(elements): + """Overload of new_list that creates a Python list.""" + return list(elements) + + +def list_append(list_, x): + """The list append function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports append semantics. + x: The element to append. + + Returns: + Same as list_, after the append was performed. + + Raises: + ValueError: if list_ is not of a known list-like type. + """ + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(list_, x) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_append(list_, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_append(list_, x) + + +def _tf_tensor_list_append(list_, x): + """Overload of list_append that stages a Tensor list write.""" + def empty_list_of_elements_like_x(): + tensor_x = ops.convert_to_tensor(x) + return list_ops.empty_tensor_list( + element_shape=array_ops.shape(tensor_x), + element_dtype=tensor_x.dtype) + + list_ = control_flow_ops.cond( + list_ops.tensor_list_length(list_) > 0, + lambda: list_, + empty_list_of_elements_like_x, + ) + return list_ops.tensor_list_push_back(list_, x) + + +def _tf_tensorarray_append(list_, x): + """Overload of list_append that stages a TensorArray write.""" + return list_.write(list_.size(), x) + + +def _py_list_append(list_, x): + """Overload of list_append that executes a Python list append.""" + # Revert to the original call. + list_.append(x) + return list_ + + +class ListPopOpts( + collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): + pass + + +def list_pop(list_, i, opts): + """The list pop function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports pop semantics. + i: Optional index to pop from. May be None. + opts: A ListPopOpts. + + Returns: + Tuple (x, out_list_): + out_list_: same as list_, after the removal was performed. + x: the removed element value. + + Raises: + ValueError: if list_ is not of a known list-like type or the operation is + not supported for that type. + """ + assert isinstance(opts, ListPopOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + raise ValueError('TensorArray does not support item removal') + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_pop(list_, i, opts) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_pop(list_, i) + + +def _tf_tensor_list_pop(list_, i, opts): + """Overload of list_pop that stages a Tensor list pop.""" + if i is not None: + raise NotImplementedError('tensor lists only support removing from the end') + + if opts.element_dtype is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'type; use set_element_type to annotate it') + if opts.element_shape is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'shape; use set_element_type to annotate it') + list_out, x = list_ops.tensor_list_pop_back( + list_, element_dtype=opts.element_dtype) + x.set_shape(opts.element_shape) + return list_out, x + + +def _py_list_pop(list_, i): + """Overload of list_pop that executes a Python list append.""" + if i is None: + x = list_.pop() + else: + x = list_.pop(i) + return list_, x + + +# TODO(mdan): Look into reducing duplication between all these containers. +class ListStackOpts( + collections.namedtuple('ListStackOpts', + ('element_dtype', 'original_call'))): + pass + + +def list_stack(list_, opts): + """The list stack function. + + This does not have a direct correspondent in Python. The closest idiom to + this is tf.append or np.stack. It's different from those in the sense that it + accepts a Tensor list, rather than a list of tensors. It can also accept + TensorArray. When the target is anything else, the dispatcher will rely on + ctx.original_call for fallback. + + Args: + list_: An entity that supports append semantics. + opts: A ListStackOpts object. + + Returns: + The output of the stack operation, typically a Tensor. + """ + assert isinstance(opts, ListStackOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_stack(list_) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_stack(list_, opts) + else: + # No-op for primitive Tensor arguments. + return list_ + else: + return _py_list_stack(list_, opts) + + +def _tf_tensorarray_stack(list_): + """Overload of list_stack that stages a TensorArray stack.""" + return list_.stack() + + +def _tf_tensor_list_stack(list_, opts): + """Overload of list_stack that stages a Tensor list write.""" + if opts.element_dtype is None: + raise ValueError('cannot stack a list without knowing its element type;' + ' use set_element_type to annotate it') + return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) + + +def _py_list_stack(list_, opts): + """Overload of list_stack that executes a Python list append.""" + # Revert to the original call. + return opts.original_call(list_) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbb52d6c10b241ec754c7dea599fa15a869595f --- /dev/null +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -0,0 +1,117 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for data_structures module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class ListTest(test.TestCase): + + def test_new_list_empty(self): + l = data_structures.new_list() + # Can't evaluate an empty list. + # TODO(mdan): sess.run should allow tf.variant maybe? + self.assertTrue(isinstance(l, ops.Tensor)) + + def test_new_list_tensor(self): + l = data_structures.new_list([3, 4, 5]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_append_tensor_list(self): + l = data_structures.new_list() + x = constant_op.constant([1, 2, 3]) + l = data_structures.list_append(l, x) + + t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [[1, 2, 3]]) + + def test_append_tensorarray(self): + l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) + l1 = data_structures.list_append(l, 1) + l2 = data_structures.list_append(l1, 2) + with self.test_session() as sess: + self.assertAllEqual(sess.run(l1.stack()), [1]) + self.assertAllEqual(sess.run(l2.stack()), [1, 2]) + + def test_append_python(self): + l = [] + self.assertAllEqual(data_structures.list_append(l, 1), [1]) + self.assertAllEqual(data_structures.list_append(l, 2), [1, 2]) + + def test_pop_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListPopOpts( + element_dtype=initial_list.dtype, + element_shape=(2,)) + + with self.assertRaises(NotImplementedError): + data_structures.list_pop(l, 0, opts) + + with self.test_session() as sess: + l, x = data_structures.list_pop(l, None, opts) + self.assertAllEqual(sess.run(x), [3, 4]) + + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[1, 2]]) + + def test_pop_python(self): + l = [1, 2, 3] + opts = data_structures.ListPopOpts(element_dtype=None, element_shape=()) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3)) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2)) + + def test_stack_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListStackOpts( + element_dtype=initial_list.dtype, original_call=None) + + with self.test_session() as sess: + t = data_structures.list_stack(l, opts) + self.assertAllEqual(sess.run(t), sess.run(initial_list)) + + def test_stack_fallback(self): + + def dummy_function(l): + # Lazy person's mock: just transform the argument in a way in which we + # can check that this function was indeed called. + return [x * 2 for x in l] + + opts = data_structures.ListStackOpts( + element_dtype=None, original_call=dummy_function) + + self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/contrib/autograph/operators/dispatch_context.py similarity index 60% rename from tensorflow/python/keras/applications/densenet/__init__.py rename to tensorflow/contrib/autograph/operators/dispatch_context.py index 6b8ea83920733a3a442171616ab460ffaf831521..097002465bd140eb92ee65b9dcd4e3643a0357d2 100644 --- a/tensorflow/python/keras/applications/densenet/__init__.py +++ b/tensorflow/contrib/autograph/operators/dispatch_context.py @@ -12,18 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DenseNet Keras applications.""" +"""Structures that allow uniform control over the dispatch process.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input +import collections -del absolute_import -del division -del print_function + +# TODO(mdan): This is where macro override controls fit. + + +class DispatchContext(collections.namedtuple( + 'DispatchContext', + ('options',))): + """Allows passing additional parameters to the specific implementations. + + Attributes: + options: Optional dict of extra arguments that may be required by specific + implementations. + """ + + def option(self, name): + return self.options[name] + + +NO_CTX = DispatchContext(options={}) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..04fbeb2f6e39234cad139442704fd7a8d0f56172 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to slicing operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +# TODO(mdan): Support extended slices. + + +class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): + pass + + +def get_item(target, i, opts): + """The slice read operator (i.e. __getitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports getitem semantics. + i: Index to read from. + opts: A GetItemOpts object. + + Returns: + The read element. + + Raises: + ValueError: if target is not of a supported type. + """ + assert isinstance(opts, GetItemOpts) + + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_get_item(target, i) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_get_item(target, i, opts) + else: + return _tf_tensor_get_item(target, i) + else: + return _py_get_item(target, i) + + +def _tf_tensorarray_get_item(target, i): + """Overload of get_item that stages a TensorArray read.""" + return target.read(i) + + +def _tf_tensor_list_get_item(target, i, opts): + """Overload of get_item that stages a Tensor list read.""" + if opts.element_dtype is None: + raise ValueError('cannot retrieve from a list without knowing its ' + 'element type; use set_element_type to annotate it') + x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) + return x + + +def _tf_tensor_get_item(target, i): + """Overload of get_item that stages a Tensor (not Tensor list) read.""" + return target[i] + + +def _py_get_item(target, i): + """Overload of get_item that executes a Python list modification.""" + return target[i] + + +def set_item(target, i, x): + """The slice write operator (i.e. __setitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports setitem semantics. + i: Index to modify. + x: The new element value. + + Returns: + Same as target, after the update was performed. + + Raises: + ValueError: if target is not of a supported type. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_set_item(target, i, x) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_set_item(target, i, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % target) + else: + return _py_set_item(target, i, x) + + +def _tf_tensorarray_set_item(target, i, x): + """Overload of set_item that stages a TensorArray write.""" + return target.write(i, x) + + +def _tf_tensor_list_set_item(target, i, x): + """Overload of set_item that stages a Tensor list update.""" + return list_ops.tensor_list_set_item(target, i, x) + + +def _py_set_item(target, i, x): + """Overload of set_item that executes a Python list modification.""" + target[i] = x + return target diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aacb9d2015fec56a8df5ad85a20b733765ba26 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SlicesTest(test.TestCase): + + def test_set_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + l = slices.set_item(l, 0, [5, 6]) + + with self.test_session() as sess: + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]]) + + def test_get_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + t = slices.get_item( + l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) + + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD similarity index 80% rename from tensorflow/contrib/py2tf/pyct/BUILD rename to tensorflow/contrib/autograph/pyct/BUILD index e3c0da4b10f9ffbee1b2a906b64d4762f41d97b4..989b821e53a5cefbe39095e669f9a9e0bec65b8a 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -24,6 +24,7 @@ py_library( "ast_util.py", "compiler.py", "context.py", + "inspect_utils.py", "parser.py", "pretty_printer.py", "qual_names.py", @@ -65,6 +66,18 @@ py_test( name = "compiler_test", srcs = ["compiler_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( + name = "inspect_utils_test", + srcs = ["inspect_utils_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -112,3 +125,15 @@ py_test( "@gast_archive//:gast", ], ) + +py_test( + name = "transformer_test", + srcs = ["transformer_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/contrib/py2tf/pyct/__init__.py b/tensorflow/contrib/autograph/pyct/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/__init__.py rename to tensorflow/contrib/autograph/pyct/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py similarity index 81% rename from tensorflow/contrib/py2tf/pyct/anno.py rename to tensorflow/contrib/autograph/pyct/anno.py index 7a0528b6d0b65b6604930b7a13d8493af9d61f02..ae861627fd65cca057e7bf1af41424e605d4b7a1 100644 --- a/tensorflow/contrib/py2tf/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -46,8 +46,15 @@ class Basic(NoValue): '`name_map` allows renaming symbols.') -def getanno(node, key, field_name='___pyct_anno'): - return getattr(node, field_name)[key] +FAIL = object() + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or + (hasattr(node, field_name) and (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default def hasanno(node, key, field_name='___pyct_anno'): @@ -70,3 +77,12 @@ def delanno(node, key, field_name='___pyct_anno'): if not annotations: delattr(node, field_name) node._fields = tuple(f for f in node._fields if f != field_name) + + +def copyanno(from_node, to_node, key, field_name='___pyct_anno'): + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) diff --git a/tensorflow/contrib/py2tf/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py similarity index 70% rename from tensorflow/contrib/py2tf/pyct/anno_test.py rename to tensorflow/contrib/autograph/pyct/anno_test.py index ff40bfe1f50ae731648afdf509c26c3a70d3f6cb..f2c0c8cf05ca4b3671eb653ce56f6da61de54aee 100644 --- a/tensorflow/contrib/py2tf/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -20,10 +20,13 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno from tensorflow.python.platform import test +# TODO(mdan): Consider strong types instead of primitives. + + class AnnoTest(test.TestCase): def test_basic(self): @@ -35,12 +38,25 @@ class AnnoTest(test.TestCase): anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) - self.assertEqual(3, anno.getanno(node, 'foo')) + self.assertEqual(anno.getanno(node, 'foo'), 3) + self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + self.assertIsNone(anno.getanno(node, 'foo', default=None)) + + def test_copyanno(self): + node_1 = ast.Name() + anno.setanno(node_1, 'foo', 3) + + node_2 = ast.Name() + anno.copyanno(node_1, node_2, 'foo') + anno.copyanno(node_1, node_2, 'bar') + + self.assertTrue(anno.hasanno(node_2, 'foo')) + self.assertFalse(anno.hasanno(node_2, 'bar')) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py similarity index 51% rename from tensorflow/contrib/py2tf/pyct/ast_util.py rename to tensorflow/contrib/autograph/pyct/ast_util.py index f916775b9cf3cec960ec2896c334f1d737862205..c4f82d11708393a6029d3f17be428b47eb9342ff 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -22,13 +22,14 @@ import ast import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser class CleanCopier(gast.NodeVisitor): - """Copy AST nodes. + """Copies AST nodes. - The copied nodes will ignore almost all fields that prefixed by '__'. + The copied nodes will ignore almost all fields that are prefixed by '__'. Exceptions make some annotations. """ @@ -84,7 +85,10 @@ class SymbolRenamer(gast.NodeTransformer): return self._process(node) def visit_Attribute(self, node): - return self._process(node) + if anno.hasanno(node, anno.Basic.QN): + return self._process(node) + # Attributes of dynamic objects will not have a QN. + return self.generic_visit(node) def rename_symbols(node, name_map): @@ -94,3 +98,88 @@ def rename_symbols(node, name_map): elif isinstance(node, tuple): return tuple(renamer.visit(n) for n in node) return renamer.visit(node) + + +def keywords_to_dict(keywords): + keys = [] + values = [] + for kw in keywords: + keys.append(gast.Str(kw.arg)) + values.append(kw.value) + return gast.Dict(keys=keys, values=values) + + +class PatternMatcher(gast.NodeVisitor): + """Matches a node against a pattern represented by a node. + + The pattern may contain wildcards represented by the symbol '_'. + """ + + def __init__(self, pattern): + self.pattern = pattern + self.pattern_stack = [] + self.matches = True + + def compare_and_visit(self, node, pattern): + self.pattern_stack.append(self.pattern) + self.pattern = pattern + self.generic_visit(node) + self.pattern = self.pattern_stack.pop() + + def no_match(self): + self.matches = False + return False + + def is_wildcard(self, p): + if isinstance(p, (list, tuple)) and len(p) == 1: + p, = p + if isinstance(p, gast.Name) and p.id == '_': + return True + if p == '_': + return True + return False + + def generic_visit(self, node): + if not self.matches: + return + + pattern = self.pattern + for f in node._fields: + if f.startswith('__'): + continue + + if not hasattr(node, f): + if hasattr(pattern, f) and getattr(pattern, f): + return self.no_match() + else: + continue + if not hasattr(pattern, f): + return self.no_match() + + v = getattr(node, f) + p = getattr(pattern, f) + + if self.is_wildcard(p): + continue + if isinstance(v, (list, tuple)): + if not isinstance(p, (list, tuple)) or len(v) != len(p): + return self.no_match() + for v_item, p_item in zip(v, p): + self.compare_and_visit(v_item, p_item) + elif isinstance(v, (gast.AST, ast.AST)): + if not isinstance(v, type(p)) and not isinstance(p, type(v)): + return self.no_match() + self.compare_and_visit(v, p) + else: + # Assume everything else is a value type. + if v != p: + return self.no_match() + + +def matches(node, pattern): + if isinstance(pattern, str): + pattern = parser.parse_expression(pattern) + matcher = PatternMatcher(pattern) + matcher.visit(node) + return matcher.matches + diff --git a/tensorflow/contrib/py2tf/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py similarity index 53% rename from tensorflow/contrib/py2tf/pyct/ast_util_test.py rename to tensorflow/contrib/autograph/pyct/ast_util_test.py index e0b00c178168f96e656c57cc75a76e6da8af1d8a..3afa04a50685d19c90944c14ed39f9d3ad35e486 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.python.platform import test @@ -33,15 +35,15 @@ class AstUtilTest(test.TestCase): ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute( - ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', - None) + ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( - node, - { - qual_names.QN('a'): qual_names.QN('renamed_a'), - qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), + node, { + qual_names.QN('a'): + qual_names.QN('renamed_a'), + qual_names.QN(qual_names.QN('b'), attr='c'): + qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') @@ -74,6 +76,43 @@ class AstUtilTest(test.TestCase): self.assertFalse(ret is new_node.body[0]) self.assertFalse(hasattr(new_node.body[0], '__foo')) + def test_keywords_to_dict(self): + keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords + d = ast_util.keywords_to_dict(keywords) + # Make sure we generate a usable dict node by attaching it to a variable and + # compiling everything. + output = parser.parse_str('b = 3') + output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) + result, _ = compiler.ast_to_object(output) + self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + + def assertMatch(self, target_str, pattern_str): + node = parser.parse_expression(target_str) + pattern = parser.parse_expression(pattern_str) + self.assertTrue(ast_util.matches(node, pattern)) + + def assertNoMatch(self, target_str, pattern_str): + node = parser.parse_expression(target_str) + pattern = parser.parse_expression(pattern_str) + self.assertFalse(ast_util.matches(node, pattern)) + + def test_matches_symbols(self): + self.assertMatch('foo', '_') + self.assertNoMatch('foo()', '_') + self.assertMatch('foo + bar', 'foo + _') + self.assertNoMatch('bar + bar', 'foo + _') + self.assertNoMatch('foo - bar', 'foo + _') + + def test_matches_function_args(self): + self.assertMatch('super(Foo, self).__init__(arg1, arg2)', + 'super(_).__init__(_)') + self.assertMatch('super().__init__()', 'super(_).__init__(_)') + self.assertNoMatch('super(Foo, self).bar(arg1, arg2)', + 'super(_).__init__(_)') + self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)') + self.assertNoMatch('super(Foo, self).__init__()', + 'super(Bar, _).__init__(_)') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py similarity index 90% rename from tensorflow/contrib/py2tf/pyct/compiler.py rename to tensorflow/contrib/autograph/pyct/compiler.py index 51cf6930e8bcb3728ee55bf5d4781f01a5ef73bd..24c4517afa89147101f80af3ef60237132c1144c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -31,7 +31,7 @@ import astor import gast -def ast_to_source(node, indentation): +def ast_to_source(node, indentation=' '): """Return the source code of given AST.""" if isinstance(node, gast.AST): node = gast.gast_to_ast(node) @@ -39,7 +39,10 @@ def ast_to_source(node, indentation): astor.string_repr.pretty_string) generator.visit(node) generator.result.append('\n') - return astor.source_repr.pretty_source(generator.result).lstrip() + # In some versions of Python, literals may appear as actual values. This + # ensures everything is string. + code = map(str, generator.result) + return astor.source_repr.pretty_source(code).lstrip() def ast_to_object( diff --git a/tensorflow/contrib/py2tf/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py similarity index 83% rename from tensorflow/contrib/py2tf/pyct/compiler_test.py rename to tensorflow/contrib/autograph/pyct/compiler_test.py index c1f84238efa7dd6fc0748748a2cb4f074572b4c6..98cdc1506b6aced603df99662f1468687a55f92c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -22,12 +22,29 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect class CompilerTest(test.TestCase): + def test_parser_compile_idempotent(self): + + def test_fn(x): + a = True + b = '' + if a: + b = x + 1 + return b + + self.assertEqual( + textwrap.dedent(tf_inspect.getsource(test_fn)), + tf_inspect.getsource( + compiler.ast_to_object( + parser.parse_entity(test_fn)[0].body[0])[0].test_fn)) + def test_ast_to_source(self): node = gast.If( test=gast.Num(1), diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py similarity index 82% rename from tensorflow/contrib/py2tf/pyct/context.py rename to tensorflow/contrib/autograph/pyct/context.py index fef74ebefa290369c7310af6d7e4faeef44d9aee..b34015cfd2888f0dbeb6492b9e7335d561bf4763 100644 --- a/tensorflow/contrib/py2tf/pyct/context.py +++ b/tensorflow/contrib/autograph/pyct/context.py @@ -22,6 +22,8 @@ from __future__ import print_function class EntityContext(object): """Contains information about an entity, like source code. + In general, objects of this class should be considered immutable. + Attributes: namer: Namer that matches the contract of all converters. source_code: The entity's source code. @@ -30,14 +32,18 @@ class EntityContext(object): (excluding parameters). arg_values: Dict[str->*], containing parameter values, if known. arg_types: Dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. """ + # TODO(mdan): Remove the default and update tests. def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, recursive): + arg_types, owner_type, recursive, type_annotation_func=None): self.namer = namer self.source_code = source_code self.source_file = source_file self.namespace = namespace self.arg_values = {} if arg_values is None else arg_values self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type self.recursive = recursive + self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eef74599a7d5415b4b05d2f05fb094b1dcd33323 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py @@ -0,0 +1,161 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Live entity inspection utilities. + +This module contains whatever inspect doesn't offer out of the box. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import types + +import six + +from tensorflow.python.util import tf_inspect + + +def isbuiltin(f): + # Note these return false for isinstance(f, types.BuiltinFunctionType) so we + # need to specifically check for them. + if f in (range, int, float): + return True + if isinstance(f, types.BuiltinFunctionType): + return True + if tf_inspect.isbuiltin(f): + return True + return False + + +def getnamespace(f): + """Returns the complete namespace of a function. + + Namespace is defined here as the mapping of all non-local variables to values. + This includes the globals and the closure variables. Note that this captures + the entire globals collection of the function, and may contain extra symbols + that it does not actually use. + + Args: + f: User defined function. + Returns: + A dict mapping symbol names to values. + """ + namespace = dict(six.get_function_globals(f)) + closure = six.get_function_closure(f) + freevars = six.get_function_code(f).co_freevars + if freevars and closure: + for name, cell in zip(freevars, closure): + namespace[name] = cell.cell_contents + return namespace + + +def _get_unbound_function(m): + # TODO(mdan): Figure out why six.get_unbound_function fails in some cases. + # The failure case is for tf.keras.Model. + if hasattr(m, 'im_func'): + return m.im_func + return m + + +def getdefiningclass(m, owner_class): + """Resolves the class (e.g. one of the superclasses) that defined a method.""" + # Normalize bound functions to their respective unbound versions. + m = _get_unbound_function(m) + for superclass in owner_class.__bases__: + if hasattr(superclass, m.__name__): + superclass_m = getattr(superclass, m.__name__) + if _get_unbound_function(superclass_m) is m: + return superclass + elif hasattr(m, '__self__') and m.__self__ == owner_class: + # Python 3 class methods only work this way it seems :S + return superclass + return owner_class + + +def getmethodclass(m): + """Resolves a function's owner, e.g. a method's class. + + Note that this returns the object that the function was retrieved from, not + necessarily the class where it was defined. + + This function relies on Python stack frame support in the interpreter, and + has the same limitations that inspect.currentframe. + + Limitations. This function will only work correctly if the owned class is + visible in the caller's global or local variables. + + Args: + m: A user defined function + + Returns: + The class that this function was retrieved from, or None if the function + is not an object or class method, or the class that owns the object or + method is not visible to m. + + Raises: + ValueError: if the class could not be resolved for any unexpected reason. + """ + + # Callable objects: return their own class. + if (not hasattr(m, '__name__') and hasattr(m, '__class__') and + hasattr(m, '__call__')): + if isinstance(m.__class__, six.class_types): + return m.__class__ + + # Instance method and class methods: should be bound to a non-null "self". + # If self is a class, then it's a class method. + if hasattr(m, '__self__'): + if m.__self__: + if tf_inspect.isclass(m.__self__): + return m.__self__ + return type(m.__self__) + + # Class, static and unbound methods: search all defined classes in any + # namespace. This is inefficient but more robust method. + owners = [] + caller_frame = tf_inspect.currentframe().f_back + try: + # TODO(mdan): This doesn't consider cell variables. + # TODO(mdan): This won't work if the owner is hidden inside a container. + # Cell variables may be pulled using co_freevars and the closure. + for v in itertools.chain(caller_frame.f_locals.values(), + caller_frame.f_globals.values()): + if hasattr(v, m.__name__): + candidate = getattr(v, m.__name__) + # Py2 methods may be bound or unbound, extract im_func to get the + # underlying function. + if hasattr(candidate, 'im_func'): + candidate = candidate.im_func + if hasattr(m, 'im_func'): + m = m.im_func + if candidate is m: + owners.append(v) + finally: + del caller_frame + + if owners: + if len(owners) == 1: + return owners[0] + + # If multiple owners are found, and are not subclasses, raise an error. + owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners) + for o in owner_types: + if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)): + return o + raise ValueError('Found too many owners of %s: %s' % (m, owners)) + + return None diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a212f676a616307b41feafafda9d1d794ba3d2d --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -0,0 +1,277 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for unspect_utils module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +import six + +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.python.platform import test + + +def decorator(f): + return f + + +def function_decorator(): + def dec(f): + return f + return dec + + +def wrapping_decorator(): + def dec(f): + def replacement(*_): + return None + + @wraps(f) + def wrapper(*args, **kwargs): + return replacement(*args, **kwargs) + return wrapper + return dec + + +class TestClass(object): + + def member_function(self): + pass + + @decorator + def decorated_member(self): + pass + + @function_decorator() + def fn_decorated_member(self): + pass + + @wrapping_decorator() + def wrap_decorated_member(self): + pass + + @staticmethod + def static_method(): + pass + + @classmethod + def class_method(cls): + pass + + +def free_function(): + pass + + +def factory(): + return free_function + + +def free_factory(): + def local_function(): + pass + return local_function + + +class InspectUtilsTest(test.TestCase): + + def test_getnamespace_globals(self): + ns = inspect_utils.getnamespace(factory) + self.assertEqual(ns['free_function'], free_function) + + def test_getnamespace_hermetic(self): + + # Intentionally hiding the global function to make sure we don't overwrite + # it in the global namespace. + free_function = object() # pylint:disable=redefined-outer-name + + def test_fn(): + return free_function + + ns = inspect_utils.getnamespace(test_fn) + globs = six.get_function_globals(test_fn) + self.assertTrue(ns['free_function'] is free_function) + self.assertFalse(globs['free_function'] is free_function) + + def test_getnamespace_locals(self): + + def called_fn(): + return 0 + + closed_over_list = [] + closed_over_primitive = 1 + + def local_fn(): + closed_over_list.append(1) + local_var = 1 + return called_fn() + local_var + closed_over_primitive + + ns = inspect_utils.getnamespace(local_fn) + self.assertEqual(ns['called_fn'], called_fn) + self.assertEqual(ns['closed_over_list'], closed_over_list) + self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) + self.assertTrue('local_var' not in ns) + + def test_getmethodclass(self): + + self.assertEqual( + inspect_utils.getmethodclass(free_function), None) + self.assertEqual( + inspect_utils.getmethodclass(free_factory()), None) + + self.assertEqual( + inspect_utils.getmethodclass(TestClass.member_function), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.fn_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.wrap_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.static_method), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.class_method), + TestClass) + + test_obj = TestClass() + self.assertEqual( + inspect_utils.getmethodclass(test_obj.member_function), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.fn_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.static_method), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.class_method), + TestClass) + + def test_getmethodclass_locals(self): + + def local_function(): + pass + + class LocalClass(object): + + def member_function(self): + pass + + @decorator + def decorated_member(self): + pass + + @function_decorator() + def fn_decorated_member(self): + pass + + @wrapping_decorator() + def wrap_decorated_member(self): + pass + + self.assertEqual( + inspect_utils.getmethodclass(local_function), None) + + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.member_function), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.fn_decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), + LocalClass) + + test_obj = LocalClass() + self.assertEqual( + inspect_utils.getmethodclass(test_obj.member_function), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.fn_decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), + LocalClass) + + def test_getmethodclass_callables(self): + class TestCallable(object): + + def __call__(self): + pass + + c = TestCallable() + self.assertEqual(inspect_utils.getmethodclass(c), TestCallable) + + def test_getdefiningclass(self): + class Superclass(object): + + def foo(self): + pass + + def bar(self): + pass + + @classmethod + def class_method(cls): + pass + + class Subclass(Superclass): + + def foo(self): + pass + + def baz(self): + pass + + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is + Superclass) + + def test_isbuiltin(self): + self.assertTrue(inspect_utils.isbuiltin(range)) + self.assertTrue(inspect_utils.isbuiltin(float)) + self.assertTrue(inspect_utils.isbuiltin(int)) + self.assertTrue(inspect_utils.isbuiltin(len)) + self.assertFalse(inspect_utils.isbuiltin(function_decorator)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py similarity index 64% rename from tensorflow/contrib/py2tf/pyct/parser.py rename to tensorflow/contrib/autograph/pyct/parser.py index dc7df883b349becd860bb0dbceab22cb39c750b5..c961efa892df6a21804dae8f52ef64bf99cd409e 100644 --- a/tensorflow/contrib/py2tf/pyct/parser.py +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -29,12 +29,30 @@ from tensorflow.python.util import tf_inspect def parse_entity(entity): - """Return the AST of given entity.""" + """Returns the AST of given entity.""" source = tf_inspect.getsource(entity) source = textwrap.dedent(source) return parse_str(source), source def parse_str(src): - """Return the AST of given piece of code.""" + """Returns the AST of given piece of code.""" return gast.parse(src) + + +def parse_expression(src): + """Returns the AST of given identifier. + + Args: + src: A piece of code that represents a single Python expression + Returns: + A gast.AST object. + Raises: + ValueError: if src does not consist of a single Expression. + """ + node = parse_str(src) + assert isinstance(node, gast.Module) + if len(node.body) != 1 and not isinstance(node.body[0], gast.Expr): + raise ValueError( + 'Expected a single expression, found instead %s' % node.body) + return node.body[0].value diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/autograph/pyct/parser_test.py similarity index 80% rename from tensorflow/contrib/py2tf/pyct/parser_test.py rename to tensorflow/contrib/autograph/pyct/parser_test.py index f35dfa04c70dc191078248c32f9a04d28133129a..007a4c6fb0393b7235808478d55b3ffa469f85d0 100644 --- a/tensorflow/contrib/py2tf/pyct/parser_test.py +++ b/tensorflow/contrib/autograph/pyct/parser_test.py @@ -20,28 +20,33 @@ from __future__ import print_function import textwrap -from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test -def f(x): - return x + 1 - - class ParserTest(test.TestCase): def test_parse_entity(self): + + def f(x): + return x + 1 + mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name) def test_parse_str(self): mod = parser.parse_str( textwrap.dedent(""" - def f(x): - return x + 1 + def f(x): + return x + 1 """)) self.assertEqual('f', mod.body[0].name) + def test_parse_expression(self): + node = parser.parse_expression('a.b') + self.assertEqual('a', node.value.id) + self.assertEqual('b', node.attr) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer.py b/tensorflow/contrib/autograph/pyct/pretty_printer.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/pretty_printer.py rename to tensorflow/contrib/autograph/pyct/pretty_printer.py diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py similarity index 96% rename from tensorflow/contrib/py2tf/pyct/pretty_printer_test.py rename to tensorflow/contrib/autograph/pyct/pretty_printer_test.py index 81e3f47b80b6cb3bb7ba9f4a1787d03df4151a99..0cb48f35760b7b2655eb5cf73017b70e28dae219 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py +++ b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py new file mode 100644 index 0000000000000000000000000000000000000000..da07013cf4f4309c0e24adda3017575d942861b7 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -0,0 +1,233 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for manipulating qualified names. + +A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite +(e.g. 'foo.bar') syntactic symbols. + +This is *not* related to the __qualname__ attribute used by inspect, which +refers to scopes. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import gast + +from tensorflow.contrib.autograph.pyct import anno + + +class Symbol(collections.namedtuple('Symbol', ['name'])): + """Represents a Python symbol.""" + + +class StringLiteral(collections.namedtuple('StringLiteral', ['value'])): + """Represents a Python string literal.""" + + def __str__(self): + return '\'%s\'' % self.value + + def __repr__(self): + return str(self) + + +class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])): + """Represents a Python numeric literal.""" + + def __str__(self): + return '%s' % self.value + + def __repr__(self): + return str(self) + + +# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans. +class QN(object): + """Represents a qualified name.""" + + def __init__(self, base, attr=None, subscript=None): + if attr is not None and subscript is not None: + raise ValueError('A QN can only be either an attr or a subscript, not ' + 'both: attr={}, subscript={}.'.format(attr, subscript)) + self._has_attr = False + self._has_subscript = False + + if attr is not None: + if not isinstance(base, QN): + raise ValueError( + 'for attribute QNs, base must be a QN; got instead "%s"' % base) + if not isinstance(attr, str): + raise ValueError('attr may only be a string; got instead "%s"' % attr) + self._parent = base + # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now. + self.qn = (base, attr) + self._has_attr = True + + elif subscript is not None: + if not isinstance(base, QN): + raise ValueError('For subscript QNs, base must be a QN.') + self._parent = base + self.qn = (base, subscript) + self._has_subscript = True + + else: + if not isinstance(base, (str, StringLiteral, NumberLiteral)): + # TODO(mdan): Require Symbol instead of string. + raise ValueError( + 'For simple QNs, base must be a string or a Literal object.') + assert '.' not in base and '[' not in base and ']' not in base + self._parent = None + self.qn = (base,) + + def is_symbol(self): + return isinstance(self.qn[0], str) + + def is_composite(self): + return len(self.qn) > 1 + + def has_subscript(self): + return self._has_subscript + + def has_attr(self): + return self._has_attr + + @property + def parent(self): + if self._parent is None: + raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) + return self._parent + + @property + def support_set(self): + """Returns the set of simple symbols that this QN relies on. + + This would be the smallest set of symbols necessary for the QN to + statically resolve (assuming properties and index ranges are verified + at runtime). + + Examples: + 'a.b' has only one support symbol, 'a' + 'a[i]' has two roots, 'a' and 'i' + """ + # TODO(mdan): This might be the set of Name nodes in the AST. Track those? + roots = set() + if self.has_attr(): + roots.update(self.parent.support_set) + elif self.has_subscript(): + roots.update(self.parent.support_set) + roots.update(self.qn[1].support_set) + else: + roots.add(self) + return roots + + def __hash__(self): + return hash(self.qn + (self._has_attr, self._has_subscript)) + + def __eq__(self, other): + return (isinstance(other, QN) and self.qn == other.qn and + self.has_subscript() == other.has_subscript() and + self.has_attr() == other.has_attr()) + + def __str__(self): + if self.has_subscript(): + return str(self.qn[0]) + '[' + str(self.qn[1]) + ']' + if self.has_attr(): + return '.'.join(map(str, self.qn)) + else: + return str(self.qn[0]) + + def __repr__(self): + return str(self) + + def ssf(self): + """Simple symbol form.""" + ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn] + ssf_string = '' + for i in range(0, len(self.qn) - 1): + if self.has_subscript(): + delimiter = '_sub_' + else: + delimiter = '_' + ssf_string += ssfs[i] + delimiter + return ssf_string + ssfs[-1] + + def ast(self): + # The caller must adjust the context appropriately. + if self.has_subscript(): + return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()), + None) + if self.has_attr(): + return gast.Attribute(self.parent.ast(), self.qn[-1], None) + + base = self.qn[0] + if isinstance(base, str): + return gast.Name(base, None, None) + elif isinstance(base, StringLiteral): + return gast.Str(base.value) + elif isinstance(base, NumberLiteral): + return gast.Num(base.value) + else: + assert False, ('the constructor should prevent types other than ' + 'str, StringLiteral and NumberLiteral') + + +class QnResolver(gast.NodeTransformer): + """Annotates nodes with QN information. + + Note: Not using NodeAnnos to avoid circular dependencies. + """ + + def visit_Name(self, node): + node = self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, QN(node.id)) + return node + + def visit_Attribute(self, node): + node = self.generic_visit(node) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) + return node + + def visit_Subscript(self, node): + # TODO(mdan): This may no longer apply if we overload getitem. + node = self.generic_visit(node) + s = node.slice + if not isinstance(s, gast.Index): + # TODO(mdan): Support range and multi-dimensional indices. + # Continuing silently because some demos use these. + return node + if isinstance(s.value, gast.Num): + subscript = QN(NumberLiteral(s.value.n)) + elif isinstance(s.value, gast.Str): + subscript = QN(StringLiteral(s.value.s)) + else: + # The index may be an expression, case in which a name doesn't make sense. + if anno.hasanno(node.slice.value, anno.Basic.QN): + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + else: + return node + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), + subscript=subscript)) + return node + + +def resolve(node): + return QnResolver().visit(node) diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py new file mode 100644 index 0000000000000000000000000000000000000000..264afd508cdb847315c486806b531dc1483ef622 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py @@ -0,0 +1,246 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for qual_names module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.qual_names import resolve +from tensorflow.python.platform import test + + +class QNTest(test.TestCase): + + def test_basic(self): + a = QN('a') + self.assertEqual(a.qn, ('a',)) + self.assertEqual(str(a), 'a') + self.assertEqual(a.ssf(), 'a') + self.assertEqual(a.ast().id, 'a') + self.assertFalse(a.is_composite()) + with self.assertRaises(ValueError): + _ = a.parent + + a_b = QN(a, attr='b') + self.assertEqual(a_b.qn, (a, 'b')) + self.assertEqual(str(a_b), 'a.b') + self.assertEqual(a_b.ssf(), 'a_b') + self.assertEqual(a_b.ast().value.id, 'a') + self.assertEqual(a_b.ast().attr, 'b') + self.assertTrue(a_b.is_composite()) + self.assertEqual(a_b.parent.qn, ('a',)) + + def test_subscripts(self): + a = QN('a') + b = QN('b') + a_sub_b = QN(a, subscript=b) + self.assertEqual(a_sub_b.qn, (a, b)) + self.assertEqual(str(a_sub_b), 'a[b]') + self.assertEqual(a_sub_b.ssf(), 'a_sub_b') + self.assertEqual(a_sub_b.ast().value.id, 'a') + self.assertEqual(a_sub_b.ast().slice.value.id, 'b') + self.assertTrue(a_sub_b.is_composite()) + self.assertTrue(a_sub_b.has_subscript()) + self.assertEqual(a_sub_b.parent.qn, ('a',)) + + c = QN('c') + b_sub_c = QN(b, subscript=c) + a_sub_b_sub_c = QN(a, subscript=b_sub_c) + self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c)) + self.assertTrue(a_sub_b.is_composite()) + self.assertTrue(a_sub_b_sub_c.is_composite()) + self.assertTrue(a_sub_b.has_subscript()) + self.assertTrue(a_sub_b_sub_c.has_subscript()) + self.assertEqual(b_sub_c.qn, (b, c)) + self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') + self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c') + self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a') + self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b') + self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c') + self.assertEqual(b_sub_c.ast().slice.value.id, 'c') + self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',)) + with self.assertRaises(ValueError): + QN('a', 'b') + + def test_equality(self): + a = QN('a') + a2 = QN('a') + a_b = QN(a, attr='b') + self.assertEqual(a2.qn, ('a',)) + with self.assertRaises(ValueError): + _ = a.parent + + a_b2 = QN(a, attr='b') + self.assertEqual(a_b2.qn, (a, 'b')) + self.assertEqual(a_b2.parent.qn, ('a',)) + + self.assertTrue(a2 == a) + self.assertFalse(a2 is a) + + self.assertTrue(a_b.parent == a) + self.assertTrue(a_b2.parent == a) + + self.assertTrue(a_b2 == a_b) + self.assertFalse(a_b2 is a_b) + self.assertFalse(a_b2 == a) + a_sub_b = QN(a, subscript='b') + a_sub_b2 = QN(a, subscript='b') + self.assertTrue(a_sub_b == a_sub_b2) + self.assertFalse(a_sub_b == a_b) + + def test_nested_attrs_subscripts(self): + a = QN('a') + b = QN('b') + c = QN('c') + b_sub_c = QN(b, subscript=c) + a_sub_b_sub_c = QN(a, subscript=b_sub_c) + + b_dot_c = QN(b, attr='c') + a_sub__b_dot_c = QN(a, subscript=b_dot_c) + + a_sub_b = QN(a, subscript=b) + a_sub_b__dot_c = QN(a_sub_b, attr='c') + + a_dot_b = QN(a, attr='b') + a_dot_b_sub_c = QN(a_dot_b, subscript=c) + + self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') + self.assertEqual(str(a_sub__b_dot_c), 'a[b.c]') + self.assertEqual(str(a_sub_b__dot_c), 'a[b].c') + self.assertEqual(str(a_dot_b_sub_c), 'a.b[c]') + + self.assertNotEqual(a_sub_b_sub_c, a_sub__b_dot_c) + self.assertNotEqual(a_sub_b_sub_c, a_sub_b__dot_c) + self.assertNotEqual(a_sub_b_sub_c, a_dot_b_sub_c) + + self.assertNotEqual(a_sub__b_dot_c, a_sub_b__dot_c) + self.assertNotEqual(a_sub__b_dot_c, a_dot_b_sub_c) + + self.assertNotEqual(a_sub_b__dot_c, a_dot_b_sub_c) + + def test_hashable(self): + d = {QN('a'): 'a', QN('b'): 'b'} + self.assertEqual(d[QN('a')], 'a') + self.assertEqual(d[QN('b')], 'b') + self.assertTrue(QN('c') not in d) + + def test_literals(self): + a = QN('a') + a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b'))) + a_sub_b = QN(a, subscript=QN('b')) + + self.assertNotEqual(a_sub_str_b, a_sub_b) + self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b)) + + a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3))) + self.assertEqual(a_sub_three.ast().slice.value.n, 3) + + def test_support_set(self): + a = QN('a') + b = QN('b') + c = QN('c') + a_sub_b = QN(a, subscript=b) + a_dot_b = QN(a, attr='b') + a_dot_b_dot_c = QN(a_dot_b, attr='c') + a_dot_b_sub_c = QN(a_dot_b, subscript=c) + + self.assertSetEqual(a.support_set, set((a,))) + self.assertSetEqual(a_sub_b.support_set, set((a, b))) + self.assertSetEqual(a_dot_b.support_set, set((a,))) + self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,))) + self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c))) + + +class QNResolverTest(test.TestCase): + + def assertQNStringIs(self, node, qn_str): + self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) + + def test_resolve(self): + samples = """ + a + a.b + (c, d.e) + [f, (g.h.i)] + j(k, l) + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'a') + self.assertQNStringIs(nodes[1], 'a.b') + self.assertQNStringIs(nodes[2].elts[0], 'c') + self.assertQNStringIs(nodes[2].elts[1], 'd.e') + self.assertQNStringIs(nodes[3].elts[0], 'f') + self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') + self.assertQNStringIs(nodes[4].func, 'j') + self.assertQNStringIs(nodes[4].args[0], 'k') + self.assertQNStringIs(nodes[4].args[1], 'l') + + def test_subscript_resolve(self): + samples = """ + x[i] + x[i.b] + a.b[c] + a.b[x.y] + a[z[c]] + a[b[c[d]]] + a[b].c + a.b.c[d].e.f + a.b[c[d]].e.f + a.b[c[d.e.f].g].h + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'x[i]') + self.assertQNStringIs(nodes[1], 'x[i.b]') + self.assertQNStringIs(nodes[2], 'a.b[c]') + self.assertQNStringIs(nodes[3], 'a.b[x.y]') + self.assertQNStringIs(nodes[4], 'a[z[c]]') + self.assertQNStringIs(nodes[5], 'a[b[c[d]]]') + self.assertQNStringIs(nodes[6], 'a[b].c') + self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f') + self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') + self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h') + + def test_function_calls(self): + samples = """ + a.b + a.b() + a().b + z[i] + z[i]() + z()[i] + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + self.assertQNStringIs(nodes[0], 'a.b') + self.assertQNStringIs(nodes[1].func, 'a.b') + self.assertQNStringIs(nodes[2].value.func, 'a') + self.assertQNStringIs(nodes[3], 'z[i]') + self.assertQNStringIs(nodes[4].func, 'z[i]') + self.assertQNStringIs(nodes[5].value.func, 'z') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD similarity index 66% rename from tensorflow/contrib/py2tf/pyct/static_analysis/BUILD rename to tensorflow/contrib/autograph/pyct/static_analysis/BUILD index fbfce18c60cca4b105e7de3c3ea7b9c3438f6b2a..8064a967cd389e88d3febbeb21cac87b0fef9e18 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,13 +19,14 @@ py_library( srcs = [ "activity.py", "annos.py", + "cfg.py", "live_values.py", "type_info.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", ], ) @@ -34,9 +35,23 @@ py_test( name = "activity_test", srcs = ["activity_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", "@gast_archive//:gast", ], @@ -46,9 +61,10 @@ py_test( name = "live_values_test", srcs = ["live_values_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -59,7 +75,8 @@ py_test( srcs_version = "PY2AND3", deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py rename to tensorflow/contrib/autograph/pyct/static_analysis/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py similarity index 55% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 1c93e1603113d48176af7a97a0f37321e6f67586..4d7b0cbb7b8f6ee5bd64553644dc3ec9b8bca95b 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -22,11 +22,13 @@ import copy import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). +# TODO(alexbw): Ignore named literals (e.g. None) class Scope(object): @@ -42,16 +44,20 @@ class Scope(object): used: identifiers referenced in this scope """ - def __init__(self, parent, isolated=True): + def __init__(self, parent, isolated=True, add_unknown_symbols=False): """Create a new scope. Args: parent: A Scope or None. isolated: Whether the scope is isolated, that is, whether variables created in this scope should be visible to the parent scope. + add_unknown_symbols: Whether to handle attributed and subscripts + without having first seen the base name. + E.g., analyzing the statement 'x.y = z' without first having seen 'x'. """ self.isolated = isolated self.parent = parent + self.add_unknown_symbols = add_unknown_symbols self.modified = set() self.created = set() self.used = set() @@ -70,13 +76,33 @@ class Scope(object): tuple(self.modified)) def copy_from(self, other): + """Recursively copies the contents of this scope from another scope.""" + if (self.parent is None) != (other.parent is None): + raise ValueError('cannot copy scopes of different structures') + if other.parent is not None: + self.parent.copy_from(other.parent) + self.isolated = other.isolated self.modified = copy.copy(other.modified) self.created = copy.copy(other.created) self.used = copy.copy(other.used) self.params = copy.copy(other.params) self.returned = copy.copy(other.returned) + @classmethod + def copy_of(cls, other): + if other.parent is not None: + parent = cls.copy_of(other.parent) + else: + parent = None + new_copy = cls(parent) + new_copy.copy_from(other) + return new_copy + def merge_from(self, other): + if (self.parent is None) != (other.parent is None): + raise ValueError('cannot merge scopes of different structures') + if other.parent is not None: + self.parent.merge_from(other.parent) self.modified |= other.modified self.created |= other.created self.used |= other.used @@ -112,18 +138,22 @@ class Scope(object): def mark_param(self, name): self.params.add(name) - def mark_creation(self, name): + def mark_creation(self, name, writes_create_symbol=False): + """Mark a qualified name as created.""" if name.is_composite(): parent = name.parent - if self.has(parent): - # This is considered mutation of the parent, not creation. - # TODO(mdan): Is that really so? + if not writes_create_symbol: return else: - raise ValueError('Unknown symbol "%s".' % parent) + if not self.has(parent): + if self.add_unknown_symbols: + self.mark_read(parent) + else: + raise ValueError('Unknown symbol "%s".' % parent) self.created.add(name) def mark_write(self, name): + """Marks the given symbol as modified in the current scope.""" self.modified.add(name) if self.isolated: self.mark_creation(name) @@ -141,25 +171,62 @@ class Scope(object): self.parent.mark_returned(name) -class ActivityAnalizer(transformer.Base): - """Annotates nodes with local scope information. See Scope.""" +class ActivityAnalyzer(transformer.Base): + """Annotates nodes with local scope information. + + See Scope. - def __init__(self, context, parent_scope): - super(ActivityAnalizer, self).__init__(context) - self.scope = Scope(parent_scope) + The use of this class requires that qual_names.resolve() has been called on + the node. This class will ignore nodes have not been + annotated with their qualified names. + """ + + def __init__(self, context, parent_scope=None, add_unknown_symbols=False): + super(ActivityAnalyzer, self).__init__(context) + self.scope = Scope(parent_scope, None, add_unknown_symbols) self._in_return_statement = False + self._in_aug_assign = False - def _track_symbol(self, node): + @property + def _in_constructor(self): + if len(self.enclosing_entities) > 1: + innermost = self.enclosing_entities[-1] + parent = self.enclosing_entities[-2] + return isinstance(parent, gast.ClassDef) and innermost.name == '__init__' + return False + + def _node_sets_self_attribute(self, node): + if anno.hasanno(node, anno.Basic.QN): + qn = anno.getanno(node, anno.Basic.QN) + # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. + if qn.has_attr and qn.parent.qn == ('self',): + return True + return False + + def _track_symbol(self, + node, + composite_writes_alter_parent=False, + writes_create_symbol=False): + # A QN may be missing when we have an attribute (or subscript) on a function + # call. Example: a().b + if not anno.hasanno(node, anno.Basic.QN): + return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) + if qn.is_composite and composite_writes_alter_parent: + self.scope.mark_write(qn.parent) + if writes_create_symbol: + self.scope.mark_creation(qn, writes_create_symbol=True) + if self._in_aug_assign: + self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - # TODO(mdan): This bay be incorrect with nested functions. + # TODO(mdan): This may be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_creation(qn) @@ -175,6 +242,14 @@ class ActivityAnalizer(transformer.Base): if self._in_return_statement: self.scope.mark_returned(qn) + def visit_AugAssign(self, node): + # Special rules for AugAssign. In Assign, the target is only written, + # but in AugAssig (e.g. a += b), the target is both read and written. + self._in_aug_assign = True + self.generic_visit(node) + self._in_aug_assign = False + return node + def visit_Name(self, node): self.generic_visit(node) self._track_symbol(node) @@ -182,7 +257,18 @@ class ActivityAnalizer(transformer.Base): def visit_Attribute(self, node): self.generic_visit(node) - self._track_symbol(node) + if self._in_constructor and self._node_sets_self_attribute(node): + self._track_symbol( + node, composite_writes_alter_parent=True, writes_create_symbol=True) + else: + self._track_symbol(node) + return node + + def visit_Subscript(self, node): + self.generic_visit(node) + # Subscript writes (e.g. a[b] = "value") are considered to modify + # both the element itself (a[b]) and its parent (a). + self._track_symbol(node, composite_writes_alter_parent=True) return node def visit_Print(self, node): @@ -224,21 +310,46 @@ class ActivityAnalizer(transformer.Base): # modifies the parent state causing the other child blocks to be # processed incorrectly. So we need to checkpoint the parent scope so that # each child sees the same context. - before_parent = Scope(None) - before_parent.copy_from(self.scope) + before_parent = Scope.copy_of(self.scope) after_children = [] for child, scope_name in children: self.scope.copy_from(before_parent) parent = self._process_block_node(parent, child, scope_name) - after_child = Scope(None) - after_child.copy_from(self.scope) + after_child = Scope.copy_of(self.scope) after_children.append(after_child) for after_child in after_children: self.scope.merge_from(after_child) return parent + def visit_FunctionDef(self, node): + if self.scope: + qn = qual_names.QN(node.name) + self.scope.mark_write(qn) + current_scope = self.scope + body_scope = Scope(current_scope, isolated=True) + self.scope = body_scope + self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope) + self.scope = current_scope + return node + + def visit_With(self, node): + current_scope = self.scope + with_scope = Scope(current_scope, isolated=False) + self.scope = with_scope + self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope) + self.scope = current_scope + return node + def visit_If(self, node): + current_scope = self.scope + cond_scope = Scope(current_scope, isolated=False) + self.scope = cond_scope self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) + self.scope = current_scope + node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) @@ -253,7 +364,13 @@ class ActivityAnalizer(transformer.Base): return node def visit_While(self, node): + current_scope = self.scope + cond_scope = Scope(current_scope, isolated=False) + self.scope = cond_scope self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) + self.scope = current_scope + node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) @@ -266,5 +383,32 @@ class ActivityAnalizer(transformer.Base): return node +def get_read(node, context): + """Return the variable names as QNs (qual_names.py) read by this statement.""" + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.used + + +def get_updated(node, context): + """Return the variable names created or mutated by this statement. + + This function considers assign statements, augmented assign statements, and + the targets of for loops, as well as function arguments. + For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and + `y`, `for i in range(x)` will return `i`, etc. + Args: + node: An AST node + context: An EntityContext instance + + Returns: + A set of variable names (QNs, see qual_names.py) of all the variables + created or mutated. + """ + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.created | analyzer.scope.modified + + def resolve(node, context, parent_scope=None): - return ActivityAnalizer(context, parent_scope).visit(node) + return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbd349af9d3325af114a7206d89617134278f14 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -0,0 +1,566 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for activity module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno +from tensorflow.python.platform import test + + +class ScopeTest(test.TestCase): + + def test_basic(self): + scope = activity.Scope(None) + self.assertFalse(scope.has(QN('foo'))) + + scope.mark_read(QN('foo')) + self.assertFalse(scope.has(QN('foo'))) + + scope.mark_write(QN('foo')) + self.assertTrue(scope.has(QN('foo'))) + + scope.mark_read(QN('bar')) + self.assertFalse(scope.has(QN('bar'))) + + def test_copy_from(self): + scope = activity.Scope(None) + scope.mark_write(QN('foo')) + + other = activity.Scope(None) + other.copy_from(scope) + + self.assertTrue(QN('foo') in other.created) + + scope.mark_write(QN('bar')) + scope.copy_from(other) + + self.assertFalse(QN('bar') in scope.created) + + scope.mark_write(QN('bar')) + scope.merge_from(other) + + self.assertTrue(QN('bar') in scope.created) + self.assertFalse(QN('bar') in other.created) + + def test_copy_of(self): + scope = activity.Scope(None) + scope.mark_read(QN('foo')) + + self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used) + + child_scope = activity.Scope(scope) + child_scope.mark_read(QN('bar')) + + self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used) + + def test_nesting(self): + scope = activity.Scope(None) + scope.mark_write(QN('foo')) + scope.mark_read(QN('bar')) + + child = activity.Scope(scope) + self.assertTrue(child.has(QN('foo'))) + self.assertTrue(scope.has(QN('foo'))) + + child.mark_write(QN('bar')) + self.assertTrue(child.has(QN('bar'))) + self.assertFalse(scope.has(QN('bar'))) + + def test_referenced(self): + scope = activity.Scope(None) + scope.mark_read(QN('a')) + + child = activity.Scope(scope) + child.mark_read(QN('b')) + + child2 = activity.Scope(child, isolated=False) + child2.mark_read(QN('c')) + + self.assertTrue(QN('c') in child2.referenced) + self.assertTrue(QN('b') in child2.referenced) + self.assertFalse(QN('a') in child2.referenced) + + self.assertTrue(QN('c') in child.referenced) + self.assertTrue(QN('b') in child.referenced) + self.assertFalse(QN('a') in child.referenced) + + +class ActivityAnalyzerTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + return node, ctx + + def test_local_markers(self): + + def test_fn(a): # pylint:disable=unused-argument + b = c # pylint:disable=undefined-variable + while b > 0: + b -= 1 + return b + + node, _ = self._parse_and_analyze(test_fn) + self.assertFalse( + anno.getanno(node.body[0].body[0].value, + NodeAnno.IS_LOCAL)) # c in b = c + self.assertTrue( + anno.getanno(node.body[0].body[1].test.left, + NodeAnno.IS_LOCAL)) # b in b > 0 + self.assertTrue( + anno.getanno(node.body[0].body[2].value, + NodeAnno.IS_LOCAL)) # b in return b + + def assertSymbolSetsAre(self, expected, actual, name): + expected = set(expected) + actual = set(str(s) for s in actual) + self.assertSetEqual( + expected, actual, 'for symbol set: %s\n' + ' Expected: %s\n' + ' Got: %s\n' + ' Missing: %s\n' + ' Extra: %s\n' % (name.upper(), expected, actual, + expected - actual, actual - expected)) + + def assertScopeIsRmc(self, scope, used, modified, created): + """Assert the scope contains specific used, modified & created variables.""" + self.assertSymbolSetsAre(used, scope.used, 'read') + self.assertSymbolSetsAre(modified, scope.modified, 'modified') + self.assertSymbolSetsAre(created, scope.created, 'created') + + def test_print_statement(self): + + def test_fn(a): + b = 0 + c = 1 + print(a, b) + return c + + node, _ = self._parse_and_analyze(test_fn) + print_node = node.body[0].body[2] + if isinstance(print_node, gast.Print): + # Python 2 + print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) + else: + # Python 3 + assert isinstance(print_node, gast.Expr) + # The call node should be the one being annotated. + print_node = print_node.value + print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) + # We basically need to detect which variables are captured by the call + # arguments. + self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ()) + + def test_call_args(self): + + def test_fn(a): + b = 0 + c = 1 + foo(a, b) # pylint:disable=undefined-variable + return c + + node, _ = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[2].value + # We basically need to detect which variables are captured by the call + # arguments. + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) + + def test_call_args_attributes(self): + + def foo(*_): + pass + + def test_fn(a): + a.c = 0 + foo(a.b, a.c) + return a.d + + node, _ = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[1].value + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), + ('a', 'a.b', 'a.c'), + (), + (), + ) + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, + ('a', 'a.b', 'a.c', 'a.d', 'foo'), + ('a.c',), + ('a',), + ) + + def test_call_args_subscripts(self): + + def foo(*_): + pass + + def test_fn(a): + b = 1 + c = 2 + foo(a[0], a[b]) + return a[c] + + node, _ = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[2].value + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), + ('a', 'a[0]', 'a[b]', 'b'), + (), + (), + ) + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, + ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'), + ('b', 'c'), + ('a', 'b', 'c'), + ) + + def test_while(self): + + def test_fn(a): + b = a + while b > 0: + c = b + b -= 1 + return b, c + + node, _ = self._parse_and_analyze(test_fn) + while_node = node.body[0].body[1] + self.assertScopeIsRmc( + anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), + ('c',)) + self.assertScopeIsRmc( + anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), + ('b', 'c'), ('a', 'b', 'c')) + self.assertScopeIsRmc( + anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ()) + + def test_for(self): + + def test_fn(a): + b = a + for _ in a: + c = b + b -= 1 + return b, c + + node, _ = self._parse_and_analyze(test_fn) + for_node = node.body[0].body[1] + self.assertScopeIsRmc( + anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) + self.assertScopeIsRmc( + anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), + ('b', 'c', '_'), ('a', 'b', 'c', '_')) + + def test_if(self): + + def test_fn(x): + if x > 0: + x = -x + y = 2 * x + z = -y + else: + x = 2 * x + y = -x + u = -y + return z, u + + node, _ = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), + ('y', 'z')) + # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), + ('x', 'y', 'u'), ('y', 'u')) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + + def test_if_attributes(self): + + def test_fn(a): + if a > 0: + a.b = -a.c + d = 2 * a + else: + a.b = a.c + d = 1 + return d + + node, _ = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), + ('a', 'a.c'), + ('a.b', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'a.c'), + ('a.b', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, + ('a', 'a.c', 'd'), + ('a.b', 'd'), + ('a', 'd'), + ) + + def test_if_subscripts(self): + + def test_fn(a, b, c, e): + if a > 0: + a[b] = -a[c] + d = 2 * a + else: + a[0] = e + d = 1 + return d + + node, _ = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), + ('a', 'b', 'c', 'a[c]'), + ('a', 'a[b]', 'd'), + ('d',), + ) + # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'e'), + ('a', 'a[0]', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, + ('a', 'b', 'c', 'd', 'e', 'a[c]'), + ('a', 'd', 'a[b]', 'a[0]'), + ('a', 'b', 'c', 'd', 'e'), + ) + + def test_nested_if(self): + + def test_fn(b): + if b > 0: + if b < 5: + a = b + else: + a = b * b + return a + + node, _ = self._parse_and_analyze(test_fn) + inner_if_node = node.body[0].body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), + ('a',)) + self.assertScopeIsRmc( + anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',), + ('a',)) + + def test_nested_function(self): + + def test_fn(a): + + def f(x): + y = x * x + return y + + b = a + for i in a: + c = b + b -= f(i) + return b, c + + node, _ = self._parse_and_analyze(test_fn) + fn_def_node = node.body[0].body[0] + + self.assertScopeIsRmc( + anno.getanno(fn_def_node, + NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), + ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) + self.assertScopeIsRmc( + anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( + 'x', + 'y', + )) + + def test_constructor_attributes(self): + + class TestClass(object): + + def __init__(self, a): + self.b = a + self.b.c = 1 + + node, _ = self._parse_and_analyze(TestClass) + init_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(init_node, NodeAnno.BODY_SCOPE), + ('self', 'a', 'self.b'), + ('self', 'self.b', 'self.b.c'), + ('self', 'a', 'self.b'), + ) + + def test_aug_assign_subscripts(self): + + def test_fn(a): + a[0] += 1 + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a', 'a[0]'), + ('a', 'a[0]'), + ('a',), + ) + + def test_return_vars_are_read(self): + + def test_fn(a, b, c): # pylint: disable=unused-argument + return c + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('c',), + (), + ( + 'a', + 'b', + 'c', + ), + ) + + def test_aug_assign(self): + + def test_fn(a, b): + a += b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a', 'b'), + ('a'), + ('a', 'b'), + ) + + def test_aug_assign_rvalues(self): + + a = dict(bar=3) + + def foo(): + return a + + def test_fn(x): + foo()['bar'] += x + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('foo', 'x'), + (), + ('x',), + ) + + def test_params_created(self): + + def test_fn(a, b): # pylint: disable=unused-argument + return b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), + (('a', 'b'))) + + def test_get_read(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, set(map(qual_names.QN, ('test_fn', 'x', 'y')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, + set(map(qual_names.QN, ('test_fn2', 'x', 'y', 'z')))) + + def test_get_updated(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py similarity index 67% rename from tensorflow/contrib/py2tf/pyct/static_analysis/annos.py rename to tensorflow/contrib/autograph/pyct/static_analysis/annos.py index 2d8e49442364fdd4a4752c8a83a5f3b76117fe57..b929b35b79200b0968c9c4f26b10cda28763773a 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Annotations used by the static analizer.""" +"""Annotations used by the static analyzer.""" from __future__ import absolute_import from __future__ import division @@ -28,23 +28,32 @@ class NoValue(Enum): class NodeAnno(NoValue): - """Additionnal annotations used by the static analyzer. + """Additional annotations used by the static analyzer. These are in addition to the basic annotations declared in anno.py. """ # Symbols - - IS_LOCAL = 'Symbol is local to the function scope being analized.' - IS_PARAM = 'Symbol is a parameter to the function being analized.' + # These flags are boolean. + IS_LOCAL = 'Symbol is local to the function scope being analyzed.' + IS_PARAM = 'Symbol is a parameter to the function being analyzed.' IS_MODIFIED_SINCE_ENTRY = ( 'Symbol has been explicitly replaced in the current function scope.') # Scopes + # Scopes are represented by objects of type activity.Scope. ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' BODY_SCOPE = ( 'The scope for the main body of a statement (True branch for if ' 'statements, main body for loops).') ORELSE_SCOPE = ( 'The scope for the orelse body of a statement (False branch for if ' 'statements, orelse body for loops).') + + # Type and Value annotations + # Type annotations are represented by objects of type type_info.Type. + STATIC_INFO = ( + 'The type or value information that should be asserted about the entity ' + 'referenced by the symbol holding this annotation, irrespective of the ' + 'execution context.') diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..ad97fdfa8e78d1fd4c38724612d83519c6609cce --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -0,0 +1,445 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph analysis. + +Given a Python AST we construct a control flow graph, with edges both to the +next and previous statements (so it can easily walk the graph both ways). Its +nodes contain the AST of the statements. It can then perform forward or backward +analysis on this CFG. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import functools +import operator + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct.static_analysis import activity + + +class CfgNode(object): + """A node in the CFG.""" + __slots__ = ['next', 'value', 'prev'] + + def __init__(self, value): + self.next = set() + self.prev = set() + self.value = value + + +class Cfg(namedtuple('Cfg', ['entry', 'exit'])): + """A Control Flow Graph. + + Each statement is represented as a node. For control flow statements such + as conditionals and loops the conditional itself is a node which either + branches or cycles, respectively. + Attributes: + entry: The entry node, which contains the `gast.arguments` node of the + function definition. + exit: The exit node. This node is special because it has no value (i.e. no + corresponding AST node). This is because Python functions can have + multiple return statements. + """ + pass + + +class CfgBuilder(gast.NodeVisitor): + """Construct a control flow graph. + + Construct a CFG starting from a FunctionDef node. + Usage: + cfg_obj = CfgBuilder().build_cfg(fndef_node) + """ + + def __init__(self): + # The current leaves of the CFG + self.current_leaves = [] + # TODO(alexbw): generalize to break, return, continue, yield, etc. + # A stack of lists, tracking continue statements + self.continue_ = [] + # A stack of lists tracking break nodes + self.break_ = [] + + def set_current_leaves(self, cfg_node): + """Link this cfg_node to the current leaves. + + This is the central function for building the CFG. It links the current + head cfg_nodes to the passed cfg_node. It then resets the head to the + passed cfg_node. + + Args: + cfg_node: A CfgNode instance. + """ + for head in self.current_leaves: + head.next.add(cfg_node) + # While we're linking the CFG forward, add backlinks + cfg_node.prev.add(head) + self.current_leaves = [cfg_node] + + def build_cfg(self, node): + """Build a CFG for a function. + + Implementation of building a CFG for dataflow analysis. See, e.g.: + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Args: + node: A function definition the body of which to analyze. + Returns: + A CFG object. + Raises: + TypeError: If the input is not a function definition. + """ + if not isinstance(node, gast.FunctionDef): + raise TypeError('input must be a function definition') + entry_cfg_node = CfgNode(node.args) + self.current_leaves = [entry_cfg_node] + self.visit_statements(node.body) + exit_cfg_node = CfgNode(None) + self.set_current_leaves(exit_cfg_node) + return Cfg(entry_cfg_node, exit_cfg_node) + + def visit_statements(self, nodes): + for node in nodes: + # Check for control flow + if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, + gast.Continue, gast.With)): + self.visit(node) + else: + expr = CfgNode(node) + self.set_current_leaves(expr) + + def generic_visit(self, node): + raise ValueError('unknown control flow') + + def visit_If(self, node): + # TODO(alexbw): change this to use immutable tuples instead of lists + # The current head will hold the conditional + test = CfgNode(node.test) + self.set_current_leaves(test) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves = [test] + # Handle the orelse + self.visit_statements(node.orelse) + self.current_leaves.extend(body_exit) + + def visit_While(self, node): + test = CfgNode(node.test) + self.set_current_leaves(test) + # Start a new level of nesting + self.break_.append([]) + self.continue_.append([]) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(test) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_For(self, node): + iter_ = CfgNode(node.iter) + self.set_current_leaves(iter_) + self.break_.append([]) + self.continue_.append([]) + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(iter_) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_Break(self, node): + self.break_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Continue(self, node): + self.continue_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Try(self, node): + self.visit_statements(node.body) + body = self.current_leaves + handlers = [] + for handler in node.handlers: + self.current_leaves = body[:] + self.visit_statements(handler.body) + handlers.extend(self.current_leaves) + self.current_leaves = body + self.visit_statements(node.orelse) + self.current_leaves = handlers + self.current_leaves + self.visit_statements(node.finalbody) + + def visit_With(self, node): + for item in node.items: + self.set_current_leaves(CfgNode(item)) + self.visit_statements(node.body) + + +# TODO(alexbw): once CFG analysis occurs at a block level, +# this extra class will not be necessary +class PropagateAnalysis(gast.NodeVisitor): + """Port analysis annotations from statements to their enclosing blocks.""" + + def __init__(self, analysis): + self.transfer_fn = analysis.transfer_fn + self.in_label = analysis.in_label + self.out_label = analysis.out_label + super(PropagateAnalysis, self).__init__() + + def visit_If(self, node): + # Depth-first. + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + outgoing |= anno.getanno(node.test, self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_For(self, node): + self.generic_visit(node) + incoming = set(anno.getanno(node.body[0], self.in_label)) + incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, frozenset(incoming)) + anno.setanno(node, self.out_label, outgoing) + + def visit_While(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_With(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + for item in node.items: + incoming |= anno.getanno(item, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + +# TODO(alexbw): Abstract the CFG walking machinery into a superclass +# which is parameterized on which fields it selects when walking. +# TODO(alexbw): Abstract the application of dataflow analysis +class Forward(object): + """Forward analysis on CFG. + + Args: + label: A name for this analysis e.g. 'active' for activity analysis. The AST + nodes in the CFG will be given annotations 'name_in', 'name_out', + 'name_gen' and 'name_kill' which contain the incoming values, outgoing + values, values generated by the statement, and values deleted by the + statement respectively. + transfer_fn: Either the AND or OR operator. If the AND operator is used it + turns into forward must analysis (i.e. a value will only be carried + forward if it appears on all incoming paths). The OR operator means that + forward may analysis is done (i.e. the union of incoming values will be + taken). + """ + + def __init__(self, label, context, transfer_fn=operator.or_): + self.transfer_fn = transfer_fn + self.context = context + self.out_label = label + '_out' + self.in_label = label + '_in' + self.gen_label = label + '_gen' + self.kill_label = label + '_kill' + + # TODO(alexbw): see if we can simplify by visiting breadth-first + def visit(self, node): + """Depth-first walking the CFG, applying dataflow information propagtion.""" + # node.value is None only for the exit CfgNode. + if not node.value: + return + + if anno.hasanno(node.value, self.out_label): + before = hash(anno.getanno(node.value, self.out_label)) + else: + before = None + preds = [ + anno.getanno(pred.value, self.out_label) + for pred in node.prev + if anno.hasanno(pred.value, self.out_label) + ] + if preds: + incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) + else: + incoming = frozenset() + anno.setanno(node.value, self.in_label, incoming) + gen, kill = self.get_gen_kill(node, incoming) + anno.setanno(node.value, self.gen_label, gen) + anno.setanno(node.value, self.kill_label, kill) + anno.setanno(node.value, self.out_label, (incoming - kill) | gen) + + if hash(anno.getanno(node.value, self.out_label)) != before: + for succ in node.next: + self.visit(succ) + + def get_gen_kill(self, cfg_node, incoming): + """Calculate Gen and Kill properties of a CFG node in dataflow analysis. + + A function which takes the CFG node as well as a set of incoming + values. It must return a set of newly generated values by the statement as + well as a set of deleted (killed) values. + + Args: + cfg_node: A CfgNode instance. + incoming: + """ + raise NotImplementedError() + + +class Backward(Forward): + """Backward analysis on CFG.""" + + def visit(self, cfg_node): + # cfg_node.value is None for the exit node, which will be visited only once + if not cfg_node.value: + for pred in cfg_node.prev: + self.visit(pred) + return + + if anno.hasanno(cfg_node.value, self.in_label): + before = hash(anno.getanno(cfg_node.value, self.in_label)) + else: + before = None + succs = [ + anno.getanno(succ.value, self.in_label) + for succ in cfg_node.next + if anno.hasanno(succ.value, self.in_label) + ] + if succs: + incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) + else: + incoming = frozenset() + anno.setanno(cfg_node.value, self.out_label, incoming) + gen, kill = self.get_gen_kill(cfg_node, incoming) + anno.setanno(cfg_node.value, self.gen_label, gen) + anno.setanno(cfg_node.value, self.kill_label, kill) + anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) + if hash(anno.getanno(cfg_node.value, self.in_label)) != before: + for pred in cfg_node.prev: + self.visit(pred) + + +def run_analyses(node, analyses): + """Perform dataflow analysis on all functions within an AST. + + Args: + node: An AST node on which to run dataflow analysis. + analyses: Either an instance of the Forward or Backward dataflow analysis + class, or a list or tuple of them. + + Returns: + node: The node, but now with annotations on the AST nodes containing the + results of the dataflow analyses. + """ + if not isinstance(analyses, (tuple, list)): + analyses = (analyses,) + for analysis in analyses: + if not isinstance(analysis, (Forward, Backward)): + raise TypeError('not a valid forward analysis object') + + for child_node in gast.walk(node): + if isinstance(child_node, gast.FunctionDef): + cfg_obj = CfgBuilder().build_cfg(child_node) + for analysis in analyses: + if isinstance(analysis, Backward): + analysis.visit(cfg_obj.exit) + elif isinstance(analysis, Forward): + analysis.visit(cfg_obj.entry) + for analysis in analyses: + PropagateAnalysis(analysis).visit(node) + return node + + +class Liveness(Backward): + """Perform a liveness analysis. + + Each statement is annotated with a set of variables that may be used + later in the program. + """ + + def __init__(self, context): + super(Liveness, self).__init__('live', context) + + def get_gen_kill(self, node, _): + # A variable's parents are live if it is live + # e.g. x is live if x.y is live. This means gen needs to return + # all parents of a variable (if it's an Attribute or Subscript). + # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) + gen = activity.get_read(node.value, self.context) + gen = functools.reduce(lambda left, right: left | right.support_set, gen, + gen) + kill = activity.get_updated(node.value, self.context) + return gen, kill + + +class ReachingDefinitions(Forward): + """Perform reaching definition analysis. + + Each statement is annotated with a set of (variable, definition) pairs. + """ + + def __init__(self, context): + super(ReachingDefinitions, self).__init__('definitions', context) + + def get_gen_kill(self, node, incoming): + definitions = activity.get_updated(node.value, self.context) + gen = frozenset((id_, node.value) for id_ in definitions) + kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) + return gen, kill + + +class Defined(Forward): + """Perform defined variable analysis. + + Each statement is annotated with a set of variables which are guaranteed to + be defined at that point. + """ + + def __init__(self, context): + super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + + def get_gen_kill(self, node, _): + gen = activity.get_updated(node.value, self.context) + return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fc07fa3447b23c0595a5893329de8a2d7055ca15 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.python.platform import test + + +class CFGTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + return node, ctx + + def _check_anno_matches(self, node, anno_name, var_names): + if isinstance(var_names, str): + var_names = (var_names,) + qual_vars = set() + for var_name in var_names: + if isinstance(var_name, str): + if '[' in var_name or ']' in var_name: + raise ValueError('Annotation matching not supported with subscript.') + if '.' not in var_name: + qual_vars.add(qual_names.QN(var_name)) + else: + attrs = var_name.split('.') + this_qn = functools.reduce(qual_names.QN, attrs[1:], + qual_names.QN(attrs[0])) + qual_vars.add(this_qn) + self.assertEqual(anno.getanno(node, anno_name), qual_vars) + + def test_reaching(self): + + def f(x): + print(x) + while True: + x = x + x = x + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + # Only the argument reaches the expression + def_in = anno.getanno(body[0], 'definitions_in') + # One element, x, from arguments + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) + + while_body = body[1].body + def_in = anno.getanno(while_body[0], 'definitions_in') + # One definition, two possible sources. + # - One from an assignment (if the loop is entered) + # - The other from the arguments (if loop is not entered) + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def_in = anno.getanno(while_body[1], 'definitions_in') + # If we've reached this line, the only reaching definition of x is the + # Assign node in previous line + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) + + def_in = anno.getanno(body[2], 'definitions_in') + # Same situation as while_body[0] + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def test_defined(self): + + def f(x): + if x: + y = 2 # pylint: disable=unused-variable + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + # only x is for sure defined at the end + self._check_anno_matches(body[1], 'defined_in', 'x') + # at the end of the if body both x and y are defined + if_body = body[0].body + self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) + + def _get_live_annotated_fnbody(self, f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body + + def test_live_straightline(self): + + def f1(x): + a = g(x) # pylint: disable=undefined-variable + b = h(a) # pylint: disable=undefined-variable, unused-variable + return x + + body = self._get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + def test_live_stacked_conds_with_else(self): + + def f2(x, a): # pylint: disable=unused-argument + if a > 0: # x should not be live + x = 0 + if a > 1: + x = 1 + else: + x = 2 + + body = self._get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + def test_live_stacked_conds(self): + + def f3(x, a): + if a > 0: # x and a should be live + x = 0 + if a > 1: # x and a should be live_in + x = 1 + return x # x should be live + + body = self._get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + def test_live_possibly_unused_cond(self): + + def f4(x, a): + if a > 0: # x should be live + x = 0 + x += 1 + + body = self._get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_attribute_in_cond(self): + + def f5(x, a): + if a > 0: # x.y should be live + x.y = 0 + return x.y + + body = self._get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + def test_live_noop(self): + + def f6(x): + return x # should this cause x.* to be live? + + body = self._get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + def test_live_loop(self): + + def f7(x, n): + for i in range(n): + x += i + return x + + body = self._get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_context_manager(self): + + def f8(x, f): + with f: + x += 1 + + body = self._get_live_annotated_fnbody(f8) + self._check_anno_matches(body[0], 'live_in', ('f', 'x')) + + def test_node_equality(self): + node_a = gast.parse('y = x').body[0] + node_b = gast.parse('y = x').body[0] + self.assertNotEqual(node_a, node_b) + + def test_nested_functions_defined(self): + + def f(x): + y = x * 2 + + def g(z): + return z + y + + return g(x) + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('g', 'x', 'y')))) + + # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. + # NOTE: 'z' is easy to find, but 'y' is not identified as + # defined, because CFG analysis is applied with each function separately. + # fndef_body = body[1].body + # self.assertEqual( + # anno.getanno(fndef_body[0], 'defined_in'), + # frozenset(map(qual_names.QN, ('z', 'y')))) + + def test_nested_functions_dont_leak_definitions(self): + + def f(x): + print(x) + + def g(): + y = 2 + return y + + return g() # y is not defined here + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('x', 'g')))) + + def test_loop_else(self): + + # Disabling useless-else-on-loop error, because 'break' and 'continue' + # canonicalization are a separate analysis pass, and here we test + # the CFG analysis in isolation. + def for_orelse(x): + y = 0 + for i in range(len(x)): + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + def while_orelse(x, i): + y = 0 + while x < 10: + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + for f in (for_orelse, while_orelse): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + return_node = body[-1] + reaching_defs = anno.getanno(return_node, 'definitions_in') + + # Y could be defined by Assign(Num(0)) or Assign(Num(1)) + # X could be defined as an argument or an AugAssign. + y_defs = [node for var, node in reaching_defs if str(var) == 'y'] + x_defs = [node for var, node in reaching_defs if str(var) == 'x'] + + self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) + self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) + self.assertEqual(len(y_defs), 2) + self.assertEqual( + set((gast.arguments, gast.AugAssign)), + set(type(def_) for def_ in x_defs)) + self.assertEqual(len(x_defs), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py similarity index 85% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 9c0a9a9e74eccb3d22840032e8f0c2b81e051e7e..53ae15459097baff918432a493edd7360ebf209d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -25,9 +25,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class LiveValueResolver(transformer.Base): @@ -55,11 +55,19 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) - anno.setanno(node, 'fqn', (obj.__name__,)) + if hasattr(obj, '__name__'): + anno.setanno(node, 'fqn', (obj.__name__,)) + elif hasattr(obj, '__class__'): + obj_class = obj.__class__ + anno.setanno(node, 'fqn', + (obj_class.__module__, obj_class.__name__)) + else: + # If the symbol value is for example a primitive, then it will not + # have a name. + pass else: pass # TODO(mdan): Should we raise an error here? @@ -86,6 +94,7 @@ class LiveValueResolver(transformer.Base): if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) + anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. @@ -96,6 +105,7 @@ class LiveValueResolver(transformer.Base): # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. + anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py similarity index 75% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 9f64689401e3594a77fbdd7b6f02880bd6e90492..69e428bde109ed43c3cdda1a94970a832dc47852 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -18,13 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -46,6 +48,7 @@ class LiveValuesResolverTest(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, + owner_type=None, recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) @@ -56,13 +59,30 @@ class LiveValuesResolverTest(test.TestCase): def test_literals(self): + a = None + def test_fn(): - return Foo # pylint: disable=undefined-variable + return a - node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) + node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) + def test_primitive_values(self): + + a = None + + def test_fn(): + return a + + node = self._parse_and_analyze(test_fn, {'a': True}) + retval_node = node.body[0].body[0].value + if six.PY2: + self.assertEqual( + anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) + else: + self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool')) + def test_namespace(self): def foo(): @@ -102,6 +122,7 @@ class LiveValuesResolverTest(test.TestCase): arg_types={'self': (TestClass.__name__, TestClass)}) func_node = node.body[0].body[0].value.func self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type')) self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py new file mode 100644 index 0000000000000000000000000000000000000000..7d1e65c958d7787ef5ed707d4822d14a83092975 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -0,0 +1,242 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Type resolution. + +This analyzer uses known live values to further infer object types. This +may include for instance constructed objects and object member functions. + +In addition, the analyzer also handles user annotations made in the code (for +example, the autograph.set_element_type function). + +Requires annotations generated by LiveValuesResolver. +""" + +# TODO(mdan): This would be more robust with a CFG. +# Situations with multiple reaching modifications (e.g. modified inside and +# outside a control flow statement) should be more robustly detected and +# analyzed. + +# TODO(mdan): Look into using Python AST's type annotation fields instead. +# It would be desirable to use that mechanism if we can. +# Some caveats to consider: We may need to annotate other nodes like +# Attribute. It may also not be feasible for us to faithfully to replicate +# PY3's type annotations where it isn't available. It would also require us +# to design rigorous type definitions that can accommodate Python types +# as well as TensorFLow dtypes and shapes. + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.python.util import tf_inspect + + +# TODO(mdan): Remove the duplication between this and activity.py. +# In particular, the symbol definitions we track here could as well be tracked +# there because they follow the same rules for visibility. +class Scope(object): + """Tracks symbol value references. + + Attributes: + values: A dict mapping string to gast.Node, containing the value that was + most recently assigned to the symbol. + """ + + def __init__(self, parent): + """Create a new scope. + + Args: + parent: A Scope or None. + """ + self.parent = parent + self.values = {} + + def __repr__(self): + return 'Scope[%s]' % self.values.keys() + + def copy(self): + s = Scope(self.parent) + s.values = self.values.copy() + return s + + def setval(self, name, value): + self.values[name] = value + + def hasval(self, name): + return (name in self.values or + (self.parent is not None and self.parent.hasval(name))) + + def getval(self, name): + if name in self.values: + return self.values[name] + if self.parent is not None: + return self.parent.getval(name) + raise KeyError(name) + + +class TypeInfoResolver(transformer.Base): + """Annotates symbols with type information where possible. + + Nodes currently annotated: + * Call (helps detect class constructors) + * Attribute (helps resolve object methods) + """ + + def __init__(self, context): + super(TypeInfoResolver, self).__init__(context) + self.scope = Scope(None) + + def visit_FunctionDef(self, node): + self.scope = Scope(self.scope) + node = self.generic_visit(node) + self.scope = self.scope.parent + return node + + def _visit_block(self, block): + self.scope = Scope(self.scope) + block = self.visit_block(block) + self.scope = self.scope.parent + return block + + def visit_For(self, node): + self.generic_visit(node.target) + self.generic_visit(node.iter) + node.body = self._visit_block(node.body) + node.orelse = self._visit_block(node.orelse) + return node + + def visit_While(self, node): + self.generic_visit(node.test) + node.body = self._visit_block(node.body) + node.orelse = self._visit_block(node.orelse) + return node + + def visit_If(self, node): + self.generic_visit(node.test) + node.body = self._visit_block(node.body) + node.orelse = self._visit_block(node.orelse) + return node + + def _process_function_arg(self, arg_name): + str_name = str(arg_name) + type_holder = arg_name.ast() + self.scope.setval(arg_name, type_holder) + if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: + # Forge a node to hold the type information, so that method calls on + # it can resolve the type. + type_string, type_obj = self.context.arg_types[str_name] + anno.setanno(type_holder, 'type', type_obj) + anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) + + def visit_arg(self, node): + self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + return node + + def visit_Name(self, node): + self.generic_visit(node) + qn = anno.getanno(node, anno.Basic.QN) + if isinstance(node.ctx, gast.Param): + self._process_function_arg(qn) + elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') + return node + + def _process_variable_assignment(self, target, value): + # Constructors + if isinstance(value, gast.Call): + func = value.func + if anno.hasanno(func, 'live_val'): + func_obj = anno.getanno(func, 'live_val') + if tf_inspect.isclass(func_obj): + anno.setanno(value, 'is_constructor', True) + anno.setanno(value, 'type', func_obj) + anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) + # TODO(mdan): Raise an error if constructor has side effects. + # We can have a whitelist of no-side-effects constructors. + # We can also step inside the constructor and further analyze. + + if isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, value) + elif isinstance(target, gast.Subscript): + pass + else: + raise ValueError('assignment target has unknown type: %s' % target) + + def visit_With(self, node): + for item in node.items: + if item.optional_vars is not None: + self.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) + self.generic_visit(node) + return node + + def visit_Assign(self, node): + self.generic_visit(node) + self.apply_to_single_assignments( + node.targets, node.value, self._process_variable_assignment) + return node + + def visit_Call(self, node): + if anno.hasanno(node.func, 'live_val'): + # Symbols targeted by the "set_type" marker function are assigned the data + # type that it specified. + if (anno.getanno(node.func, 'live_val') is + self.context.type_annotation_func): + + if len(node.args) < 2 or len(node.args) > 3: + raise ValueError('"%s" must have either two or three parameters' + % self.context.type_annotation_func) + if len(node.args) == 2: + target_arg, type_arg = node.args + shape_arg = parser.parse_expression('None') + else: + target_arg, type_arg, shape_arg = node.args + if not anno.hasanno(target_arg, anno.Basic.QN): + raise ValueError('the first argument of "%s" must by a symbol' + % self.context.type_annotation_func) + # TODO(mdan): This is vulnerable to symbol renaming. + element_type = type_arg + element_shape = shape_arg + + target_symbol = anno.getanno(target_arg, anno.Basic.QN) + # Find the definition of this symbol and annotate it with the given + # data type. That in turn will cause future uses of the symbol + # to receive the same type annotation. + definition = self.scope.getval(target_symbol) + anno.setanno(node, 'element_type', element_type) + anno.setanno(node, 'element_shape', element_shape) + anno.setanno(definition, 'element_type', element_type) + anno.setanno(definition, 'element_shape', element_shape) + # TODO(mdan): Should we update references between definition and here? + return self.generic_visit(node) + + +def resolve(node, context): + return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py similarity index 66% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 3659f949db9910534870d8dd9e42fd4ee8297253..484562f294bb53a63feeca965b8f94c58aa2a685 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.client import session from tensorflow.python.platform import test from tensorflow.python.training import training @@ -56,7 +57,10 @@ class ScopeTest(test.TestCase): class TypeInfoResolverTest(test.TestCase): - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + def _parse_and_analyze(self, + test_fn, + namespace, + arg_types=None): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, @@ -65,7 +69,9 @@ class TypeInfoResolverTest(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, - recursive=True) + owner_type=None, + recursive=True, + type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) @@ -174,6 +180,75 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) + def test_type_annotation(self): + + class Foo(object): + pass + + def test_fn(): + f = [] + f = utils.set_element_type(f, Foo, (1, 2, 3)) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_def = node.body[0].body[0].value + self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo') + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') + + def test_type_annotation_args(self): + + class Foo(object): + pass + + def test_fn(f): + utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') + + def test_nested_unpacking(self): + + class Foo(object): + pass + + class Bar(object): + pass + + def test_fn(): + a, (b, c) = (Foo(), (Bar(), Foo())) + return a, b, c + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) + a, b, c = node.body[0].body[1].value.elts + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + self.assertFalse(anno.hasanno(c, 'live_val')) + + def test_inner_scope(self): + + def test_fn(): + a = [] + utils.set_element_type(a, 1) + for _ in a: + b = [] + utils.set_element_type(b, 2) + return a, b + + node = self._parse_and_analyze(test_fn, {'utils': utils}) + a, b = node.body[0].body[2].body[2].value.elts + self.assertEquals(anno.getanno(a, 'element_type').n, 1) + self.assertEquals(anno.getanno(b, 'element_type').n, 2) + self.assertFalse(anno.hasanno(a, 'type')) + self.assertFalse(anno.hasanno(b, 'type')) + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py similarity index 55% rename from tensorflow/contrib/py2tf/pyct/templates.py rename to tensorflow/contrib/autograph/pyct/templates.py index 6ee6c0c5ceb70d87779ee313670135cadc5214b5..9c479ebc2fa83d27dc363ae306daedb556734a1f 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -26,9 +26,9 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names class ReplaceTransformer(gast.NodeTransformer): @@ -44,8 +44,6 @@ class ReplaceTransformer(gast.NodeTransformer): self.replacements = replacements self.in_replacements = False - # TODO(mdan): Make a more detailed pass and clean up if needed. - def visit_Expr(self, node): if (isinstance(node.value, gast.Name) and node.value.id in self.replacements): @@ -53,17 +51,66 @@ class ReplaceTransformer(gast.NodeTransformer): self.generic_visit(node) return node + def visit_keyword(self, node): + if node.arg in self.replacements: + repl = self.replacements[node.arg] + if isinstance(repl, gast.keyword): + return repl + elif (isinstance(repl, (list, tuple)) and repl and + all(isinstance(r, gast.keyword) for r in repl)): + return repl + # TODO(mdan): We may allow replacing with a string as well. + # For example, if one wanted to replace foo with bar in foo=baz, then + # we could allow changing just node arg, so that we end up with bar=baz. + raise ValueError( + 'a keyword argument may only be replaced by another keyword or a ' + 'non-empty list of keywords. Found: %s' % repl) + return self.generic_visit(node) + def visit_FunctionDef(self, node): node = self.generic_visit(node) if node.name in self.replacements: repl = self.replacements[node.name] if not isinstance(repl, (gast.Name, ast.Name)): raise ValueError( - 'A function name can only be replaced by a Name node. Found: %s' % + 'a function name can only be replaced by a Name node. Found: %s' % repl) node.name = repl.id return node + def _check_has_context(self, node): + if not node.ctx: + raise ValueError('node %s is missing ctx value' % node) + + def _check_inner_children_have_context(self, node): + if isinstance(node, gast.Attribute): + self._check_inner_children_have_context(node.value) + self._check_has_context(node) + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._check_inner_children_have_context(e) + self._check_has_context(node) + elif isinstance(node, gast.Dict): + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._check_inner_children_have_context(node.value) + self._check_inner_children_have_context(node.slice) + elif isinstance(node, gast.Slice): + self._check_inner_children_have_context(node.lower) + if node.upper: + self._check_inner_children_have_context(node.upper) + if node.step: + self._check_inner_children_have_context(node.step) + elif isinstance(node, gast.Name): + self._check_has_context(node) + elif isinstance(node, (gast.Str, gast.Num)): + pass + else: + raise ValueError('unexpected node type "%s"' % node) + def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, ctx) @@ -74,11 +121,40 @@ class ReplaceTransformer(gast.NodeTransformer): node.ctx = ctx elif isinstance(node, gast.Name): node.ctx = ctx + elif isinstance(node, gast.Call): + self._set_inner_child_context(node.func, ctx) + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for a in node.args: + self._check_inner_children_have_context(a) + for k in node.keywords: + self._check_inner_children_have_context(k.value) + elif isinstance(node, gast.Dict): + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._set_inner_child_context(node.value, ctx) + self._check_inner_children_have_context(node.slice) elif isinstance(node, (gast.Str, gast.Num)): pass else: raise ValueError('unexpected node type "%s"' % node) + def visit_Attribute(self, node): + node = self.generic_visit(node) + if node.attr not in self.replacements: + return node + repl = self.replacements[node.attr] + if not isinstance(repl, gast.Name): + raise ValueError( + 'An attribute can only be replaced by a Name node. Found: %s' % repl) + node.attr = repl.id + return node + def visit_Name(self, node): if node.id not in self.replacements: return node @@ -154,3 +230,22 @@ def replace(template, **replacements): if isinstance(results, list): return [qual_names.resolve(r) for r in results] return qual_names.resolve(results) + + +def replace_as_expression(template, **replacements): + """Variant of replace that generates expressions, instead of code blocks.""" + replacement = replace(template, **replacements) + if len(replacement) != 1: + raise ValueError( + 'single expression expected; for more general templates use replace') + node = replacement[0] + node = qual_names.resolve(node) + + if isinstance(node, gast.Expr): + return node.value + elif isinstance(node, gast.Name): + return node + + raise ValueError( + 'the template is expected to generate an expression or a name node;' + ' instead found %s' % node) diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a01f8bf04c4faa6ec1779e0fb306155d99f5bd09 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -0,0 +1,168 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for templates module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import imp + +import gast + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.python.platform import test + + +class TemplatesTest(test.TestCase): + + def test_replace_tuple(self): + template = """ + def test_fn(a, c): + return b, + """ + + node = templates.replace(template, b=('a', 'c'))[0] + result, _ = compiler.ast_to_object(node) + + self.assertEquals((2, 3), result.test_fn(2, 3)) + + def test_replace_variable(self): + template = """ + def test_fn(a): + a += 1 + a = 2 * a + 1 + return b + """ + + node = templates.replace(template, a='b')[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_function_name(self): + template = """ + def fname(a): + a += 1 + a = 2 * a + 1 + return a + """ + + node = templates.replace(template, fname='test_fn')[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_code_block(self): + template = """ + def test_fn(a): + block + return a + """ + + node = templates.replace( + template, + block=[ + gast.Assign([ + gast.Name('a', None, None) + ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), + ] * 2)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn(1)) + + def test_replace_attribute(self): + template = """ + def test_fn(a): + return a.foo + """ + + node = templates.replace(template, foo='b')[0] + result, _ = compiler.ast_to_object(node) + mod = imp.new_module('test') + mod.b = 3 + self.assertEquals(3, result.test_fn(mod)) + + with self.assertRaises(ValueError): + templates.replace(template, foo=1) + + def test_replace_call_keyword(self): + template = """ + def test_fn(): + def f(a, d, f): + return a + d + f + return f(1, kws=None) + """ + + source = parser.parse_expression('f(d=3, f=5)') + node = templates.replace(template, kws=source.keywords)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(9, result.test_fn()) + + with self.assertRaises(ValueError): + templates.replace(template, kws=[]) + templates.replace(template, kws=1) + + def test_replace_name_with_call(self): + template = """ + def test_fn(): + b = 5 + def g(a): + return 3 * a + def f(): + return g + return foo + """ + + source = parser.parse_expression('f()(b)') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(15, result.test_fn()) + + def test_replace_name_with_dict(self): + template = """ + def test_fn(): + return foo['bar'] + """ + + source = parser.parse_expression('{\'bar\': 3}') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn()) + + def replace_as_expression(self): + template = """ + foo(a) + """ + + node = templates.replace(template, foo='bar', a='baz') + self.assertTrue(node is gast.Call) + self.assertEqual(node.func.id, 'bar') + self.assertEqual(node.func.args[0].id, 'baz') + + def replace_as_expression_restrictions(self): + template = """ + foo(a) + bar(b) + """ + with self.assertRaises(ValueError): + templates.replace_as_expression(template) + with self.assertRaises(ValueError): + templates.replace('') + with self.assertRaises(ValueError): + templates.replace('a = b') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..60bca8b38dcf62b4e997379d075cfc45511a894f --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -0,0 +1,291 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A node transformer that includes utilities for SCT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import gast +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import pretty_printer + + +class AutographParseError(SyntaxError): + pass + + +def try_ast_to_source(node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + + +class Base(gast.NodeTransformer): + """Base class for specialized transformers. + + Scope-local state tracking: to keep state across nodes, at the level of + (possibly nested) scopes, use enter/exit_local_scope and set/get_local. + You must call enter/exit_local_scope manually, but the transformer detects + when they are not properly paired. + """ + + def __init__(self, context): + """Initialize the transformer. Subclasses should call this. + + Args: + context: An EntityContext. + """ + self._lineno = 0 + self._col_offset = 0 + self.context = context + self._enclosing_entities = [] + + # A stack that allows keeping mutable, scope-local state where scopes may be + # nested. For example, it can be used to track the usage of break + # statements in each loop, where loops may be nested. + self._local_scope_state = [] + self.enter_local_scope() + + @property + def enclosing_entities(self): + return tuple(self._enclosing_entities) + + @property + def local_scope_level(self): + return len(self._local_scope_state) + + def enter_local_scope(self, inherit=None): + """Marks entry into a new local scope. + + Args: + inherit: Optional enumerable of variable names to copy from the + parent scope. + """ + scope_entered = {} + if inherit: + this_scope = self._local_scope_state[-1] + for name in inherit: + if name in this_scope: + scope_entered[name] = this_scope[name] + self._local_scope_state.append(scope_entered) + + def exit_local_scope(self, keep=None): + """Marks exit from the current local scope. + + Args: + keep: Optional enumerable of variable names to copy into the + parent scope. + Returns: + A dict containing the scope that has just been exited. + """ + scope_left = self._local_scope_state.pop() + if keep: + this_scope = self._local_scope_state[-1] + for name in keep: + if name in scope_left: + this_scope[name] = scope_left[name] + return scope_left + + def set_local(self, name, value): + self._local_scope_state[-1][name] = value + + def get_local(self, name, default=None): + return self._local_scope_state[-1].get(name, default) + + def debug_print(self, node): + """Helper method useful for debugging.""" + if __debug__: + print(pretty_printer.fmt(node)) + return node + + def visit_block(self, nodes, before_visit=None, after_visit=None): + """A more powerful version of generic_visit for statement blocks. + + An example of a block is the body of an if statement. + + This function allows specifying a postprocessing callback (the + after_visit argument) argument which can be used to move nodes to a new + destination. This is done by after_visit by returning a non-null + second return value, e.g. return new_node, new_destination. + + For example, a transformer could perform the following move: + + foo() + bar() + baz() + + foo() + if cond: + bar() + baz() + + The above could be done with a postprocessor of this kind: + + def after_visit(node): + if node_is_function_call(bar): + new_container_node = build_cond() + new_container_node.body.append(node) + return new_container_node, new_container_node.body + else: + # Once we set a new destination, all subsequent items will be + # moved to it, so we don't need to explicitly handle baz. + return node, None + + Args: + nodes: enumerable of AST node objects + before_visit: optional callable that is called before visiting each item + in nodes + after_visit: optional callable that takes in an AST node and + returns a tuple (new_node, new_destination). It is called after + visiting each item in nodes. Is used in the same was as the + visit_* methods: new_node will replace the node; if not None, + new_destination must be a list, and subsequent nodes will be placed + in this list instead of the list returned by visit_block. + Returns: + A list of AST node objects containing the transformed items fron nodes, + except those nodes that have been relocated using after_visit. + """ + results = [] + node_destination = results + for node in nodes: + if before_visit: + # TODO(mdan): We can modify node here too, if ever needed. + before_visit() + + replacement = self.visit(node) + + if after_visit and replacement: + replacement, new_destination = after_visit(replacement) + else: + new_destination = None + + if replacement: + if isinstance(replacement, (list, tuple)): + node_destination.extend(replacement) + else: + node_destination.append(replacement) + + # Allow the postprocessor to reroute the remaining nodes to a new list. + if new_destination is not None: + node_destination = new_destination + return results + + # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + def apply_to_single_assignments(self, targets, values, apply_fn): + """Applies a fuction to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: list, tuple of or individual AST node. Should be used with the + targets field of an ast.Assign node. + values: an AST node. + apply_fn: a function of a single argument, which will be called with the + respective nodes of each single assignment. The signaure is + apply_fn(target, value), no return value. + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) + self.apply_to_single_assignments(target_el, value_el, apply_fn) + else: + # TODO(mdan): Look into allowing to rewrite the AST here. + apply_fn(target, values) + + def visit(self, node): + source_code = self.context.source_code + source_file = self.context.source_file + did_enter_function = False + local_scope_size_at_entry = len(self._local_scope_state) + + try: + if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): + did_enter_function = True + + if did_enter_function: + self._enclosing_entities.append(node) + + if source_code and hasattr(node, 'lineno'): + self._lineno = node.lineno + self._col_offset = node.col_offset + + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + result = super(Base, self).visit(node) + + # On exception, the local scope integrity is not guaranteed. + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_size_at_entry != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.' % ( + node, + local_scope_size_at_entry, + len(self._local_scope_state) + )) + return result + + except (ValueError, AttributeError, KeyError, NotImplementedError) as e: + msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( + e.__class__.__name__, str(e), try_ast_to_source(node), + pretty_printer.fmt(node, color=False)) + if source_code: + line = source_code.splitlines()[self._lineno - 1] + else: + line = '' + # TODO(mdan): Avoid the printing of the original exception. + # In other words, we need to find how to suppress the "During handling + # of the above exception, another exception occurred" message. + six.reraise(AutographParseError, + AutographParseError( + msg, + (source_file, self._lineno, self._col_offset + 1, line)), + sys.exc_info()[2]) diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f110e79605945e908e8a49112cf758ec29fa1b11 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for templates module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.python.platform import test + + +class TransformerTest(test.TestCase): + + def _context_for_testing(self): + return context.EntityContext( + namer=None, + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None, + recursive=False) + + def test_entity_scope_tracking(self): + + class TestTransformer(transformer.Base): + + # The choice of note to assign to is arbitrary. Using Assign because it's + # easy to find in the tree. + def visit_Assign(self, node): + anno.setanno(node, 'enclosing_entities', self.enclosing_entities) + return self.generic_visit(node) + + # This will show up in the lambda function. + def visit_BinOp(self, node): + anno.setanno(node, 'enclosing_entities', self.enclosing_entities) + return self.generic_visit(node) + + tr = TestTransformer(self._context_for_testing()) + + def test_function(): + a = 0 + + class TestClass(object): + + def test_method(self): + b = 0 + def inner_function(x): + c = 0 + d = lambda y: (x + y) + return c, d + return b, inner_function + return a, TestClass + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + test_function_node = node.body[0] + test_class = test_function_node.body[1] + test_method = test_class.body[0] + inner_function = test_method.body[1] + lambda_node = inner_function.body[1].value + + a = test_function_node.body[0] + b = test_method.body[0] + c = inner_function.body[0] + lambda_expr = lambda_node.body + + self.assertEqual( + (test_function_node,), anno.getanno(a, 'enclosing_entities')) + self.assertEqual((test_function_node, test_class, test_method), + anno.getanno(b, 'enclosing_entities')) + self.assertEqual( + (test_function_node, test_class, test_method, inner_function), + anno.getanno(c, 'enclosing_entities')) + self.assertEqual((test_function_node, test_class, test_method, + inner_function, lambda_node), + anno.getanno(lambda_expr, 'enclosing_entities')) + + def test_local_scope_info_stack(self): + + class TestTransformer(transformer.Base): + + # Extract all string constants from the block. + def visit_Str(self, node): + self.set_local('string', self.get_local('string', default='') + node.s) + return self.generic_visit(node) + + def _annotate_result(self, node): + self.enter_local_scope() + node = self.generic_visit(node) + anno.setanno(node, 'test', self.get_local('string')) + self.exit_local_scope() + return node + + def visit_While(self, node): + return self._annotate_result(node) + + def visit_For(self, node): + return self._annotate_result(node) + + tr = TestTransformer(self._context_for_testing()) + + def test_function(a): + """Docstring.""" + assert a == 'This should not be counted' + for i in range(3): + _ = 'a' + if i > 2: + return 'b' + else: + _ = 'c' + while True: + raise '1' + return 'nor this' + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + for_node = node.body[0].body[2] + while_node = for_node.body[1].orelse[1] + + self.assertFalse(anno.hasanno(for_node, 'string')) + self.assertEqual('abc', anno.getanno(for_node, 'test')) + self.assertFalse(anno.hasanno(while_node, 'string')) + self.assertEqual('1', anno.getanno(while_node, 'test')) + + def test_local_scope_info_stack_checks_integrity(self): + + class TestTransformer(transformer.Base): + + def visit_If(self, node): + self.enter_local_scope() + return self.generic_visit(node) + + def visit_For(self, node): + node = self.generic_visit(node) + self.exit_local_scope() + return node + + tr = TestTransformer(self._context_for_testing()) + + def no_exit(a): + if a > 0: + print(a) + return None + + node, _ = parser.parse_entity(no_exit) + with self.assertRaises(AssertionError): + tr.visit(node) + + def no_entry(a): + for _ in a: + print(a) + + node, _ = parser.parse_entity(no_entry) + with self.assertRaises(AssertionError): + tr.visit(node) + + def test_visit_block_postprocessing(self): + + class TestTransformer(transformer.Base): + + def _process_body_item(self, node): + if isinstance(node, gast.Assign) and (node.value.id == 'y'): + if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) + return if_node, if_node.body + return node, None + + def visit_FunctionDef(self, node): + node.body = self.visit_block( + node.body, after_visit=self._process_body_item) + return node + + def test_function(x, y): + z = x + z = y + return z + + tr = TestTransformer(self._context_for_testing()) + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + node = node.body[0] + + self.assertEqual(len(node.body), 2) + self.assertTrue(isinstance(node.body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1], gast.If)) + self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD similarity index 85% rename from tensorflow/contrib/py2tf/utils/BUILD rename to tensorflow/contrib/autograph/utils/BUILD index c2fdd40707775783140390e4b5c0186c9c3e562e..d82c17bf2afd01aedf4344f983b02c09abcb9bad 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -20,26 +20,33 @@ py_library( name = "utils", srcs = [ "__init__.py", + "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", - "printing.py", "py_func.py", "tensor_list.py", + "testing.py", "type_check.py", + "type_hints.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:dtypes", + "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", + "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], ) py_test( - name = "context_managers_test", - srcs = ["context_managers_test.py"], + name = "builtins_test", + srcs = ["builtins_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -47,8 +54,8 @@ py_test( ) py_test( - name = "misc_test", - srcs = ["misc_test.py"], + name = "context_managers_test", + srcs = ["context_managers_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -57,8 +64,8 @@ py_test( ) py_test( - name = "multiple_dispatch_test", - srcs = ["multiple_dispatch_test.py"], + name = "misc_test", + srcs = ["misc_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -67,8 +74,8 @@ py_test( ) py_test( - name = "py_func_test", - srcs = ["py_func_test.py"], + name = "multiple_dispatch_test", + srcs = ["multiple_dispatch_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -77,9 +84,10 @@ py_test( ) py_test( - name = "printing_test", - srcs = ["printing_test.py"], + name = "py_func_test", + srcs = ["py_func_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..817d4126d106487e1fea3e442712a69bbfccd7f3 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility module that contains APIs usable in the generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin +from tensorflow.contrib.autograph.utils.builtins import dynamic_print +from tensorflow.contrib.autograph.utils.builtins import dynamic_range +from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns +from tensorflow.contrib.autograph.utils.misc import alias_tensors +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not +from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond +from tensorflow.contrib.autograph.utils.py_func import wrap_py_func +from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append +from tensorflow.contrib.autograph.utils.testing import fake_tf +from tensorflow.contrib.autograph.utils.type_check import is_tensor +from tensorflow.contrib.autograph.utils.type_hints import set_element_type diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..998087e056c2cd264399982220d6e0528aab9edb --- /dev/null +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -0,0 +1,126 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builtin conversion utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops + + +def dynamic_builtin(f, *args, **kwargs): + """Converts a builtin function call inline.""" + if f is len: + return dynamic_len(*args, **kwargs) + if six.PY2 and f is xrange: + return dynamic_range(*args, **kwargs) + if f is range: + return dynamic_range(*args, **kwargs) + if f is int: + return dynamic_int(*args, **kwargs) + if f is float: + return dynamic_float(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) + + +def dynamic_len(list_or_tensor): + """Implementation of len using dynamic dispatch.""" + if tensor_util.is_tensor(list_or_tensor): + shape = list_or_tensor.shape + if not shape: + raise ValueError( + 'len requires non-zero rank for tensor "%s"' % list_or_tensor) + return array_ops.shape(list_or_tensor)[0] + return len(list_or_tensor) + + +def dynamic_int(num_or_tensor, **kwargs): + """Implementation of int() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) + return int(num_or_tensor) + + +def dynamic_float(num_or_tensor, **kwargs): + """Implementation of float() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) + return float(num_or_tensor) + + +def dynamic_range(start_or_stop, stop=None, step=None): + """Implementation of range using dynamic dispatch.""" + if type_check.is_tensor(start_or_stop, stop, step): + if step is not None: + return math_ops.range(start_or_stop, stop, step) + if stop is not None: + return math_ops.range(start_or_stop, stop) + return math_ops.range(start_or_stop) + + if step is not None: + return range(start_or_stop, stop, step) + elif stop is not None: + return range(start_or_stop, stop) + return range(start_or_stop) + + +def is_tf_print_compatible(value): + # TODO(mdan): Enable once we can reliably test this. + # This is currently disabled because we can't capture the output of + # op kernels from Python. + del value + return False + + +def dynamic_print(*values): + """Implementation of print using dynamic dispatch. + + The function attempts to use tf.Print if all the values are compatible. + Otherwise, it will fall back to py_func. + + Args: + *values: values to print + Returns: + A dummy value indicating the print completed. If tf. + """ + + if all(map(is_tf_print_compatible, values)): + return logging_ops.Print(1, values) + + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) + print(*vals) + # The flush helps avoid garbled output in IPython. + sys.stdout.flush() + + return py_func.wrap_py_func( + print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2312178a921037fa419818bf309d671c33914d --- /dev/null +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -0,0 +1,127 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class BuiltinsTest(test.TestCase): + + def test_dynamic_len_tf_scalar(self): + a = constant_op.constant(1) + + with self.assertRaises(ValueError): + with self.test_session() as sess: + sess.run(builtins.dynamic_builtin(len, a)) + + def test_dynamic_len_tf_array(self): + a = constant_op.constant([1, 2, 3]) + + with self.test_session() as sess: + self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_tf_matrix(self): + a = constant_op.constant([[1, 2], [3, 4]]) + + with self.test_session() as sess: + self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_py_list(self): + a = [3] * 5 + + self.assertEqual(5, builtins.dynamic_builtin(len, a)) + + def test_dynamic_range_all_python(self): + self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) + self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) + + def test_dynamic_range_tf(self): + with self.test_session() as sess: + self.assertAllEqual( + sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), + [0, 1, 2]) + self.assertAllEqual( + sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), + [1, 2]) + self.assertAllEqual( + sess.run( + builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), + [2, 1]) + + def test_dynamic_range_detection(self): + def range(x): # pylint:disable=redefined-builtin + return x + + # Functions that just have the names of builtins are rejected. + with self.assertRaises(NotImplementedError): + self.assertEqual(builtins.dynamic_builtin(range, 1), 1) + if six.PY2: + self.assertListEqual( + list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + + def test_casts(self): + i = constant_op.constant(2, dtype=dtypes.int32) + f = constant_op.constant(1.0, dtype=dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) + self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, True), 1) + self.assertEqual(builtins.dynamic_builtin(int, False), 0) + self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) + self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) + + def test_dynamic_print_tf(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_dynamic_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/autograph/utils/context_managers.py similarity index 85% rename from tensorflow/contrib/py2tf/utils/context_managers.py rename to tensorflow/contrib/autograph/utils/context_managers.py index 38d9e11fe9069722b9023fee848bf53e1f72de6a..3d150a95817b83c4d7aaa78dc250092dcc4c5a9b 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers.py +++ b/tensorflow/contrib/autograph/utils/context_managers.py @@ -21,6 +21,7 @@ from __future__ import print_function import contextlib from tensorflow.python.framework import ops +from tensorflow.python.ops import tensor_array_ops def control_dependency_on_returns(return_value): @@ -34,9 +35,15 @@ def control_dependency_on_returns(return_value): Returns: A context manager. """ + def control_dependency_handle(t): + if isinstance(t, tensor_array_ops.TensorArray): + return t.flow + return t + if return_value is None: return contextlib.contextmanager(lambda: (yield))() # TODO(mdan): Filter to tensor objects. if not isinstance(return_value, (list, tuple)): return_value = (return_value,) + return_value = tuple(control_dependency_handle(t) for t in return_value) return ops.control_dependencies(return_value) diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/autograph/utils/context_managers_test.py similarity index 82% rename from tensorflow/contrib/py2tf/utils/context_managers_test.py rename to tensorflow/contrib/autograph/utils/context_managers_test.py index 633ba93540e696889a6b2b71b40b999da39d48ff..42e27724b9856f715b524cdd7539897851715638 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers_test.py +++ b/tensorflow/contrib/autograph/utils/context_managers_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import context_managers +from tensorflow.contrib.autograph.utils import context_managers from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -32,6 +34,9 @@ class ContextManagersTest(test.TestCase): with context_managers.control_dependency_on_returns( constant_op.constant(1)): pass + with context_managers.control_dependency_on_returns( + tensor_array_ops.TensorArray(dtypes.int32, size=1)): + pass with context_managers.control_dependency_on_returns( [constant_op.constant(1), constant_op.constant(2)]): diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/autograph/utils/misc.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/misc.py rename to tensorflow/contrib/autograph/utils/misc.py diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py similarity index 78% rename from tensorflow/contrib/py2tf/utils/misc_test.py rename to tensorflow/contrib/autograph/utils/misc_test.py index bfcb304c838df69e9e3961907362c7939c065117..71e358c33e1ea9887d267c67bc80362bac26c3a6 100644 --- a/tensorflow/contrib/py2tf/utils/misc_test.py +++ b/tensorflow/contrib/autograph/utils/misc_test.py @@ -18,29 +18,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import misc -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import variables +from tensorflow.contrib.autograph.utils.misc import alias_tensors +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.ops.variables import Variable from tensorflow.python.platform import test -class ContextManagersTest(test.TestCase): +class MiscTest(test.TestCase): def test_alias_single_tensor(self): - a = constant_op.constant(1) + a = constant(1) - new_a = misc.alias_tensors(a) + new_a = alias_tensors(a) self.assertFalse(new_a is a) with self.test_session() as sess: self.assertEqual(1, sess.run(new_a)) def test_alias_tensors(self): - a = constant_op.constant(1) - v = variables.Variable(2) + a = constant(1) + v = Variable(2) s = 'a' l = [1, 2, 3] - new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l) + new_a, new_v, new_s, new_l = alias_tensors(a, v, s, l) self.assertFalse(new_a is a) self.assertTrue(new_v is v) diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/contrib/autograph/utils/multiple_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..70eef5676f61bcd978ea53260f0b86a817f2bd7c --- /dev/null +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch.py @@ -0,0 +1,66 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for type-dependent behavior used in autograph-generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils.type_check import is_tensor +from tensorflow.python.ops import control_flow_ops + + +def dynamic_is(left, right): + # TODO(alexbw) if we're sure we should leave 'is' in place, + # then change the semantics in converters/logical_expressions.py + return left is right + + +def dynamic_is_not(left, right): + return left is not right + + +def run_cond(condition, true_fn, false_fn): + """Type-dependent functional conditional. + + Args: + condition: A Tensor or Python bool. + true_fn: A Python callable implementing the true branch of the conditional. + false_fn: A Python callable implementing the false branch of the + conditional. + + Returns: + result: The result of calling the appropriate branch. If condition is a + Tensor, tf.cond will be used. Otherwise, a standard Python if statement will + be ran. + """ + if is_tensor(condition): + return control_flow_ops.cond(condition, true_fn, false_fn) + else: + return py_cond(condition, true_fn, false_fn) + + +def py_cond(condition, true_fn, false_fn): + """Functional version of Python's conditional.""" + if condition: + results = true_fn() + else: + results = false_fn() + + # The contract for the branch functions is to return tuples, but they should + # be collapsed to a single element when there is only one output. + if len(results) == 1: + return results[0] + return results diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py rename to tensorflow/contrib/autograph/utils/multiple_dispatch_test.py index 5bb4d4086b002211eebb86783bb7212c707a1418..f72f8e94a0df815f7d517e2b81ffc86c5c545f07 100644 --- a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import multiple_dispatch + +import numpy as np + +from tensorflow.contrib.autograph.utils import multiple_dispatch from tensorflow.python.client.session import Session from tensorflow.python.framework.constant_op import constant from tensorflow.python.platform import test @@ -25,44 +28,47 @@ from tensorflow.python.platform import test class MultipleDispatchTest(test.TestCase): + def test_dynamic_is_python(self): + a = np.eye(3) + also_a = a + not_actually_a = np.eye(3) + should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) + should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) + should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) + should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) + + def test_dynamic_is_tf(self): + with Session().as_default(): + a = constant([2.0]) + also_a = a + not_actually_a = constant([2.0]) + should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) + should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) + should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) + should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) + def test_run_cond_python(self): - true_fn = lambda: 2.0 - false_fn = lambda: 3.0 - self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) - self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + true_fn = lambda: (2,) + false_fn = lambda: (3,) + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3) def test_run_cond_tf(self): - - true_fn = lambda: constant([2.0]) - false_fn = lambda: constant([3.0]) + true_fn = lambda: (constant(2),) + false_fn = lambda: (constant(3),) with Session() as sess: out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) - self.assertEqual(sess.run(out), 2.0) + self.assertEqual(sess.run(out), 2) out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) - self.assertEqual(sess.run(out), 3.0) - - def test_run_while_python(self): - cond_fn = lambda x, t, s: x > t - body_fn = lambda x, t, s: (x * s, t, s) - - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5]) - self.assertEqual(x, 0.75) - - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5]) - self.assertEqual(x, 3.0) - - def test_run_while_tf(self): - cond_fn = lambda x, t, s: x > t - body_fn = lambda x, t, s: (x * s, t, s) - - with Session() as sess: - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, - [constant(3.0), 1.0, 0.5]) - self.assertEqual(sess.run(x), 0.75) - - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, - [constant(3.0), 4.0, 0.5]) - self.assertEqual(sess.run(x), 3.0) + self.assertEqual(sess.run(out), 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/contrib/autograph/utils/py_func.py new file mode 100644 index 0000000000000000000000000000000000000000..11ebfb2e49f0e762b56ae2cde2b76d2e24032d72 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pyfunc creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import script_ops + + +class MatchDType(namedtuple('MatchDType', ('arg_number',))): + """Allows matching the dtype of an argument. + + Used in conjunction with function calls. For example, MatchDType(0) will + match the DType of the first argument. + """ + + pass + + +def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False): + """Helper that wraps a callable to py_func. + + The helper passes tensor arguments through the py_func interface. Non-tensor + arguments are allowed, and will be passed to f directly. Note that non-tensor + arguments are captured by f will not update every time the wrapper is + called (this is consistent with its argument list, which only includes + the tensor arguments). In general, it's safest not to reuse this wrapper. + + Args: + f: Callable + return_dtypes: None, individual of tuple/list of DType or MatchDType, the + data type for each of f's return value(s). Set to None if f has no + return values or use_dummy_return is True. Use MatchDType to define a + dtype identical to that of `i`th argument (argument 0 is the first); + an argument must of Tensor type if it is to be used with MatchDType. + args: Positional arguments for f, as list or tuple. + kwargs: Keyword arguments for f, as dict with string keys. May be None. + use_dummy_return: If True, the function will return a dummy value of 1 + and discard its actual return value. + Returns: + The return values of f converted to tensor. + Raises: + ValueError: if any of the arguments are incorrect. + """ + + if return_dtypes and use_dummy_return: + raise ValueError('if use_dummy_return is True, return_dtypes must be empty') + + tensor_args = [] + tensor_args_idx = {} + + # Of the positional arguments, only grab the tensor ones to be passed through + # the py_func. + n_args = len(args) + arg_is_tensor = tuple(map(tensor_util.is_tensor, args)) + for i in range(n_args): + if arg_is_tensor[i]: + tensor_args_idx[i] = len(tensor_args) + tensor_args.append(args[i]) + + # We essentially take the tensor kwargs, if any, and add them to the list of + # positional arguments. The kwargs are then reconstructed inside the py_func. + # + # For example, if + # + # args = [Tensor(1), 'foo'] + # kwargs = {'a': Tensor(2), 'b': 'bar'} + # + # Then + # + # tensor_args = (Tensor(1), Tensor(2)) + # kwarg_keys = ('a', 'b') + if kwargs: + kwarg_keys = tuple(kwargs.keys()) + kwarg_is_tensor = {k: tensor_util.is_tensor(kwargs[k]) for k in kwarg_keys} + for k in kwarg_keys: + if kwarg_is_tensor[k]: + tensor_args_idx[k] = len(tensor_args) + tensor_args.append(kwargs[k]) + else: + kwarg_keys = () + + # Set up return dtypes. + def match_arg_dtype(arg_number): + arg = args[arg_number] + if not arg_is_tensor[arg_number]: + raise ValueError( + 'argument %d was used with MatchDType and must be a tf.Tensor, but ' + 'was %s instead' % (arg_number, type(arg))) + return arg.dtype + + if return_dtypes: + if isinstance(return_dtypes, MatchDType): + return_dtypes = match_arg_dtype(return_dtypes.arg_number) + elif isinstance(return_dtypes, (list, tuple)): + return_dtypes = tuple( + match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a + for a in return_dtypes) + else: + assert isinstance(return_dtypes, dtypes.DType) + + def f_wrapper(*tensor_args): + f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a + for i, a in enumerate(args)) + f_kwargs = { + k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k] + for i, k in enumerate(kwarg_keys) + } + retval = f(*f_args, **f_kwargs) + return 1 if use_dummy_return else retval + + return script_ops.py_func(f_wrapper, tensor_args, dtypes.int64 + if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2468263142f14332e86db99d198ba0f5c633dc69 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func_test.py @@ -0,0 +1,103 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrap_py_func module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class PyFuncTest(test.TestCase): + + def test_wrap_py_func_simple(self): + + def test_fn(a, b, c): + return a + b + c + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (1, constant_op.constant(1), 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func( + test_fn, dtypes.int64, + (constant_op.constant(1), 1, constant_op.constant(1))) + self.assertEqual(3, sess.run(result)) + + def test_wrap_py_func_complex_args(self): + + class TestClass(object): + + def __init__(self): + self.foo = 5 + + def test_fn(a, b): + return a * b.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) + self.assertEqual(35, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass())) + self.assertEqual(35, sess.run(result)) + + def test_wrap_py_func_kwargs(self): + + class TestClass(object): + + def __init__(self, foo): + self.foo = foo + + def test_fn(a, b, c, d): + return a * b.foo + c * d.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), { + 'c': 11, + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass(5)), { + 'c': constant_op.constant(11), + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + + def test_wrap_py_func_dummy_return(self): + + side_counter = [0] + + def test_fn(_): + side_counter[0] += 1 + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([1], side_counter) + result = py_func.wrap_py_func( + test_fn, None, (constant_op.constant(5),), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([2], side_counter) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/tensor_list.py b/tensorflow/contrib/autograph/utils/tensor_list.py similarity index 67% rename from tensorflow/contrib/py2tf/utils/tensor_list.py rename to tensorflow/contrib/autograph/utils/tensor_list.py index b6ff49e2a0eff384f10903e12212ab929e267804..2556f412891b4f0b954af5a6f0193341a6a5020a 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list.py +++ b/tensorflow/contrib/autograph/utils/tensor_list.py @@ -18,7 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +def dynamic_list_append(target, element): + """Converts a list append call inline.""" + if isinstance(target, tensor_array_ops.TensorArray): + return target.write(target.size(), element) + # TODO(mdan): What's the right way to check this? + # TODO(mdan): We may not need this branch. + # It may be possible to use TensorList alone if the loop body will not + # require wrapping it, although we'd have to think about an autoboxing + # mechanism for lists received as parameter. + if isinstance(target, ops.Tensor): + return list_ops.tensor_list_push_back(target, element) + + # Python targets (including TensorList): fallback to their original append. + target.append(element) + return target class TensorList(object): diff --git a/tensorflow/contrib/py2tf/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py similarity index 71% rename from tensorflow/contrib/py2tf/utils/tensor_list_test.py rename to tensorflow/contrib/autograph/utils/tensor_list_test.py index b5e554a162674e08da21785dcbe193c54647f128..d58489eb68b6b949a4276520605c62b7c2825558 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list_test.py +++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py @@ -12,22 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for PyFlow list.""" +"""Tests for Autograph lists.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import tensor_list as tl +from tensorflow.contrib.autograph.utils import tensor_list as tl from tensorflow.python.client.session import Session from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.constant_op import constant +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test class TensorListTest(test.TestCase): + def _shape(self, shape_tuple): + return constant(shape_tuple, dtypes.int32) + + def test_dynamic_list_append(self): + l = [] + l = tl.dynamic_list_append(l, 1) + self.assertListEqual(l, [1]) + + l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32) + l = tl.dynamic_list_append(l, 1) + s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(s), [1]) + + l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) + l = tl.dynamic_list_append(l, 1) + s = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(s), [1]) + + l = tl.TensorList(self._shape(()), dtypes.int32) + l = tl.dynamic_list_append(l, 1) + with self.test_session() as sess: + self.assertAllEqual(sess.run(l[0]), 1) + def test_list_append_python(self): with context.eager_mode(): a = constant(3.0) diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py b/tensorflow/contrib/autograph/utils/testing.py similarity index 65% rename from tensorflow/contrib/bayesflow/python/ops/custom_grad.py rename to tensorflow/contrib/autograph/utils/testing.py index ca1ecb9c40204c3c723fa3423cfe148e823adc28..cb4785d0dc0f4674b3560418daeb6733364b21e7 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py +++ b/tensorflow/contrib/autograph/utils/testing.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for specifying custom gradients. - -See ${python/contrib.bayesflow.custom_gradient}. -""" +"""Testing utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.custom_grad_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import imp + +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops -_allowed_symbols = [ - 'custom_gradient', -] -remove_undocumented(__name__, _allowed_symbols) +def fake_tf(): + """Creates a fake module that looks like TensorFlow, for testing.""" + mod = imp.new_module('tensorflow') + mod_contents = dict() + mod_contents.update(math_ops.__dict__) + mod_contents.update(ops.__dict__) + mod_contents.update(mod.__dict__) + mod.__dict__.update(mod_contents) + return mod diff --git a/tensorflow/contrib/py2tf/utils/type_check.py b/tensorflow/contrib/autograph/utils/type_check.py similarity index 86% rename from tensorflow/contrib/py2tf/utils/type_check.py rename to tensorflow/contrib/autograph/utils/type_check.py index 9ca2dec872c8a9ca7bedaa8603f70e3214a3e24a..8748abc47bcfb55b4d0b11178a46816249732da9 100644 --- a/tensorflow/contrib/py2tf/utils/type_check.py +++ b/tensorflow/contrib/autograph/utils/type_check.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities used in py2tf-generated code.""" +"""Utilities used in autograph-generated code.""" from __future__ import absolute_import from __future__ import division @@ -22,12 +22,12 @@ from tensorflow.python.framework import tensor_util def is_tensor(*args): - """Check if all arguments are tensors. + """Check if any arguments are tensors. Args: *args: Python objects that may or may not be tensors. Returns: - True if all *args are TensorFlow types, False if one or more are not. + True if any *args are TensorFlow types, False if none are. """ return any([tensor_util.is_tensor(a) for a in args]) diff --git a/tensorflow/contrib/py2tf/utils/type_check_test.py b/tensorflow/contrib/autograph/utils/type_check_test.py similarity index 96% rename from tensorflow/contrib/py2tf/utils/type_check_test.py rename to tensorflow/contrib/autograph/utils/type_check_test.py index 7d0428e9cccecdc67511e236bc00655a055aea29..3b67b7194c5656b193d47860f93986a985cb1aef 100644 --- a/tensorflow/contrib/py2tf/utils/type_check_test.py +++ b/tensorflow/contrib/autograph/utils/type_check_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy -from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/autograph/utils/type_hints.py similarity index 54% rename from tensorflow/contrib/lite/toco/python/toco_wrapper.py rename to tensorflow/contrib/autograph/utils/type_hints.py index e39b5f22c7c8ffafaf72129be6f54090e6761dc3..aeb9e545610460afbe364dfcfc7a54b9aede29fe 100644 --- a/tensorflow/contrib/lite/toco/python/toco_wrapper.py +++ b/tensorflow/contrib/autograph/utils/type_hints.py @@ -12,24 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wrapper for runninmg toco binary embedded in pip site-package. +"""No-op utilities that provide static type hints. -NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. -It can only install Python "console-scripts." This will work as a console -script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). +These are used when the data type is not known at creation, for instance in the +case of empty lists. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import sys -import tensorflow as tf +def set_element_type(entity, dtype, shape=None): + """Indicates that the entity is expected hold items of specified type. + + This function is a no-op. Its presence merely marks the data type of its + argument. The staged TensorFlow ops will reflect and assert this data type. -def main(): - # Pip installs the binary in aux-bin off of main site-package install. - # Just find it and exec, passing all arguments in the process. - # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. - binary = os.path.join(tf.__path__[0], 'aux-bin/toco') - os.execvp(binary, sys.argv) + Args: + entity: A Tensor or TensorArray. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + Returns: + The value of entity, unchanged. + """ + del dtype + del shape + return entity diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index ee67909133fc26ba98355db05a4b90d3dfa6b97b..b27a19b16c08cb588b45949105a6399623e766e1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -49,6 +49,14 @@ cc_library( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], @@ -96,6 +104,7 @@ py_test( name = "batch_ops_test", size = "small", srcs = ["python/ops/batch_ops_test.py"], + shard_count = 5, srcs_version = "PY2AND3", tags = [ "manual", @@ -112,14 +121,3 @@ py_test( "//tensorflow/python:script_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a73bfb1bf008f6f4eafd14913c88dcfa..1e503a097a7b72d9244b0a1cf57747c4b4122c81 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4e478c3e60771fdc3ae99febc33d2e3..012a51f71101471850d312033c41dcbc4805d44c 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -83,6 +84,74 @@ def batch_function(num_batch_threads, SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors. + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + for inp in computation.captured_inputs: + print("inp: %s" % inp) + for op in inp.consumers(): + print("op: %s" % op) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + Args: num_batch_threads: Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index fac7aff29f79fa18fa5f7e596db8afedabaa8993..78468145469df216344bc00f116add250dc51dd3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,7 +23,10 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -185,12 +188,38 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] @@ -205,6 +234,138 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOp(self): + """Tests that the batch_function op works.""" + with self.test_session() as sess: + + @function.Defun(dtypes.int32) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = gen_batch_ops.batch_function( + [inp], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, + Tout=[dtypes.int32], + f=computation, + captured_tensors=computation.captured_inputs) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithCapturedInput(self): + """Tests that batch_function op works with captured input.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32) + def computation(inp): + return inp + captured_inp0 - captured_inp1 + + result = gen_batch_ops.batch_function( + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="", + f=computation, + in_tensors=[inp], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + + def testBasicUnbatchDecoratedWithReshape(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return array_ops.reshape(in_t, [-1]) + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [[2]]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" with self.test_session() as sess: @@ -250,7 +411,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" with self.test_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, batch_timeout_micros=36000000, grad_timeout_micros=1000000, diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/contrib/batching/serial_device_batch_scheduler.h similarity index 59% rename from tensorflow/compiler/xla/service/versioned_computation_handle.cc rename to tensorflow/contrib/batching/serial_device_batch_scheduler.h index a693c4695f0e776cf297d0ecd28d6de53bd5c0c6..bf6b7083612018eecf0d1784e60cbbf0c5796fef 100644 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc +++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h @@ -13,20 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" -namespace xla { - -string VersionedComputationHandle::ToString() const { - return tensorflow::strings::StrCat(handle.handle(), ":v", version); -} - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle) { - out << versioned_handle.ToString(); - return out; -} - -} // namespace xla +#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD index 6db627faad1df4a4b73082e74e7754829ff2b514..7cb2d8079bd18660f72eab92654629434ce4d6a5 100644 --- a/tensorflow/contrib/batching/test_util/BUILD +++ b/tensorflow/contrib/batching/test_util/BUILD @@ -8,17 +8,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - cc_library( name = "fake_clock_env", testonly = 1, diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD index 2a84a7712a8fa66e89db41ff4e7ebe4f620029ca..8f81b6702f2807d7da7e72190ce2d86b28e52113 100644 --- a/tensorflow/contrib/batching/util/BUILD +++ b/tensorflow/contrib/batching/util/BUILD @@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "**/google_*", - ], - ), -) - cc_library( name = "periodic_function_dynamic", hdrs = ["periodic_function.h"], diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 74712aeb67c3f0a31def78f25a0298f9c02c9590..5a2d7f6a3c0ba233299a5790fa80488786712f3c 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -37,126 +37,6 @@ py_library( ], ) -cuda_py_test( - name = "metropolis_hastings_test", - size = "medium", - srcs = ["python/kernel_tests/metropolis_hastings_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "csiszar_divergence_test", - size = "medium", - srcs = ["python/kernel_tests/csiszar_divergence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], - tags = [ - "manual", # b/64490288 - "notap", - ], -) - -cuda_py_test( - name = "custom_grad_test", - size = "small", - srcs = ["python/kernel_tests/custom_grad_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "layers_conv_variational_test", - size = "small", - srcs = ["python/kernel_tests/layers_conv_variational_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], -) - -cuda_py_test( - name = "layers_dense_variational_test", - size = "small", - srcs = ["python/kernel_tests/layers_dense_variational_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], -) - -cuda_py_test( - name = "mcmc_diagnostics_test", - size = "small", - srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:spectral_ops_test_util", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], -) - cuda_py_test( name = "monte_carlo_test", size = "small", @@ -177,117 +57,3 @@ cuda_py_test( "//tensorflow/python:random_seed", ], ) - -cuda_py_test( - name = "halton_sequence_test", - size = "small", - srcs = ["python/kernel_tests/halton_sequence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - tags = ["no_mac"], # b/73192243 -) - -cuda_py_test( - name = "hmc_test", - size = "medium", - srcs = ["python/kernel_tests/hmc_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], -) - -cuda_py_test( - name = "sgld_optimizer_test", - size = "small", - srcs = ["python/kernel_tests/sgld_optimizer_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["notsan"], -) - -cuda_py_test( - name = "variable_utils_test", - size = "small", - srcs = ["python/kernel_tests/variable_utils_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "variational_sgd_optimizer_test", - size = "small", - srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["notsan"], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/bayesflow/README.md b/tensorflow/contrib/bayesflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..10323dc6d59918a9f8cf1840d06dcd219dfe3568 --- /dev/null +++ b/tensorflow/contrib/bayesflow/README.md @@ -0,0 +1,17 @@ +# Notice + +`tf.contrib.bayesflow` has moved! + +See new code at [github.com/tensorflow/probability]( +https://github.com/tensorflow/probability). + +Switch imports with: + +```python +# old +import tensorflow as tf +tfp = tf.contrib.bayesflow + +# new +import tensorflow_probability as tfp +``` diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 528c4fbacd06c7b0defa0e32bd24a98b2bc07b64..41a8c920fc4e81af90f4c94a149d8c404c58b747 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -21,36 +21,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence -from tensorflow.contrib.bayesflow.python.ops import custom_grad -from tensorflow.contrib.bayesflow.python.ops import halton_sequence -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops import layers -from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo -from tensorflow.contrib.bayesflow.python.ops import optimizers -from tensorflow.contrib.bayesflow.python.ops import variable_utils # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'csiszar_divergence', - 'custom_grad', - 'entropy', - 'halton_sequence', - 'hmc', - 'layers', - 'metropolis_hastings', - 'mcmc_diagnostics', 'monte_carlo', - 'optimizers', - 'special_math', - 'stochastic_variables', - 'variable_utils', - 'variational_inference', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py deleted file mode 100644 index 2e94b7206de4f7c40c89f083f3bfa2a22bb7b917..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ /dev/null @@ -1,1004 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Csiszar Divergence Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence_impl -from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -cd = csiszar_divergence_impl - - -def tridiag(d, diag_value, offdiag_value): - """d x d matrix with given value on diag, and one super/sub diag.""" - diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) - three_bands = array_ops.matrix_band_part( - array_ops.fill([d, d], offdiag_value), 1, 1) - return diag_mat + three_bands - - -class AmariAlphaTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for alpha in [-1., 0., 1., 2.]: - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(0., alpha=alpha, - self_normalized=normalized).eval(), - 0.) - - def test_correct_when_alpha0(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0.).eval(), - -self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0., self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - def test_correct_when_alpha1(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1.).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1., self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - def test_correct_when_alpha_not_01(self): - for alpha in [-2, -1., -0.5, 0.5, 2.]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=False).eval(), - ((self._u**alpha - 1)) / (alpha * (alpha - 1.))) - - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=True).eval(), - ((self._u**alpha - 1.) - - alpha * (self._u - 1)) / (alpha * (alpha - 1.))) - - -class KLReverseTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_reverse(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_reverse(self._logu).eval(), - -self._logu) - - self.assertAllClose( - cd.kl_reverse(self._logu, self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - -class KLForwardTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_forward(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_forward(self._logu).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.kl_forward(self._logu, self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - -class JensenShannonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jensen_shannon(0.).eval(), np.log(0.25)) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jensen_shannon).eval()) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - cd.symmetrized_csiszar_function( - self._logu, - lambda x: cd.jensen_shannon(x, self_normalized=True)).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - (self._u * self._logu - - (1 + self._u) * np.log1p(self._u))) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - (self._u * self._logu - - (1 + self._u) * np.log((1 + self._u) / 2))) - - -class ArithmeticGeometricMeanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.arithmetic_geometric(0.).eval(), np.log(4)) - self.assertAllClose( - cd.arithmetic_geometric(0., self_normalized=True).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.arithmetic_geometric).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - (1. + self._u) * np.log((1. + self._u) / np.sqrt(self._u))) - - self.assertAllClose( - cd.arithmetic_geometric(self._logu, self_normalized=True).eval(), - (1. + self._u) * np.log(0.5 * (1. + self._u) / np.sqrt(self._u))) - - -class TotalVariationTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.total_variation(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.total_variation(self._logu).eval(), - 0.5 * np.abs(self._u - 1)) - - -class PearsonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.pearson(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.pearson(self._logu).eval(), - np.square(self._u - 1)) - - -class SquaredHellingerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.squared_hellinger(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.squared_hellinger).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - np.square(np.sqrt(self._u) - 1)) - - -class TriangularTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.triangular(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.triangular).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - np.square(self._u - 1) / (1 + self._u)) - - -class TPowerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.t_power(0., t=-0.1).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=0.5).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=1.1).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=-0.1, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=0.5, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=1.1, self_normalized=True).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1)).eval(), - self._u ** -0.1 - 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5)).eval(), - -self._u ** 0.5 + 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1)).eval(), - self._u ** 1.1 - 1.) - - def test_correct_self_normalized(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1), - self_normalized=True).eval(), - self._u ** -0.1 - 1. + 0.1 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5), - self_normalized=True).eval(), - -self._u ** 0.5 + 1. + 0.5 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1), - self_normalized=True).eval(), - self._u ** 1.1 - 1. - 1.1 * (self._u - 1.)) - - -class Log1pAbsTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.log1p_abs(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.log1p_abs(self._logu).eval(), - self._u**(np.sign(self._u - 1)) - 1) - - -class JeffreysTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jeffreys(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jeffreys).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - 0.5 * (self._u * self._logu - self._logu)) - - -class ChiSquareTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.chi_square(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.chi_square(self._logu).eval(), - self._u**2 - 1) - - -class ModifiedGanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(0.).eval(), np.log(2)) - self.assertAllClose( - cd.modified_gan(0., self_normalized=True).eval(), np.log(2)) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(self._logu).eval(), - np.log1p(self._u) - self._logu) - - self.assertAllClose( - cd.modified_gan(self._logu, self_normalized=True).eval(), - np.log1p(self._u) - self._logu + 0.5 * (self._u - 1)) - - -class SymmetrizedCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_jensen_shannon(self): - with self.test_session(): - - # The following functions come from the claim made in the - # symmetrized_csiszar_function docstring. - def js1(logu): - return (-logu - - (1. + math_ops.exp(logu)) * ( - nn_ops.softplus(logu))) - - def js2(logu): - return 2. * (math_ops.exp(logu) * ( - logu - nn_ops.softplus(logu))) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js1).eval(), - cd.jensen_shannon(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js2).eval(), - cd.jensen_shannon(self._logu).eval()) - - def test_jeffreys(self): - with self.test_session(): - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.jeffreys(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.jeffreys(self._logu).eval()) - - -class DualCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_kl_forward(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.kl_reverse(self._logu).eval()) - - def test_kl_reverse(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.kl_forward(self._logu).eval()) - - -class MonteCarloCsiszarFDivergenceTest(test.TestCase): - - def test_kl_forward(self): - with self.test_session() as sess: - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse(self): - with self.test_session() as sess: - - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.07, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - def test_kl_forward_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.06, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.05, atol=0.) - - def test_score_trick(self): - - with self.test_session() as sess: - d = 5 # Dimension - num_draws = int(1e5) - seed = 1 - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [d])) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed) - - approx_kl_self_normalized_score_trick = ( - cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed)) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - [ - approx_kl_grad_, - approx_kl_self_normalized_grad_, - approx_kl_score_trick_grad_, - approx_kl_self_normalized_score_trick_grad_, - exact_kl_grad_, - approx_kl_, - approx_kl_self_normalized_, - approx_kl_score_trick_, - approx_kl_self_normalized_score_trick_, - exact_kl_, - ] = sess.run([ - grad_sum(approx_kl), - grad_sum(approx_kl_self_normalized), - grad_sum(approx_kl_score_trick), - grad_sum(approx_kl_self_normalized_score_trick), - grad_sum(exact_kl), - approx_kl, - approx_kl_self_normalized, - approx_kl_score_trick, - approx_kl_self_normalized_score_trick, - exact_kl, - ]) - - # Test average divergence. - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_score_trick_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, - rtol=0.08, atol=0.) - - # Test average gradient-divergence. - self.assertAllClose(approx_kl_grad_, exact_kl_grad_, - rtol=0.007, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_, - rtol=0.011, atol=0.) - - self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_, - rtol=0.018, atol=0.) - - self.assertAllClose( - approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_, - rtol=0.017, atol=0.) - - -class CsiszarVIMCOTest(test.TestCase): - - def _csiszar_vimco_helper(self, logu): - """Numpy implementation of `csiszar_vimco_helper`.""" - - # Since this is a naive/intuitive implementation, we compensate by using the - # highest precision we can. - logu = np.float128(logu) - n = logu.shape[0] - u = np.exp(logu) - loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu). - for j in range(n): - loogeoavg_u.append(np.exp(np.mean( - [logu[i, ...] for i in range(n) if i != j], - axis=0))) - loogeoavg_u = np.array(loogeoavg_u) - - loosum_u = [] # Leave-one-out sum of exp(logu). - for j in range(n): - loosum_u.append(np.sum( - [u[i, ...] for i in range(n) if i != j], - axis=0)) - loosum_u = np.array(loosum_u) - - # Natural log of the average u except each is swapped-out for its - # leave-`i`-th-out Geometric average. - log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n) - - log_avg_u = np.log(np.mean(u, axis=0)) - return log_avg_u, log_sooavg_u - - def _csiszar_vimco_helper_grad(self, logu, delta): - """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`.""" - - # This code actually estimates the sum of the Jacobiab because that's what - # TF's `gradients` does. - np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper( - logu[..., None] + np.diag([delta]*len(logu))) - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper( - logu[..., None]) - return [ - (np_log_avg_u1 - np_log_avg_u) / delta, - np.sum(np_log_sooavg_u1 - np_log_sooavg_u, axis=0) / delta, - ] - - def test_vimco_helper_1(self): - """Tests that function calculation correctly handles batches.""" - - logu = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-8, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-8, atol=0.) - - def test_vimco_helper_2(self): - """Tests that function calculation correctly handles overflow.""" - - # Using 700 (rather than 1e3) since naive numpy version can't handle higher. - logu = np.float32([0., 700, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-6, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-5, atol=0.) - - def test_vimco_helper_3(self): - """Tests that function calculation correctly handles underlow.""" - - logu = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-5, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-4, atol=1e-15) - - def test_vimco_helper_gradient_using_finite_difference_1(self): - """Tests that gradient calculation correctly handles batches.""" - - logu_ = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - # We skip checking against finite-difference approximation since it - # doesn't support batches. - - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_2(self): - """Tests that gradient calculation correctly handles overflow.""" - - delta = 1e-3 - logu_ = np.float32([0., 1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_3(self): - """Tests that gradient calculation correctly handles underlow.""" - - delta = 1e-3 - logu_ = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_and_gradient(self): - - with self.test_session() as sess: - dims = 5 # Dimension - num_draws = int(20) - num_batch_draws = int(3) - seed = 1 - - f = lambda logu: cd.kl_reverse(logu, self_normalized=False) - np_f = lambda logu: -logu - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [dims])) - - vimco = cd.csiszar_vimco( - f=f, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - num_batch_draws=num_batch_draws, - seed=seed) - - x = q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed) - x = array_ops.stop_gradient(x) - logu = p.log_prob(x) - q.log_prob(x) - f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - def jacobian(x): - # Warning: this function is slow and may not even finish if prod(shape) - # is larger than, say, 100. - shape = x.shape.as_list() - assert all(s is not None for s in shape) - x = array_ops.reshape(x, shape=[-1]) - r = [grad_sum(x[i]) for i in range(np.prod(shape))] - return array_ops.reshape(array_ops.stack(r), shape=shape) - - [ - logu_, - jacobian_logqx_, - vimco_, - grad_vimco_, - f_log_sum_u_, - grad_mean_f_log_sum_u_, - ] = sess.run([ - logu, - jacobian(q.log_prob(x)), - vimco, - grad_sum(vimco), - f_log_sum_u, - grad_sum(f_log_sum_u) / num_batch_draws, - ]) - - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu_) - - # Test VIMCO loss is correct. - self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, - rtol=1e-5, atol=0.) - - # Test gradient of VIMCO loss is correct. - # - # To make this computation we'll inject two gradients from TF: - # - grad[mean(f(log(sum(p(x)/q(x)))))] - # - jacobian[log(q(x))]. - # - # We now justify why using these (and only these) TF values for - # ground-truth does not undermine the completeness of this test. - # - # Regarding `grad_mean_f_log_sum_u_`, note that we validate the - # correctness of the zero-th order derivative (for each batch member). - # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient - # information, we can safely rely on TF. - self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) - # - # Regarding `jacobian_logqx_`, note that testing the gradient of - # `q.log_prob` is outside the scope of this unit-test thus we may safely - # use TF to find it. - - # The `mean` is across batches and the `sum` is across iid samples. - np_grad_vimco = ( - grad_mean_f_log_sum_u_ - + np.mean( - np.sum( - jacobian_logqx_ * (np_f(np_log_avg_u) - - np_f(np_log_sooavg_u)), - axis=0), - axis=0)) - - self.assertAllClose(np_grad_vimco, grad_vimco_, - rtol=1e-5, atol=0.) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py deleted file mode 100644 index a95df31ac1fd9f5038abe779391ccba5f7fe408d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Custom Gradient Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import custom_grad_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -cg = custom_grad_impl - - -class CustomGradientTest(test.TestCase): - - def test_works_correctly(self): - with self.test_session() as sess: - f = lambda x: x**2 / 2 - g = lambda x: (x - 1)**3 / 3 - x_ = np.linspace(-100, 100, int(1e4)) + [0.] - - x = constant_op.constant(x_) - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, x)[0] - [fx_, gx_] = sess.run([fx, gx]) - - self.assertAllClose(f(x_), fx_) - self.assertAllClose(g(x_), gx_) - - def test_works_correctly_both_f_g_zero(self): - with self.test_session() as sess: - f = lambda x: x**2 / 2 - g = lambda x: x**3 / 3 - x_ = np.linspace(-100, 100, int(1e4)) + [0.] - - x = constant_op.constant(x_) - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, x)[0] - [fx_, gx_] = sess.run([fx, gx]) - - self.assertAllClose(f(x_), fx_) - self.assertAllClose(g(x_), gx_) - - def test_works_correctly_vector_of_vars(self): - with self.test_session() as sess: - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(2)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(3)) - sess.run([variables.global_variables_initializer()]) - - f = lambda z: z[0] * z[1] - g = lambda z: z[0]**2 * z[1]**2 / 2 - - z = array_ops.stack([x, y]) - fz = cg.custom_gradient(f(z), g(z), z, axis=0) - gz = gradients_impl.gradients(fz, variables.trainable_variables()) - [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]]) - - self.assertEqual(f(z_), fz_) - self.assertEqual(g(z_), gx_) - self.assertEqual(g(z_), gy_) - - def test_works_correctly_side_vars(self): - with self.test_session() as sess: - x_ = np.float32(2.1) # Adding extra tenth to force imprecision. - y_ = np.float32(3.1) - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(x_)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(y_)) - sess.run([variables.global_variables_initializer()]) - - f = lambda x: x * y - g = lambda z: math_ops.square(x) * y - - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, variables.trainable_variables()) - [x_, fx_, gx_] = sess.run([x, fx, gx[0]]) - gy_ = gx[1] - - self.assertEqual(x_ * y_, fx_) - self.assertEqual(np.square(x_) * y_, gx_) - self.assertEqual(None, gy_) - - def test_works_correctly_fx_gx_manually_stopped(self): - with self.test_session() as sess: - x_ = np.float32(2.1) # Adding extra tenth to force imprecision. - y_ = np.float32(3.1) - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(x_)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(y_)) - sess.run([variables.global_variables_initializer()]) - - stop = array_ops.stop_gradient # For readability. - - # Basically we need to stop the `x` portion of `f`. And when we supply the - # arg to `custom_gradient` we need to stop the complement, i.e., the `y` - # part. - f = lambda x: stop(x) * y - g = lambda x: stop(math_ops.square(x)) * y - fx = cg.custom_gradient(f(x), g(x), x + stop(y), - fx_gx_manually_stopped=True) - - gx = gradients_impl.gradients(fx, variables.trainable_variables()) - [x_, fx_, gx_, gy_] = sess.run([x, fx, gx[0], gx[1]]) - - self.assertEqual(x_ * y_, fx_) - self.assertEqual(np.square(x_) * y_, gx_) - self.assertEqual(x_, gy_) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py deleted file mode 100644 index 0a85862abfd744a86b9a38e10dbb5b985d0a0e94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for halton_sequence.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -mc = monte_carlo_lib - - -class HaltonSequenceTest(test.TestCase): - - def test_known_values_small_bases(self): - with self.test_session(): - # The first five elements of the Halton sequence with base 2 and 3 - expected = np.array(((1. / 2, 1. / 3), - (1. / 4, 2. / 3), - (3. / 4, 1. / 9), - (1. / 8, 4. / 9), - (5. / 8, 7. / 9)), dtype=np.float32) - sample = halton.sample(2, num_samples=5) - self.assertAllClose(expected, sample.eval(), rtol=1e-6) - - def test_sample_indices(self): - with self.test_session(): - dim = 5 - indices = math_ops.range(10, dtype=dtypes.int32) - sample_direct = halton.sample(dim, num_samples=10) - sample_from_indices = halton.sample(dim, sample_indices=indices) - self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), - rtol=1e-6) - - def test_dtypes_works_correctly(self): - with self.test_session(): - dim = 3 - sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) - sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) - self.assertEqual(sample_float32.eval().dtype, np.float32) - self.assertEqual(sample_float64.eval().dtype, np.float64) - - def test_normal_integral_mean_and_var_correctly_estimated(self): - n = int(1000) - # This test is almost identical to the similarly named test in - # monte_carlo_test.py. The only difference is that we use the Halton - # samples instead of the random samples to evaluate the expectations. - # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) - # (N=number of samples). For QMC in low dimensions, the expected convergence - # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the - # 1e6 samples used in the pseudo-random monte carlo. - with self.test_session(): - mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) - mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) - sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) - sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) - p = normal_lib.Normal(loc=mu_p, scale=sigma_p) - q = normal_lib.Normal(loc=mu_q, scale=sigma_q) - - cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) - q_sample = q.quantile(cdf_sample) - - # Compute E_p[X]. - e_x = mc.expectation_importance_sampler( - f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) - - # Compute E_p[X^2]. - e_x2 = mc.expectation_importance_sampler( - f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) - - stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) - # Keep the tolerance levels the same as in monte_carlo_test.py. - self.assertEqual(p.batch_shape, e_x.get_shape()) - self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) - self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) - - def test_docstring_example(self): - # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 - dim = 3 - with self.test_session(): - sample = halton.sample(dim, num_samples=num_samples) - - # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional - # hypercube. - powers = math_ops.range(1.0, limit=dim + 1) - integral = math_ops.reduce_mean( - math_ops.reduce_prod(sample ** powers, axis=-1)) - true_value = 1.0 / math_ops.reduce_prod(powers + 1.0) - - # Produces a relative absolute error of 1.7%. - self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) - - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. - - sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, - dtype=dtypes.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) - - integral_leaped = math_ops.reduce_mean( - math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) - self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py deleted file mode 100644 index 5bd834e56245ab4d874544cfd014fe59ae521ea8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ /dev/null @@ -1,863 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Hamiltonian Monte Carlo.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -import numpy as np -from scipy import stats - -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator - -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_linalg_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging_ops - - -def _reduce_variance(x, axis=None, keepdims=False): - sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) - return math_ops.reduce_mean( - math_ops.squared_difference(x, sample_mean), axis, keepdims) - - -class HMCTest(test.TestCase): - - def setUp(self): - self._shape_param = 5. - self._rate_param = 10. - - random_seed.set_random_seed(10003) - np.random.seed(10003) - - def assertAllFinite(self, x): - self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) - - def _log_gamma_log_prob(self, x, event_dims=()): - """Computes log-pdf of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - """ - return math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims) - - def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, - feed_dict=None): - step_size = array_ops.placeholder(np.float32, [], name="step_size") - hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict[hmc_lf_steps] = 1000 - - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - - m = random_ops.random_normal(array_ops.shape(x)) - log_prob_0 = self._log_gamma_log_prob(x, event_dims) - grad_0 = gradients_ops.gradients(log_prob_0, x) - old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) - - new_m, _, log_prob_1, _ = _leapfrog_integrator( - current_momentums=[m], - target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), - current_state_parts=[x], - step_sizes=[step_size], - num_leapfrog_steps=hmc_lf_steps, - current_target_log_prob=log_prob_0, - current_grads_target_log_prob=grad_0) - new_m = new_m[0] - - new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, - event_dims) - - x_shape = sess.run(x, feed_dict).shape - event_size = np.prod(x_shape[independent_chain_ndims:]) - feed_dict[step_size] = 0.1 / event_size - old_energy_, new_energy_ = sess.run([old_energy, new_energy], - feed_dict) - logging_ops.vlog(1, "average energy relative change: {}".format( - (1. - new_energy_ / old_energy_).mean())) - self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) - - def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): - """Tests the long-term energy conservation of the leapfrog integrator. - - The leapfrog integrator is symplectic, so for sufficiently small step - sizes it should be possible to run it more or less indefinitely without - the energy of the system blowing up or collapsing. - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._integrator_conserves_energy(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper(0) - - def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper(1) - - def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper(2) - - def testIntegratorEnergyConservation3(self): - self._integrator_conserves_energy_wrapper(3) - - def testSampleChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - num_results = 10 - independent_chain_ndims = 1 - - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - kwargs = dict( - target_log_prob_fn=log_gamma_log_prob, - current_state=np.random.rand(4, 3, 2), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=150, - seed=52, - ) - - samples0, kernel_results0 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=2 * num_results, - num_steps_between_results=0).items()))) - - samples1, kernel_results1 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=num_results, - num_steps_between_results=1).items()))) - - [ - samples0_, - samples1_, - target_log_prob0_, - target_log_prob1_, - ] = sess.run([ - samples0, - samples1, - kernel_results0.current_target_log_prob, - kernel_results1.current_target_log_prob, - ]) - self.assertAllClose(samples0_[::2], samples1_, - atol=1e-5, rtol=1e-5) - self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, - atol=1e-5, rtol=1e-5) - - def _chain_gets_correct_expectations(self, x, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - def log_gamma_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - num_results = array_ops.placeholder( - np.int32, [], name="num_results") - step_size = array_ops.placeholder( - np.float32, [], name="step_size") - num_leapfrog_steps = array_ops.placeholder( - np.int32, [], name="num_leapfrog_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict.update({num_results: 150, - step_size: 0.05, - num_leapfrog_steps: 2}) - - samples, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_gamma_log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - num_burnin_steps=150, - seed=42) - - self.assertAllEqual(dict(target_calls=2), counter) - - expected_x = (math_ops.digamma(self._shape_param) - - np.log(self._rate_param)) - - expected_exp_x = self._shape_param / self._rate_param - - acceptance_probs_, samples_, expected_x_ = sess.run( - [kernel_results.acceptance_probs, samples, expected_x], - feed_dict) - - actual_x = samples_.mean() - actual_exp_x = np.exp(samples_).mean() - - logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( - expected_x_, expected_exp_x)) - logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( - actual_x, actual_exp_x)) - self.assertNear(actual_x, expected_x_, 2e-2) - self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ > 0.5) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ <= 1.) - - def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper(0) - - def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper(1) - - def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper(2) - - def testKernelResultsUsingTruncatedDistribution(self): - def log_prob(x): - return array_ops.where( - x >= 0., - -x - x**2, # Non-constant gradient. - array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) - # This log_prob has the property that it is likely to attract - # the HMC flow toward, and below, zero...but for x <=0, - # log_prob(x) = -inf, which should result in rejection, as well - # as a non-finite log_prob. Thus, this distribution gives us an opportunity - # to test out the kernel results ability to correctly capture rejections due - # to finite AND non-finite reasons. - # Why use a non-constant gradient? This ensures the leapfrog integrator - # will not be exact. - - num_results = 1000 - # Large step size, will give rejections due to integration error in addition - # to rejection due to going into a region of log_prob = -inf. - step_size = 0.1 - num_leapfrog_steps = 5 - num_chains = 2 - - with self.test_session(graph=ops.Graph()) as sess: - - # Start multiple independent chains. - initial_state = ops.convert_to_tensor([0.1] * num_chains) - - states, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_prob, - current_state=initial_state, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - seed=42) - - states_, kernel_results_ = sess.run([states, kernel_results]) - pstates_ = kernel_results_.proposed_state - - neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) - - # First: Test that the mathematical properties of the above log prob - # function in conjunction with HMC show up as expected in kernel_results_. - - # We better have log_prob = -inf some of the time. - self.assertLess(0, neg_inf_mask.sum()) - # We better have some rejections due to something other than -inf. - self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) - # We better have been accepted a decent amount, even near the end of the - # chain, or else this HMC run just got stuck at some point. - self.assertLess( - 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) - # We better not have any NaNs in proposed state or log_prob. - # We may have some NaN in grads, which involve multiplication/addition due - # to gradient rules. This is the known "NaN grad issue with tf.where." - self.assertAllEqual(np.zeros_like(states_), - np.isnan(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual(np.zeros_like(states_), - np.isnan(states_)) - # We better not have any +inf in states, grads, or log_prob. - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual( - np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(states_)) - - # Second: Test that kernel_results is congruent with itself and - # acceptance/rejection of states. - - # Proposed state is negative iff proposed target log prob is -inf. - np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) - np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) - - # Acceptance probs are zero whenever proposed state is negative. - self.assertAllEqual( - np.zeros_like(pstates_[neg_inf_mask]), - kernel_results_.acceptance_probs[neg_inf_mask]) - - # The move is accepted ==> state = proposed state. - self.assertAllEqual( - states_[kernel_results_.is_accepted], - pstates_[kernel_results_.is_accepted], - ) - # The move was rejected <==> state[t] == state[t - 1]. - for t in range(1, num_results): - for i in range(num_chains): - if kernel_results_.is_accepted[t, i]: - self.assertNotEqual(states_[t, i], states_[t - 1, i]) - else: - self.assertEqual(states_[t, i], states_[t - 1, i]) - - def _kernel_leaves_target_invariant(self, initial_draws, - independent_chain_ndims, - sess, feed_dict=None): - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - def fake_log_prob(x): - """Cooled version of the target distribution.""" - return 1.1 * log_gamma_log_prob(x) - - step_size = array_ops.placeholder(np.float32, [], name="step_size") - - if feed_dict is None: - feed_dict = {} - - feed_dict[step_size] = 0.4 - - sample, kernel_results = hmc.kernel( - target_log_prob_fn=log_gamma_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=43) - - bad_sample, bad_kernel_results = hmc.kernel( - target_log_prob_fn=fake_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=44) - - [ - acceptance_probs_, - bad_acceptance_probs_, - initial_draws_, - updated_draws_, - fake_draws_, - ] = sess.run([ - kernel_results.acceptance_probs, - bad_kernel_results.acceptance_probs, - initial_draws, - sample, - bad_sample, - ], feed_dict) - - # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_.mean(), 0.5) - - # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_.mean(), 0.99) - self.assertLess(bad_acceptance_probs_.mean(), 0.99) - - _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), - updated_draws_.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), - fake_draws_.flatten()) - - logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs_.mean())) - logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs_.mean())) - logging_ops.vlog(1, "K-S p-value for true target: {}".format( - ks_p_value_true)) - logging_ops.vlog(1, "K-S p-value for fake target: {}".format( - ks_p_value_fake)) - # Make sure that the MCMC update hasn't changed the empirical CDF much. - self.assertGreater(ks_p_value_true, 1e-3) - # Confirm that targeting the wrong distribution does - # significantly change the empirical CDF. - self.assertLess(ks_p_value_fake, 1e-6) - - def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): - """Tests that the kernel leaves the target distribution invariant. - - Draws some independent samples from the target distribution, - applies an iteration of the MCMC kernel, then runs a - Kolmogorov-Smirnov test to determine if the distribution of the - MCMC-updated samples has changed. - - We also confirm that running the kernel with a different log-pdf - does change the target distribution. (And that we can detect that.) - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - initial_draws = np.log(np.random.gamma(self._shape_param, - size=[50000, 2, 2])) - initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name="x_ph") - - feed_dict = {x_ph: initial_draws} - - self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper(1) - - def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper(2) - - def testKernelLeavesTargetInvariant3(self): - self._kernel_leaves_target_invariant_wrapper(3) - - def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - - def proposal_log_prob(x): - counter["proposal_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - if feed_dict is None: - feed_dict = {} - - num_steps = 200 - - _, ais_weights, _ = hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal_log_prob, - num_steps=num_steps, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=init, - num_leapfrog_steps=2, - seed=45) - - # We have three calls because the calculation of `ais_weights` entails - # another call to the `convex_combined_log_prob_fn`. We could refactor - # things to avoid this, if needed (eg, b/72994218). - self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) - - event_shape = array_ops.shape(init)[independent_chain_ndims:] - event_size = math_ops.reduce_prod(event_shape) - - log_true_normalizer = ( - -self._shape_param * math_ops.log(self._rate_param) - + math_ops.lgamma(self._shape_param)) - log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) - - log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) - - np.log(num_steps)) - - ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) - ais_weights_size = array_ops.size(ais_weights) - standard_error = math_ops.sqrt( - _reduce_variance(ratio_estimate_true) - / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) - - [ - ratio_estimate_true_, - log_true_normalizer_, - log_estimated_normalizer_, - standard_error_, - ais_weights_size_, - event_size_, - ] = sess.run([ - ratio_estimate_true, - log_true_normalizer, - log_estimated_normalizer, - standard_error, - ais_weights_size, - event_size, - ], feed_dict) - - logging_ops.vlog(1, " log_true_normalizer: {}\n" - " log_estimated_normalizer: {}\n" - " ais_weights_size: {}\n" - " event_size: {}\n".format( - log_true_normalizer_, - log_estimated_normalizer_, - ais_weights_size_, - event_size_)) - self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) - - def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): - """Tests that AIS yields reasonable estimates of normalizers.""" - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - initial_draws = np.random.normal(size=[30, 2, 1]) - self._ais_gets_correct_log_normalizer( - x_ph, - independent_chain_ndims, - sess, - feed_dict={x_ph: initial_draws}) - - def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper(1) - - def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper(2) - - def testAIS3(self): - self._ais_gets_correct_log_normalizer_wrapper(3) - - def testSampleAIChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - independent_chain_ndims = 1 - x = np.random.rand(4, 3, 2) - - def proposal_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - ais_kwargs = dict( - proposal_log_prob_fn=proposal_log_prob, - num_steps=200, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=x, - num_leapfrog_steps=2, - seed=53) - - _, ais_weights0, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - _, ais_weights1, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - [ais_weights0_, ais_weights1_] = sess.run([ - ais_weights0, ais_weights1]) - - self.assertAllClose(ais_weights0_, ais_weights1_, - atol=1e-5, rtol=1e-5) - - def testNanRejection(self): - """Tests that an update that yields NaN potentials gets rejected. - - We run HMC with a target distribution that returns NaN - log-likelihoods if any element of x < 0, and unit-scale - exponential log-likelihoods otherwise. The exponential potential - pushes x towards 0, ensuring that any reasonably large update will - push us over the edge into NaN territory. - """ - def _unbounded_exponential_log_prob(x): - """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where( - x < 0., - array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), - -x) - return math_ops.reduce_sum(per_element_potentials) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_unbounded_exponential_log_prob, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=46) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) - - def testNanFromGradsDontPropagate(self): - """Test that update with NaN gradients does not cause NaN in results.""" - def _nan_log_prob_with_nan_gradient(x): - return np.nan * math_ops.reduce_sum(x) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_nan_log_prob_with_nan_gradient, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=47) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) - - self.assertAllFinite( - gradients_ops.gradients(updated_x, initial_x)[0].eval()) - self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, initial_x)]) - self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, - kernel_results.proposed_state)]) - - # Gradients of the acceptance probs and new log prob are not finite. - # self.assertAllFinite( - # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) - # self.assertAllFinite( - # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) - - def _testChainWorksDtype(self, dtype): - with self.test_session(graph=ops.Graph()) as sess: - states, kernel_results = hmc.sample_chain( - num_results=10, - target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), - current_state=np.zeros(5).astype(dtype), - step_size=0.01, - num_leapfrog_steps=10, - seed=48) - states_, acceptance_probs_ = sess.run( - [states, kernel_results.acceptance_probs]) - self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, acceptance_probs_.dtype) - - def testChainWorksIn64Bit(self): - self._testChainWorksDtype(np.float64) - - def testChainWorksIn16Bit(self): - self._testChainWorksDtype(np.float16) - - def testChainWorksCorrelatedMultivariate(self): - dtype = np.float32 - true_mean = dtype([0, 0]) - true_cov = dtype([[1, 0.5], - [0.5, 1]]) - num_results = 2000 - counter = collections.Counter() - with self.test_session(graph=ops.Graph()) as sess: - def target_log_prob(x, y): - counter["target_calls"] += 1 - # Corresponds to unnormalized MVN. - # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) - z = array_ops.stack([x, y], axis=-1) - true_mean - z = array_ops.squeeze( - gen_linalg_ops.matrix_triangular_solve( - np.linalg.cholesky(true_cov), - z[..., array_ops.newaxis]), - axis=-1) - return -0.5 * math_ops.reduce_sum(z**2., axis=-1) - states, _ = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=target_log_prob, - current_state=[dtype(-2), dtype(2)], - step_size=[0.5, 0.5], - num_leapfrog_steps=2, - num_burnin_steps=200, - num_steps_between_results=1, - seed=54) - self.assertAllEqual(dict(target_calls=2), counter) - states = array_ops.stack(states, axis=-1) - self.assertEqual(num_results, states.shape[0].value) - sample_mean = math_ops.reduce_mean(states, axis=0) - x = states - sample_mean - sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) - [sample_mean_, sample_cov_] = sess.run([ - sample_mean, sample_cov]) - self.assertAllClose(true_mean, sample_mean_, - atol=0.05, rtol=0.) - self.assertAllClose(true_cov, sample_cov_, - atol=0., rtol=0.1) - - -class _EnergyComputationTest(object): - - def testHandlesNanFromPotential(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - target_log_prob, proposed_target_log_prob = [ - self.dtype(x.flatten()) for x in np.meshgrid(x, x)] - num_chains = len(target_log_prob) - dummy_momentums = [-1, 1] - momentums = [self.dtype([dummy_momentums] * num_chains)] - proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - self.assertAllEqual(np.ones_like(grads_).astype(np.bool), - np.isfinite(grads_)) - - def testHandlesNanFromKinetic(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - momentums, proposed_momentums = [ - [np.reshape(self.dtype(x), [-1, 1])] - for x in np.meshgrid(x, x)] - num_chains = len(momentums[0]) - target_log_prob = np.ones(num_chains, self.dtype) - proposed_target_log_prob = np.ones(num_chains, self.dtype) - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - g = grads_[0].reshape([len(x), len(x)])[:, 0] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) - - # The remaining gradients are nan because the momentum was itself nan or - # inf. - g = grads_[0].reshape([len(x), len(x)])[:, 1:] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) - - -class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): - dtype = np.float16 - - -class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): - dtype = np.float32 - - -class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): - dtype = np.float64 - - -class _HMCHandlesLists(object): - - def testStateParts(self): - with self.test_session(graph=ops.Graph()) as sess: - dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) - dist_y = independent_lib.Independent( - gamma_lib.Gamma(concentration=self.dtype([1, 2]), - rate=self.dtype([0.5, 0.75])), - reinterpreted_batch_ndims=1) - def target_log_prob(x, y): - return dist_x.log_prob(x) + dist_y.log_prob(y) - x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] - samples, _ = hmc.sample_chain( - num_results=int(2e3), - target_log_prob_fn=target_log_prob, - current_state=x0, - step_size=0.85, - num_leapfrog_steps=3, - num_burnin_steps=int(250), - seed=49) - actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] - actual_vars = [_reduce_variance(s, axis=0) for s in samples] - expected_means = [dist_x.mean(), dist_y.mean()] - expected_vars = [dist_x.variance(), dist_y.variance()] - [ - actual_means_, - actual_vars_, - expected_means_, - expected_vars_, - ] = sess.run([ - actual_means, - actual_vars, - expected_means, - expected_vars, - ]) - self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) - self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) - - -class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): - dtype = np.float32 - - -class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py deleted file mode 100644 index 750afb6654311fea30a1dc6b31b20aa3b4160ae2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for convolutional Bayesian layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib -from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - - -class Counter(object): - """Helper class to manage incrementing a counting `int`.""" - - def __init__(self): - self._value = -1 - - @property - def value(self): - return self._value - - def __call__(self): - self._value += 1 - return self._value - - -class MockDistribution(independent_lib.Independent): - """Monitors layer calls to the underlying distribution.""" - - def __init__(self, result_sample, result_log_prob, loc=None, scale=None): - self.result_sample = result_sample - self.result_log_prob = result_log_prob - self.result_loc = loc - self.result_scale = scale - self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) - if loc is not None and scale is not None: - self.result_distribution = normal_lib.Normal(loc=self.result_loc, - scale=self.result_scale) - self.called_log_prob = Counter() - self.called_sample = Counter() - self.called_loc = Counter() - self.called_scale = Counter() - - def log_prob(self, *args, **kwargs): - self.called_log_prob() - return self.result_log_prob - - def sample(self, *args, **kwargs): - self.called_sample() - return self.result_sample - - @property - def distribution(self): # for dummy check on Independent(Normal) - return self.result_distribution - - @property - def loc(self): - self.called_loc() - return self.result_loc - - @property - def scale(self): - self.called_scale() - return self.result_scale - - -class MockKLDivergence(object): - """Monitors layer calls to the divergence implementation.""" - - def __init__(self, result): - self.result = result - self.args = [] - self.called = Counter() - - def __call__(self, *args, **kwargs): - self.called() - self.args.append(args) - return self.result - - -class ConvVariational(test.TestCase): - - def _testKLPenaltyKernel(self, layer_class): - with self.test_session(): - layer = layer_class(filters=2, kernel_size=3) - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform([2, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 1) - self.assertListEqual(layer.losses, losses) - - def _testKLPenaltyBoth(self, layer_class): - def _make_normal(dtype, *args): # pylint: disable=unused-argument - return normal_lib.Normal( - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) - with self.test_session(): - layer = layer_class( - filters=2, - kernel_size=3, - bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), - bias_prior_fn=_make_normal) - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform([2, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 2) - self.assertListEqual(layer.losses, losses) - - def _testConvSetUp(self, layer_class, batch_size, depth=None, - height=None, width=None, channels=None, filters=None, - **kwargs): - seed = Counter() - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform( - [batch_size, width, channels], seed=seed()) - kernel_size = (2,) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform( - [batch_size, height, width, channels], seed=seed()) - kernel_size = (2, 2) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform( - [batch_size, depth, height, width, channels], seed=seed()) - kernel_size = (2, 2, 2) - - kernel_shape = kernel_size + (channels, filters) - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_shape, seed=seed()), - scale=random_ops.random_uniform(kernel_shape, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_shape, seed=seed())) - - bias_size = (filters,) - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) - - layer = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence, - **kwargs) - - outputs = layer(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - return (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, - layer, inputs, outputs, kl_penalty, kernel_shape) - - def _testConvReparameterization(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty, kernel_shape) = self._testConvSetUp( - layer_class, batch_size, - depth=depth, height=height, width=width, channels=channels, - filters=filters) - - convolution_op = nn_ops.Convolution( - tensor_shape.TensorShape(inputs.shape), - filter_shape=tensor_shape.TensorShape(kernel_shape), - padding="SAME") - expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) - expected_outputs = nn.bias_add(expected_outputs, - bias_posterior.result_sample, - data_format="NHWC") - - [ - expected_outputs_, actual_outputs_, - expected_kernel_, actual_kernel_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_posterior.result_sample, layer.kernel_posterior_tensor, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_kernel_, actual_kernel_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - kernel_posterior.result_sample]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def _testConvFlipout(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty, kernel_shape) = self._testConvSetUp( - layer_class, batch_size, - depth=depth, height=height, width=width, channels=channels, - filters=filters, seed=44) - - convolution_op = nn_ops.Convolution( - tensor_shape.TensorShape(inputs.shape), - filter_shape=tensor_shape.TensorShape(kernel_shape), - padding="SAME") - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(kernel_posterior.result_loc), - scale=kernel_posterior.result_scale) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - - expected_outputs = convolution_op( - inputs, kernel_posterior.distribution.loc) - - input_shape = array_ops.shape(inputs) - output_shape = array_ops.shape(expected_outputs) - batch_shape = array_ops.expand_dims(input_shape[0], 0) - channels = input_shape[-1] - rank = len(inputs.get_shape()) - 2 - - sign_input = random_ops.random_uniform( - array_ops.concat([batch_shape, - array_ops.expand_dims(channels, 0)], 0), - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=layer.seed) - sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) - sign_output = random_ops.random_uniform( - array_ops.concat([batch_shape, - array_ops.expand_dims(filters, 0)], 0), - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=distribution_util.gen_new_seed( - layer.seed, salt="conv_flipout")) - sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) - for _ in range(rank): - sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) - sign_output = array_ops.expand_dims(sign_output, 1) - - sign_input = array_ops.tile( # tile for element-wise op broadcasting - sign_input, - [1] + [input_shape[i + 1] for i in range(rank)] + [1]) - sign_output = array_ops.tile( - sign_output, - [1] + [output_shape[i + 1] for i in range(rank)] + [1]) - - perturbed_inputs = convolution_op( - inputs * sign_input, expected_kernel_posterior_affine_tensor) - perturbed_inputs *= sign_output - - expected_outputs += perturbed_inputs - expected_outputs = nn.bias_add(expected_outputs, - bias_posterior.result_sample, - data_format="NHWC") - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, kernel_prior.distribution, None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def _testRandomConvFlipout(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - seed = Counter() - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform( - [batch_size, width, channels], seed=seed()) - kernel_size = (2,) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform( - [batch_size, height, width, channels], seed=seed()) - kernel_size = (2, 2) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform( - [batch_size, depth, height, width, channels], seed=seed()) - kernel_size = (2, 2, 2) - - kernel_shape = kernel_size + (channels, filters) - bias_size = (filters,) - - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform( - kernel_shape, seed=seed()), - scale=random_ops.random_uniform( - kernel_shape, seed=seed()), - result_log_prob=random_ops.random_uniform( - kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform( - kernel_shape, seed=seed())) - bias_posterior = MockDistribution( - loc=random_ops.random_uniform( - bias_size, seed=seed()), - scale=random_ops.random_uniform( - bias_size, seed=seed()), - result_log_prob=random_ops.random_uniform( - bias_size, seed=seed()), - result_sample=random_ops.random_uniform( - bias_size, seed=seed())) - layer_one = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=44) - layer_two = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=45) - - outputs_one = layer_one(inputs) - outputs_two = layer_two(inputs) - - outputs_one_, outputs_two_ = sess.run([ - outputs_one, outputs_two]) - - self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), - np.prod(outputs_one_.shape)) - - def testKLPenaltyKernelConv1DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization) - - def testKLPenaltyKernelConv2DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization) - - def testKLPenaltyKernelConv3DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization) - - def testKLPenaltyKernelConv1DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout) - - def testKLPenaltyKernelConv2DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout) - - def testKLPenaltyKernelConv3DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout) - - def testKLPenaltyBothConv1DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization) - - def testKLPenaltyBothConv2DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization) - - def testKLPenaltyBothConv3DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization) - - def testKLPenaltyBothConv1DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout) - - def testKLPenaltyBothConv2DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout) - - def testKLPenaltyBothConv3DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout) - - def testConv1DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization) - - def testConv2DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization) - - def testConv3DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization) - - def testConv1DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv1DFlipout) - - def testConv2DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv2DFlipout) - - def testConv3DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv3DFlipout) - - def testRandomConv1DFlipout(self): - self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py deleted file mode 100644 index 342f38ccec7ec74db1b393d6cdc22300205cc547..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for dense Bayesian layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational as prob_layers_lib -from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - - -class Counter(object): - """Helper class to manage incrementing a counting `int`.""" - - def __init__(self): - self._value = -1 - - @property - def value(self): - return self._value - - def __call__(self): - self._value += 1 - return self._value - - -class MockDistribution(independent_lib.Independent): - """Monitors layer calls to the underlying distribution.""" - - def __init__(self, result_sample, result_log_prob, loc=None, scale=None): - self.result_sample = result_sample - self.result_log_prob = result_log_prob - self.result_loc = loc - self.result_scale = scale - self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) - if loc is not None and scale is not None: - self.result_distribution = normal_lib.Normal(loc=self.result_loc, - scale=self.result_scale) - self.called_log_prob = Counter() - self.called_sample = Counter() - self.called_loc = Counter() - self.called_scale = Counter() - - def log_prob(self, *args, **kwargs): - self.called_log_prob() - return self.result_log_prob - - def sample(self, *args, **kwargs): - self.called_sample() - return self.result_sample - - @property - def distribution(self): # for dummy check on Independent(Normal) - return self.result_distribution - - @property - def loc(self): - self.called_loc() - return self.result_loc - - @property - def scale(self): - self.called_scale() - return self.result_scale - - -class MockKLDivergence(object): - """Monitors layer calls to the divergence implementation.""" - - def __init__(self, result): - self.result = result - self.args = [] - self.called = Counter() - - def __call__(self, *args, **kwargs): - self.called() - self.args.append(args) - return self.result - - -class DenseVariational(test.TestCase): - - def _testKLPenaltyKernel(self, layer_class): - with self.test_session(): - layer = layer_class(units=2) - inputs = random_ops.random_uniform([2, 3], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 1) - self.assertListEqual(layer.losses, losses) - - def _testKLPenaltyBoth(self, layer_class): - def _make_normal(dtype, *args): # pylint: disable=unused-argument - return normal_lib.Normal( - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) - with self.test_session(): - layer = layer_class( - units=2, - bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), - bias_prior_fn=_make_normal) - inputs = random_ops.random_uniform([2, 3], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 2) - self.assertListEqual(layer.losses, losses) - - def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size, - **kwargs): - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_size, seed=seed()), - scale=random_ops.random_uniform(kernel_size, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) - - layer = layer_class( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence, - **kwargs) - - outputs = layer(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - return (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, - layer, inputs, outputs, kl_penalty) - - def testKLPenaltyKernelReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseReparameterization) - - def testKLPenaltyKernelLocalReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseLocalReparameterization) - - def testKLPenaltyKernelFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseFlipout) - - def testKLPenaltyBothReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseReparameterization) - - def testKLPenaltyBothLocalReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseLocalReparameterization) - - def testKLPenaltyBothFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseFlipout) - - def testDenseReparameterization(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseReparameterization, - batch_size, in_size, out_size) - - expected_outputs = ( - math_ops.matmul(inputs, kernel_posterior.result_sample) + - bias_posterior.result_sample) - - [ - expected_outputs_, actual_outputs_, - expected_kernel_, actual_kernel_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_posterior.result_sample, layer.kernel_posterior_tensor, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_kernel_, actual_kernel_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - kernel_posterior.result_sample]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testDenseLocalReparameterization(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseLocalReparameterization, - batch_size, in_size, out_size) - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=math_ops.matmul(inputs, kernel_posterior.result_loc), - scale=math_ops.matmul( - inputs**2., kernel_posterior.result_scale**2)**0.5) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - expected_outputs = (expected_kernel_posterior_affine_tensor + - bias_posterior.result_sample) - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testDenseFlipout(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseFlipout, - batch_size, in_size, out_size, seed=44) - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(kernel_posterior.result_loc), - scale=kernel_posterior.result_scale) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - - sign_input = random_ops.random_uniform( - [batch_size, in_size], - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=layer.seed) - sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) - sign_output = random_ops.random_uniform( - [batch_size, out_size], - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=distribution_util.gen_new_seed( - layer.seed, salt="dense_flipout")) - sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) - perturbed_inputs = math_ops.matmul( - inputs * sign_input, expected_kernel_posterior_affine_tensor) - perturbed_inputs *= sign_output - - expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc) - expected_outputs += perturbed_inputs - expected_outputs += bias_posterior.result_sample - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, kernel_prior.distribution, None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testRandomDenseFlipout(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - scale=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - result_log_prob=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - result_sample=random_ops.random_uniform( - [in_size, out_size], seed=seed())) - bias_posterior = MockDistribution( - loc=random_ops.random_uniform( - [out_size], seed=seed()), - scale=random_ops.random_uniform( - [out_size], seed=seed()), - result_log_prob=random_ops.random_uniform( - [out_size], seed=seed()), - result_sample=random_ops.random_uniform( - [out_size], seed=seed())) - layer_one = prob_layers_lib.DenseFlipout( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=44) - layer_two = prob_layers_lib.DenseFlipout( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=45) - - outputs_one = layer_one(inputs) - outputs_two = layer_two(inputs) - - outputs_one_, outputs_two_ = sess.run([ - outputs_one, outputs_two]) - - self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), out_size) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py deleted file mode 100644 index 52e36e135d95c1ec919c710f35d59073c2134d05..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MCMC diagnostic utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import spectral_ops_test_util -from tensorflow.python.platform import test - -rng = np.random.RandomState(42) - - -class _EffectiveSampleSizeTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to implement `use_static_shape`.") - - def _check_versus_expected_effective_sample_size(self, - x_, - expected_ess, - sess, - atol=1e-2, - rtol=1e-2, - filter_threshold=None, - filter_beyond_lag=None): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - ess = mcmc_diagnostics.effective_sample_size( - x, - filter_threshold=filter_threshold, - filter_beyond_lag=filter_beyond_lag) - if self.use_static_shape: - self.assertAllEqual(x.shape[1:], ess.shape) - - ess_ = sess.run(ess) - - self.assertAllClose( - np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) - - def testIidRank1NormalHasFullEssMaxLags10(self): - # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we - # should have a good estimate of ESS, and it should be close to the full - # sequence length of 5000. - # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only - # since we know the correlation length should be zero right away. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=10, - filter_threshold=None, - rtol=0.3) - - def testIidRank2NormalHasFullEssMaxLags10(self): - # See similar test for Rank1Normal for reasoning. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000, 2).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=10, - filter_threshold=None, - rtol=0.3) - - def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): - # With a length 5000 iid normal sequence, and filter_threshold = 0, - # we should have a super-duper estimate of ESS, and it should be very close - # to the full sequence length of 5000. - # The choice of filter_beyond_lag = 0 means we cutoff as soon as the - # auto-corris below zero. This should happen very quickly, due to the fact - # that the theoretical auto-corr is [1, 0, 0,...] - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): - # See similar test for Rank1Normal for reasoning. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000, 2).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): - # Create x_, such that - # x_[i] = iid_x_[0], i = 0,...,9 - # x_[i] = iid_x_[1], i = 10,..., 19, - # and so on. - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=x_, - expected_ess=50000 // 10, - sess=sess, - filter_beyond_lag=50, - filter_threshold=None, - rtol=0.2) - - def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( - self): - # Create x_, such that - # x_[i] = iid_x_[0], i = 0,...,9 - # x_[i] = iid_x_[1], i = 10,..., 19, - # and so on. - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=x_, - expected_ess=50000 // 10, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testListArgs(self): - # x_ has correlation length 10 ==> ESS = N / 10 - # y_ has correlation length 1 ==> ESS = N - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - y_ = rng.randn(50000).astype(np.float32) - states = [x_, x_, y_, y_] - filter_threshold = [0., None, 0., None] - filter_beyond_lag = [None, 5, None, 5] - - # See other tests for reasoning on tolerance. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - ess = mcmc_diagnostics.effective_sample_size( - states, - filter_threshold=filter_threshold, - filter_beyond_lag=filter_beyond_lag) - ess_ = sess.run(ess) - self.assertAllEqual(4, len(ess_)) - - self.assertAllClose(50000 // 10, ess_[0], rtol=0.3) - self.assertAllClose(50000 // 10, ess_[1], rtol=0.3) - self.assertAllClose(50000, ess_[2], rtol=0.1) - self.assertAllClose(50000, ess_[3], rtol=0.1) - - def testMaxLagsThresholdLessThanNeg1SameAsNone(self): - # Setting both means we filter out items R_k from the auto-correlation - # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. - - # x_ has correlation length 10. - iid_x_ = rng.randn(500, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - - ess_none_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=None, filter_beyond_lag=None) - ess_none_200 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=None, filter_beyond_lag=200) - ess_neg2_200 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=-2., filter_beyond_lag=200) - ess_neg2_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=-2., filter_beyond_lag=None) - ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run( - [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none]) - - # filter_threshold=-2 <==> filter_threshold=None. - self.assertAllClose(ess_none_none_, ess_neg2_none_) - self.assertAllClose(ess_none_200_, ess_neg2_200_) - - def testMaxLagsArgsAddInAnOrManner(self): - # Setting both means we filter out items R_k from the auto-correlation - # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. - - # x_ has correlation length 10. - iid_x_ = rng.randn(500, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - - ess_1_9 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=9) - ess_1_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=None) - ess_none_9 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=9) - ess_1_9_, ess_1_none_, ess_none_9_ = sess.run( - [ess_1_9, ess_1_none, ess_none_9]) - - # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10, - # filter_threshold = 1 <==> filter_beyond_lag = 9. - self.assertAllClose(ess_1_9_, ess_1_none_) - self.assertAllClose(ess_1_9_, ess_none_9_) - - -class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): - - @property - def use_static_shape(self): - return True - - -class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest): - - @property - def use_static_shape(self): - return False - - -class _PotentialScaleReductionTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to impliment `use_static_shape`.") - - def testListOfStatesWhereFirstPassesSecondFails(self): - """Simple test showing API with two states. Read first!.""" - n_samples = 1000 - - # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass. - state_0 = rng.randn(n_samples, 2) - - # state_1 is three 4-variate chains taken from Normal(0, 1) that have been - # shifted. Since every chain is shifted, they are not the same, and the - # test should fail. - offset = np.array([1., -1., 2.]).reshape(3, 1) - state_1 = rng.randn(n_samples, 3, 4) + offset - - rhat = mcmc_diagnostics.potential_scale_reduction( - chains_states=[state_0, state_1], independent_chain_ndims=1) - - self.assertIsInstance(rhat, list) - with self.test_session() as sess: - rhat_0_, rhat_1_ = sess.run(rhat) - - # r_hat_0 should be close to 1, meaning test is passed. - self.assertAllEqual((), rhat_0_.shape) - self.assertAllClose(1., rhat_0_, rtol=0.02) - - # r_hat_1 should be greater than 1.2, meaning test has failed. - self.assertAllEqual((4,), rhat_1_.shape) - self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2) - - def check_results(self, state_, independent_chain_shape, should_pass): - sample_ndims = 1 - independent_chain_ndims = len(independent_chain_shape) - with self.test_session(): - state = array_ops.placeholder_with_default( - input=state_, shape=state_.shape if self.use_static_shape else None) - - rhat = mcmc_diagnostics.potential_scale_reduction( - state, independent_chain_ndims=independent_chain_ndims) - - if self.use_static_shape: - self.assertAllEqual( - state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape) - - rhat_ = rhat.eval() - if should_pass: - self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02) - else: - self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2) - - def iid_normal_chains_should_pass_wrapper(self, - sample_shape, - independent_chain_shape, - other_shape, - dtype=np.float32): - """Check results with iid normal chains.""" - - state_shape = sample_shape + independent_chain_shape + other_shape - state_ = rng.randn(*state_shape).astype(dtype) - - # The "other" dimensions do not have to be identical, just independent, so - # force them to not be identical. - if other_shape: - state_ *= rng.rand(*other_shape).astype(dtype) - - self.check_results(state_, independent_chain_shape, should_pass=True) - - def testPassingIIDNdimsAreIndependentOneOtherZero(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[4], other_shape=[]) - - def testPassingIIDNdimsAreIndependentOneOtherOne(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[3], other_shape=[7]) - - def testPassingIIDNdimsAreIndependentOneOtherTwo(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7]) - - def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], - independent_chain_shape=[2, 3], - other_shape=[5, 7], - dtype=np.float64) - - def offset_normal_chains_should_fail_wrapper( - self, sample_shape, independent_chain_shape, other_shape): - """Check results with normal chains that are offset from each other.""" - - state_shape = sample_shape + independent_chain_shape + other_shape - state_ = rng.randn(*state_shape) - - # Add a significant offset to the different (formerly iid) chains. - offset = np.linspace( - 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len( - sample_shape) + independent_chain_shape + [1] * len(other_shape)) - state_ += offset - - self.check_results(state_, independent_chain_shape, should_pass=False) - - def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self): - self.offset_normal_chains_should_fail_wrapper( - sample_shape=[10000], independent_chain_shape=[2], other_shape=[5]) - - -class PotentialScaleReductionStaticTest(test.TestCase, - _PotentialScaleReductionTest): - - @property - def use_static_shape(self): - return True - - def testIndependentNdimsLessThanOneRaises(self): - with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"): - mcmc_diagnostics.potential_scale_reduction( - rng.rand(2, 3, 4), independent_chain_ndims=0) - - -class PotentialScaleReductionDynamicTest(test.TestCase, - _PotentialScaleReductionTest): - - @property - def use_static_shape(self): - return False - - -class _ReduceVarianceTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to impliment `use_static_shape`.") - - def check_versus_numpy(self, x_, axis, biased, keepdims): - with self.test_session(): - x_ = np.asarray(x_) - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - var = mcmc_diagnostics._reduce_variance( - x, axis=axis, biased=biased, keepdims=keepdims) - np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims) - - if self.use_static_shape: - self.assertAllEqual(np_var.shape, var.shape) - - var_ = var.eval() - # We will mask below, which changes shape, so check shape explicitly here. - self.assertAllEqual(np_var.shape, var_.shape) - - # We get NaN when we divide by zero due to the size being the same as ddof - nan_mask = np.isnan(np_var) - if nan_mask.any(): - self.assertTrue(np.isnan(var_[nan_mask]).all()) - self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02) - - def testScalarBiasedTrue(self): - self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False) - - def testScalarBiasedFalse(self): - # This should result in NaN. - self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False) - - def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False) - - def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True) - - def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True) - - def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False) - - -class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest): - - @property - def use_static_shape(self): - return True - - -class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest): - - @property - def use_static_shape(self): - return False - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py deleted file mode 100644 index 63d93fad64d077aa385b72428665e841b6784b90..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for metropolis_hastings.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class McmcStepTest(test.TestCase): - - def test_density_increasing_step_accepted(self): - """Tests that if a transition increases density, it is always accepted.""" - target_log_density = lambda x: - x * x - state = variable_scope.get_variable('state', initializer=10.) - state_log_density = variable_scope.get_variable( - 'state_log_density', - initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) - - get_next_proposal = lambda x: (x - 1., None) - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, get_next_proposal, seed=1234) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - for j in range(9): - sess.run(step) - sample = sess.run(state) - sample_log_density = sess.run(state_log_density) - self.assertAlmostEqual(sample, 9 - j) - self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) - - def test_sample_properties(self): - """Tests that the samples converge to the target distribution.""" - - def target_log_density(x): - """Log-density corresponding to a normal distribution with mean = 4.""" - return - (x - 2.0) * (x - 2.0) * 0.5 - - # Use the uniform random walker to generate proposals. - proposal_fn = mh.uniform_random_proposal( - step_size=1.0, seed=1234) - - state = variable_scope.get_variable('state', initializer=0.0) - state_log_density = variable_scope.get_variable( - 'state_log_density', - initializer=target_log_density(state.initialized_value())) - - log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) - # Random walk MCMC converges slowly so need to put in enough iterations. - num_iterations = 5000 - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, proposal_fn, seed=4321) - - init = variables.global_variables_initializer() - - sample_sum, sample_sq_sum = 0.0, 0.0 - with self.test_session() as sess: - sess.run(init) - for _ in np.arange(num_iterations): - # Allow for the mixing of the chain and discard these samples. - sess.run(step) - for _ in np.arange(num_iterations): - sess.run(step) - sample = sess.run(state) - sample_sum += sample - sample_sq_sum += sample * sample - - sample_mean = sample_sum / num_iterations - sample_variance = sample_sq_sum / num_iterations - sample_mean * sample_mean - # The samples have large autocorrelation which reduces the effective sample - # size. - self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) - self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) - - def test_normal_proposals(self): - """Tests that the normal proposals are correctly distributed.""" - - initial_points = array_ops.ones([10000], dtype=dtypes.float32) - proposal_fn = mh.normal_random_proposal( - scale=2.0, seed=1234) - proposal_points, _ = proposal_fn(initial_points) - - with self.test_session() as sess: - sample = sess.run(proposal_points) - - # It is expected that the elements in proposal_points have the same mean as - # initial_points and have the standard deviation that was supplied to the - # proposal scheme. - self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) - self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) - - def test_docstring_example(self): - """Tests the simplified docstring example with multiple chains.""" - - n = 2 # dimension of the problem - - # Generate 300 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = variable_scope.get_variable( - 'state', initializer=random_ops.random_normal( - [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - math_ops.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = variable_scope.get_variable( - 'state_log_density', - initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = variable_scope.get_variable( - 'log_acceptance_ratio', - initializer=array_ops.zeros([300], dtype=dtypes.float32)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + random_ops.random_uniform( - array_ops.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps. - for _ in range(10): - sess.run(stepper) - samples = sess.run(state) - covariance = np.eye(n) - # Verify that the estimated mean and covariance are close to the true - # values. - self.assertAlmostEqual( - np.max(np.abs(np.mean(samples, 0) - - np.zeros(n))), 0, - delta=0.1) - self.assertAlmostEqual( - np.max(np.abs(np.reshape(np.cov(samples, rowvar=False), [n**2]) - - np.reshape(covariance, [n**2]))), 0, - delta=0.2) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py deleted file mode 100644 index 756c25683bd4b0c8c77e9e28485ca2a85582999c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional test for GradientDescent.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import math -from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class SGLDOptimizerTest(test.TestCase): - - def testBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.53 - sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - - def testBasicMultiInstance(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - vara = variables.Variable([1.1, 2.1], dtype=dtype) - varb = variables.Variable([3.0, 4.0], dtype=dtype) - gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) - gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.5 - sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - sgd_optimizer2 = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate) - sgd_op2 = sgd_optimizer2.apply_gradients( - zip([gradsa, gradsb], [vara, varb])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) - - # Run 1 step of sgd - sgd_op.run() - sgd_op2.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) - - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) - self.assertNotEqual(sgd_optimizer.variable_scope, - sgd_optimizer2.variable_scope) - self.assertNotEqual(sgd_optimizer.variable_scope.name, - sgd_optimizer2.variable_scope.name) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer2._counter.eval()) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = constant_op.constant(3.0) - decay_rate = 0.5 - sgd_op = SGLDOptimizer( - lrate, preconditioner_decay_rate=constant_op.constant( - decay_rate)).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - - def testGradWrtRef(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - opt = SGLDOptimizer(3.0) - values = [1.0, 3.0] - vars_ = [variables.Variable([v], dtype=dtype) for v in values] - grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) - variables.global_variables_initializer().run() - for grad, _ in grads_and_vars: - self.assertAllCloseAccordingToType([1.0], grad.eval()) - - def testWithGlobalStep(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - global_step = variables.Variable(0, trainable=False) - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.1 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - - # Validate updated params and global_step - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) - - def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) - var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) - grads0 = ops.IndexedSlices( - constant_op.constant([0.1], shape=[1, 1], dtype=dtype), - constant_op.constant([0]), constant_op.constant([2, 1])) - grads1 = ops.IndexedSlices( - constant_op.constant([0.01], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), constant_op.constant([2, 1])) - decay_rate = 0.9 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) - self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], - var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py deleted file mode 100644 index f978cf86417dc5ff5412a3eee584330a266e0964..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for utility functions related to managing `tf.Variable`s.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import variable_utils - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.ops import variables as variables_ops -from tensorflow.python.platform import test - - -def test_fn(x): - x = ops.convert_to_tensor(x, name="x") - dtype = x.dtype.as_numpy_dtype - s = x.shape.as_list() - z = varscope_ops.get_variable( - name="z", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) - y = varscope_ops.get_variable( - name="y", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2) - return x + y + z - - -class _WrapCallableTest(object): - - def testDefaultArgsWorkCorrectly(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( - test_fn, [x]) - - varscope_ops.get_variable_scope().reuse_variables() - - result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5) - - y_actual = varscope_ops.get_variable("y", dtype=self.dtype) - z_actual = varscope_ops.get_variable("z", dtype=self.dtype) - - variables_ops.global_variables_initializer().run() - result_ = result.eval() - - self.assertEqual(self.dtype, result_.dtype) - self.assertAllEqual([5.5, 6.5, 7.5], result_) - self.assertAllEqual([y_actual, z_actual], vars_args) - - def testNonDefaultArgsWorkCorrectly(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - - _ = test_fn(self.dtype([0., 0.])) # Needed to create vars. - varscope_ops.get_variable_scope().reuse_variables() - - y_actual = varscope_ops.get_variable("y", dtype=self.dtype) - - wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( - test_fn, [x], possible_ancestor_vars=[y_actual]) - - result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y - - variables_ops.global_variables_initializer().run() - result_ = result.eval() - - self.assertEqual(self.dtype, result_.dtype) - self.assertAllEqual([2.5, 4.5], result_) - self.assertAllEqual([y_actual], vars_args) - - def testWarnings(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, _ = variable_utils.externalize_variables_as_args( - test_fn, [x], possible_ancestor_vars=[]) - varscope_ops.get_variable_scope().reuse_variables() - with warnings.catch_warnings(record=True) as w: - wrapped_fn(self.dtype(2)) - w = sorted(w, key=lambda w: str(w.message)) - self.assertEqual(2, len(w)) - self.assertRegexpMatches( - str(w[0].message), - r"Variable .* 'y:0' .* not found in bypass dict.") - self.assertRegexpMatches( - str(w[1].message), - r"Variable .* 'z:0' .* not found in bypass dict.") - - def testExceptions(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, _ = variable_utils.externalize_variables_as_args( - test_fn, - [x], - possible_ancestor_vars=[], - assert_variable_override=True) - varscope_ops.get_variable_scope().reuse_variables() - with self.assertRaisesRegexp(ValueError, r"not found"): - wrapped_fn(self.dtype(2)) - - -class WrapCallableTest16(test.TestCase, _WrapCallableTest): - dtype = np.float16 - - -class WrapCallableTest32(test.TestCase, _WrapCallableTest): - dtype = np.float32 - - -class WrapCallableTest64(test.TestCase, _WrapCallableTest): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py deleted file mode 100644 index 83c64dbe0fd586edcb784a5c09a4c133aaa99cff..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional test for GradientDescent.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.contrib.bayesflow.python.ops.optimizers import VariationalSGDOptimizer -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class VariationalSGDOptimizerTest(test.TestCase): - - def testBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.53 - sgd_op = VariationalSGDOptimizer( - 1, - 1, - preconditioner_decay_rate=decay_rate, - max_learning_rate=3.0, - burnin_max_learning_rate=3.0, - use_single_learning_rate=True).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - - def testBasicMultiInstance(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - vara = variables.Variable([1.1, 2.1], dtype=dtype) - varb = variables.Variable([3.0, 4.0], dtype=dtype) - gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) - gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=1.0, - burnin_max_learning_rate=3.0, - preconditioner_decay_rate=decay_rate) - sgd_op = optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - optimizer2 = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=1.0, - burnin_max_learning_rate=10.0, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op2 = optimizer2.apply_gradients( - zip([gradsa, gradsb], [vara, varb])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) - - # Run 1 step of sgd - sgd_op.run() - sgd_op2.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3. * 0.1, 2.1 - 3. * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([1.1 - 0.1, 2.1 - 0.1], vara.eval()) - - self.assertAllCloseAccordingToType([3.0 - 3. * 0.01, 4.0 - 3. * 0.01], - var1.eval()) - self.assertAllCloseAccordingToType([3.0 - 0.01, 4.0 - 0.01], - varb.eval()) - self.assertNotEqual(optimizer.variable_scope, - optimizer2.variable_scope) - self.assertNotEqual(optimizer.variable_scope.name, - optimizer2.variable_scope.name) - self.assertAllCloseAccordingToType(1, optimizer._counter.eval()) - self.assertAllCloseAccordingToType(1, optimizer2._counter.eval()) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = constant_op.constant(3.0) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - sgd_op = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=lrate, - burnin=0, - preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - - def testTensorDecayLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = variables.Variable(3.0) - lrate_decay_op = lrate.assign_add(-3.) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=lrate, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - # Update learning rate to 0 - lrate_decay_op.eval() - sgd_op.run() - # Validate params haven't changed - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - lrate_decay_op.eval() - - with self.assertRaises(errors.InvalidArgumentError): - sgd_op.run() - - def testGradWrtRef(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - opt = VariationalSGDOptimizer(1, 1, max_learning_rate=1.0) - values = [1.0, 3.0] - vars_ = [variables.Variable([v], dtype=dtype) for v in values] - grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) - variables.global_variables_initializer().run() - for grad, _ in grads_and_vars: - self.assertAllCloseAccordingToType([1.0], grad.eval()) - - def testWithGlobalStep(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - global_step = variables.Variable(0, trainable=False) - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.1 - batch_size = 2 - total_num_examples = 10 - sgd_optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=3.0, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - - # Validate updated params and global_step - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - - def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) - var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) - grads0 = ops.IndexedSlices( - constant_op.constant([0.1], shape=[1, 1], dtype=dtype), - constant_op.constant([0]), constant_op.constant([2, 1])) - grads1 = ops.IndexedSlices( - constant_op.constant([0.01], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), constant_op.constant([2, 1])) - decay_rate = 0.1 - batch_size = 2 - total_num_examples = 10 - sgd_op = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=3.0, - burnin=0, - preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) - self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([[1.1 - 3.0 * 0.1], [2.1]], - var0.eval()) - self.assertAllCloseAccordingToType( - [[3.0 - 3.0 * 0], [4.0 - 3.0 * 0.01]], var1.eval()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py deleted file mode 100644 index 8efd59d6516924bea538717d45bb4ae303583421..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ /dev/null @@ -1,1105 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Csiszar f-Divergence and helpers. - -@@amari_alpha -@@arithmetic_geometric -@@chi_square -@@csiszar_vimco -@@dual_csiszar_function -@@jeffreys -@@jensen_shannon -@@kl_forward -@@kl_reverse -@@log1p_abs -@@modified_gan -@@monte_carlo_csiszar_f_divergence -@@pearson -@@squared_hellinger -@@symmetrized_csiszar_function -@@total_variation -@@triangular - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util - - -def amari_alpha(logu, alpha=1., self_normalized=False, name=None): - """The Amari-alpha Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Amari-alpha Csiszar-function is: - - ```none - f(u) = { -log(u) + (u - 1), alpha = 0 - { u log(u) - (u - 1), alpha = 1 - { [(u**alpha - 1) - alpha (u - 1)] / (alpha (alpha - 1)), otherwise - ``` - - When `self_normalized = False` the `(u - 1)` terms are omitted. - - Warning: when `alpha != 0` and/or `self_normalized = True` this function makes - non-log-space calculations and may therefore be numerically unstable for - `|logu| >> 0`. - - For more information, see: - A. Cichocki and S. Amari. "Families of Alpha-Beta-and GammaDivergences: - Flexible and Robust Measures of Similarities." Entropy, vol. 12, no. 6, pp. - 1532-1568, 2010. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - alpha: `float`-like Python scalar. (See Mathematical Details for meaning.) - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - amari_alpha_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - - Raises: - TypeError: if `alpha` is `None` or a `Tensor`. - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - with ops.name_scope(name, "amari_alpha", [logu]): - if alpha is None or contrib_framework.is_tensor(alpha): - raise TypeError("`alpha` cannot be `None` or `Tensor` type.") - if self_normalized is None or contrib_framework.is_tensor(self_normalized): - raise TypeError("`self_normalized` cannot be `None` or `Tensor` type.") - - logu = ops.convert_to_tensor(logu, name="logu") - - if alpha == 0.: - f = -logu - elif alpha == 1.: - f = math_ops.exp(logu) * logu - else: - f = math_ops.expm1(alpha * logu) / (alpha * (alpha - 1.)) - - if not self_normalized: - return f - - if alpha == 0.: - return f + math_ops.expm1(logu) - elif alpha == 1.: - return f - math_ops.expm1(logu) - else: - return f - math_ops.expm1(logu) / (alpha - 1.) - - -def kl_reverse(logu, self_normalized=False, name=None): - """The reverse Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-reverse Csiszar-function is: - - ```none - f(u) = -log(u) + (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[q, p] - ``` - - The KL is "reverse" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: when self_normalized = True` this function makes non-log-space - calculations and may therefore be numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_reverse_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_reverse", [logu]): - return amari_alpha(logu, alpha=0., self_normalized=self_normalized) - - -def kl_forward(logu, self_normalized=False, name=None): - """The forward Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-forward Csiszar-function is: - - ```none - f(u) = u log(u) - (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, q] - ``` - - The KL is "forward" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_forward_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_forward", [logu]): - return amari_alpha(logu, alpha=1., self_normalized=self_normalized) - - -def jensen_shannon(logu, self_normalized=False, name=None): - """The Jensen-Shannon Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Jensen-Shannon Csiszar-function is: - - ```none - f(u) = u log(u) - (1 + u) log(1 + u) + (u + 1) log(2) - ``` - - When `self_normalized = False` the `(u + 1) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, m] + KL[q, m] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Arithmetic-Geometric - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - For more information, see: - Lin, J. "Divergence measures based on the Shannon entropy." IEEE Trans. - Inf. Th., 37, 145-151, 1991. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jensen_shannon_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jensen_shannon", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - npdt = logu.dtype.as_numpy_dtype - y = nn_ops.softplus(logu) - if self_normalized: - y -= np.log(2).astype(npdt) - return math_ops.exp(logu) * logu - (1. + math_ops.exp(logu)) * y - - -def arithmetic_geometric(logu, self_normalized=False, name=None): - """The Arithmetic-Geometric Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the Arithmetic-Geometric Csiszar-function is: - - ```none - f(u) = (1 + u) log( (1 + u) / sqrt(u) ) - (1 + u) log(2) - ``` - - When `self_normalized = False` the `(1 + u) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[m, p] + KL[m, q] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Jensen-Shannon - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - arithmetic_geometric_of_u: `float`-like `Tensor` of the - Csiszar-function evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "arithmetic_geometric", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - 0.5 * logu - if self_normalized: - y -= np.log(2.).astype(logu.dtype.as_numpy_dtype) - return (1. + math_ops.exp(logu)) * y - - -def total_variation(logu, name=None): - """The Total Variation Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Total-Variation Csiszar-function is: - - ```none - f(u) = 0.5 |u - 1| - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - total_variation_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "total_variation", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.abs(math_ops.expm1(logu)) - - -def pearson(logu, name=None): - """The Pearson Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Pearson Csiszar-function is: - - ```none - f(u) = (u - 1)**2 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - pearson_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - """ - - with ops.name_scope(name, "pearson", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.square(math_ops.expm1(logu)) - - -def squared_hellinger(logu, name=None): - """The Squared-Hellinger Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Squared-Hellinger Csiszar-function is: - - ```none - f(u) = (sqrt(u) - 1)**2 - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - squared_hellinger_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "squared_hellinger", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(0.5 * logu) - - -def triangular(logu, name=None): - """The Triangular Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Triangular Csiszar-function is: - - ```none - f(u) = (u - 1)**2 / (1 + u) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - triangular_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "triangular", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(logu) / (1. + math_ops.exp(logu)) - - -def t_power(logu, t, self_normalized=False, name=None): - """The T-Power Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the T-Power Csiszar-function is: - - ```none - f(u) = s [ u**t - 1 - t(u - 1) ] - s = { -1 0 < t < 1 - { +1 otherwise - ``` - - When `self_normalized = False` the `- t(u - 1)` term is omitted. - - This is similar to the `amari_alpha` Csiszar-function, with the associated - divergence being the same up to factors depending only on `t`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - t: `Tensor` of same `dtype` as `logu` and broadcastable shape. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - t_power_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - with ops.name_scope(name, "t_power", [logu, t]): - logu = ops.convert_to_tensor(logu, name="logu") - t = ops.convert_to_tensor(t, dtype=logu.dtype.base_dtype, name="t") - fu = math_ops.expm1(t * logu) - if self_normalized: - fu -= t * math_ops.expm1(logu) - fu *= array_ops.where(math_ops.logical_and(0. < t, t < 1.), - -array_ops.ones_like(t), - array_ops.ones_like(t)) - return fu - - -def log1p_abs(logu, name=None): - """The log1p-abs Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Log1p-Abs Csiszar-function is: - - ```none - f(u) = u**(sign(u-1)) - 1 - ``` - - This function is so-named because it was invented from the following recipe. - Choose a convex function g such that g(0)=0 and solve for f: - - ```none - log(1 + f(u)) = g(log(u)). - <=> - f(u) = exp(g(log(u))) - 1 - ``` - - That is, the graph is identically `g` when y-axis is `log1p`-domain and x-axis - is `log`-domain. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log1p_abs_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "log1p_abs", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(math_ops.abs(logu)) - - -def jeffreys(logu, name=None): - """The Jeffreys Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Jeffreys Csiszar-function is: - - ```none - f(u) = 0.5 ( u log(u) - log(u) ) - = 0.5 kl_forward + 0.5 kl_reverse - = symmetrized_csiszar_function(kl_reverse) - = symmetrized_csiszar_function(kl_forward) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jeffreys_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jeffreys", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.expm1(logu) * logu - - -def chi_square(logu, name=None): - """The chi-Square Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Chi-square Csiszar-function is: - - ```none - f(u) = u**2 - 1 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(2. * logu) - - -def modified_gan(logu, self_normalized=False, name=None): - """The Modified-GAN Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the modified-GAN (Generative/Adversarial - Network) Csiszar-function is: - - ```none - f(u) = log(1 + u) - log(u) + 0.5 (u - 1) - ``` - - When `self_normalized = False` the `0.5 (u - 1)` is omitted. - - The unmodified GAN Csiszar-function is identical to Jensen-Shannon (with - `self_normalized = False`). - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - logu - if self_normalized: - y += 0.5 * math_ops.expm1(logu) - return y - - -def dual_csiszar_function(logu, csiszar_function, name=None): - """Calculates the dual Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar-dual is defined as: - - ```none - f^*(u) = u f(1 / u) - ``` - - where `f` is some other Csiszar-function. - - For example, the dual of `kl_reverse` is `kl_forward`, i.e., - - ```none - f(u) = -log(u) - f^*(u) = u f(1 / u) = -u log(1 / u) = u log(u) - ``` - - The dual of the dual is the original function: - - ```none - f^**(u) = {u f(1/u)}^*(u) = u (1/u) f(1/(1/u)) = f(u) - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - dual_f_of_u: `float`-like `Tensor` of the result of calculating the dual of - `f` at `u = exp(logu)`. - """ - - with ops.name_scope(name, "dual_csiszar_function", [logu]): - return math_ops.exp(logu) * csiszar_function(-logu) - - -def symmetrized_csiszar_function(logu, csiszar_function, name=None): - """Symmetrizes a Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The symmetrized Csiszar-function is defined as: - - ```none - f_g(u) = 0.5 g(u) + 0.5 u g (1 / u) - ``` - - where `g` is some other Csiszar-function. - - We say the function is "symmetrized" because: - - ```none - D_{f_g}[p, q] = D_{f_g}[q, p] - ``` - - for all `p << >> q` (i.e., `support(p) = support(q)`). - - There exists alternatives for symmetrizing a Csiszar-function. For example, - - ```none - f_g(u) = max(f(u), f^*(u)), - ``` - - where `f^*` is the dual Csiszar-function, also implies a symmetric - f-Divergence. - - Example: - - When either of the following functions are symmetrized, we obtain the - Jensen-Shannon Csiszar-function, i.e., - - ```none - g(u) = -log(u) - (1 + u) log((1 + u) / 2) + u - 1 - h(u) = log(4) + 2 u log(u / (1 + u)) - ``` - - implies, - - ```none - f_g(u) = f_h(u) = u log(u) - (1 + u) log((1 + u) / 2) - = jensen_shannon(log(u)). - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - symmetrized_g_of_u: `float`-like `Tensor` of the result of applying the - symmetrization of `g` evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "symmetrized_csiszar_function", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * (csiszar_function(logu) - + dual_csiszar_function(logu, csiszar_function)) - - -def monte_carlo_csiszar_f_divergence( - f, - p_log_prob, - q, - num_draws, - use_reparametrization=None, - seed=None, - name=None): - """Monte-Carlo approximation of the Csiszar f-Divergence. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar f-Divergence for Csiszar-function f is given by: - - ```none - D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] - ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), - where x_j ~iid q(X) - ``` - - Tricks: Reparameterization and Score-Gradient - - When q is "reparameterized", i.e., a diffeomorphic transformation of a - parameterless distribution (e.g., - `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and - expectation, i.e., - `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}` - and `s_i = f(x_i), x_i ~iid q(X)`. - - However, if q is not reparameterized, TensorFlow's gradient will be incorrect - since the chain-rule stops at samples of unreparameterized distributions. In - this circumstance using the Score-Gradient trick results in an unbiased - gradient, i.e., - - ```none - grad[ E_q[f(X)] ] - = grad[ int dx q(x) f(x) ] - = int dx grad[ q(x) f(x) ] - = int dx [ q'(x) f(x) + q(x) f'(x) ] - = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] - = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] - = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ] - ``` - - Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is - usually preferable to set `use_reparametrization = True`. - - Example Application: - - The Csiszar f-Divergence is a useful framework for variational inference. - I.e., observe that, - - ```none - f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] ) - <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ] - := D_f[p(x, Z), q(Z | x)] - ``` - - The inequality follows from the fact that the "perspective" of `f`, i.e., - `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and - `t` is a real. Since the above framework includes the popular Evidence Lower - BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework - "Evidence Divergence Bound Optimization" (EDBO). - - Args: - f: Python `callable` representing a Csiszar-function in log-space, i.e., - takes `p_log_prob(q_samples) - q.log_prob(q_samples)`. - p_log_prob: Python `callable` taking (a batch of) samples from `q` and - returning the natural-log of the probability under distribution `p`. - (In variational inference `p` is the joint distribution.) - q: `tf.Distribution`-like instance; must implement: - `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`. - (In variational inference `q` is the approximate posterior distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - use_reparametrization: Python `bool`. When `None` (the default), - automatically set to: - `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`. - When `True` uses the standard Monte-Carlo average. When `False` uses the - score-gradient trick. (See above for details.) When `False`, consider - using `csiszar_vimco`. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. - - Raises: - ValueError: if `q` is not a reparameterized distribution and - `use_reparametrization = True`. A distribution `q` is said to be - "reparameterized" when its samples are generated by transforming the - samples of another distribution which does not depend on the - parameterization of `q`. This property ensures the gradient (with respect - to parameters) is valid. - TypeError: if `p_log_prob` is not a Python `callable`. - """ - with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - if use_reparametrization is None: - use_reparametrization = (q.reparameterization_type - == distribution.FULLY_REPARAMETERIZED) - elif (use_reparametrization and - q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): - # TODO(jvdillon): Consider only raising an exception if the gradient is - # requested. - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") - if not callable(p_log_prob): - raise TypeError("`p_log_prob` must be a Python `callable` function.") - return monte_carlo.expectation( - f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)), - samples=q.sample(num_draws, seed=seed), - log_prob=q.log_prob, # Only used if use_reparametrization=False. - use_reparametrization=use_reparametrization) - - -def csiszar_vimco(f, - p_log_prob, - q, - num_draws, - num_batch_draws=1, - seed=None, - name=None): - """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))]. - - This function generalizes "Variational Inference for Monte Carlo Objectives" - (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences. - - Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`, - consider using `monte_carlo_csiszar_f_divergence`. - - The VIMCO loss is: - - ```none - vimco = f(Avg{logu[i] : i=0,...,m-1}) - where, - logu[i] = log( p(x, h[i]) / q(h[i] | x) ) - h[i] iid~ q(H | x) - ``` - - Interestingly, the VIMCO gradient is not the naive gradient of `vimco`. - Rather, it is characterized by: - - ```none - grad[vimco] - variance_reducing_term - where, - variance_reducing_term = Sum{ grad[log q(h[i] | x)] * - (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) - : i=0, ..., m-1 } - h[j;i] = { u[j] j!=i - { GeometricAverage{ u[k] : k!=i} j==i - ``` - - (We omitted `stop_gradient` for brevity. See implementation for more details.) - - The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th - element has been replaced by the leave-`i`-out Geometric-average. - - This implementation prefers numerical precision over efficiency, i.e., - `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`. - (The constant may be fairly large, perhaps around 12.) - - Args: - f: Python `callable` representing a Csiszar-function in log-space. - p_log_prob: Python `callable` representing the natural-log of the - probability under distribution `p`. (In variational inference `p` is the - joint distribution.) - q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and - `log_prob(x)`. (In variational inference `q` is the approximate posterior - distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - num_batch_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - vimco: The Csiszar f-Divergence generalized VIMCO objective. - - Raises: - ValueError: if `num_draws < 2`. - """ - with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]): - if num_draws < 2: - raise ValueError("Must specify num_draws > 1.") - stop = array_ops.stop_gradient # For readability. - x = stop(q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed)) - logqx = q.log_prob(x) - logu = p_log_prob(x) - logqx - f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)] - dotprod = math_ops.reduce_sum( - logqx * stop(f_log_avg_u - f_log_sooavg_u), - axis=0) # Sum over iid samples. - # We now rewrite f_log_avg_u so that: - # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`. - # To achieve this, we use a trick that - # `f(x) - stop(f(x)) == zeros_like(f(x))` - # but its gradient is grad[f(x)]. - # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence - # this trick loses no precision. For more discussion regarding the relevant - # portions of the IEEE754 standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod). - return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches. - - -def csiszar_vimco_helper(logu, name=None): - """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`. - - `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e., - - ```none - logu[j] = log(u[j]) - u[j] = p(x, h[j]) / q(h[j] | x) - h[j] iid~ q(H | x) - ``` - - Args: - logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the - average of `u`. The sum of the gradient of `log_avg_u` is `1`. - log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the - average of `u`` except that the average swaps-out `u[i]` for the - leave-`i`-out Geometric-average. The mean of the gradient of - `log_sooavg_u` is `1`. Mathematically `log_sooavg_u` is, - ```none - log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1}) - h[j ; i] = { u[j] j!=i - { GeometricAverage{u[k] : k != i} j==i - ``` - - """ - with ops.name_scope(name, "csiszar_vimco_helper", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - - n = logu.shape.with_rank_at_least(1)[0].value - if n is None: - n = array_ops.shape(logu)[0] - log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype)) - nm1 = math_ops.cast(n - 1, dtype=logu.dtype) - else: - log_n = np.log(n).astype(logu.dtype.as_numpy_dtype) - nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype) - - # Throughout we reduce across axis=0 since this is presumed to be iid - # samples. - - log_max_u = math_ops.reduce_max(logu, axis=0) - log_sum_u_minus_log_max_u = math_ops.reduce_logsumexp( - logu - log_max_u, axis=0) - - # log_loosum_u[i] = - # = logsumexp(logu[j] : j != i) - # = log( exp(logsumexp(logu)) - exp(logu[i]) ) - # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i])) - # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1) - # = logu[i] + log(exp(logsumexp(logu) - logu[i]) - 1) - # = logu[i] + softplus_inverse(logsumexp(logu) - logu[i]) - d = log_sum_u_minus_log_max_u + (log_max_u - logu) - # We use `d != 0` rather than `d > 0.` because `d < 0.` should never - # happens; if it does we want to complain loudly (which `softplus_inverse` - # will). - d_ok = math_ops.not_equal(d, 0.) - safe_d = array_ops.where(d_ok, d, array_ops.ones_like(d)) - d_ok_result = logu + distribution_util.softplus_inverse(safe_d) - - inf = np.array(np.inf, dtype=logu.dtype.as_numpy_dtype) - - # When not(d_ok) and is_positive_and_largest then we manually compute the - # log_loosum_u. (We can efficiently do this for any one point but not all, - # hence we still need the above calculation.) This is good because when - # this condition is met, we cannot use the above calculation; its -inf. - # We now compute the log-leave-out-max-sum, replicate it to every - # point and make sure to select it only when we need to. - is_positive_and_largest = math_ops.logical_and( - logu > 0., - math_ops.equal(logu, log_max_u[array_ops.newaxis, ...])) - log_lomsum_u = math_ops.reduce_logsumexp( - array_ops.where(is_positive_and_largest, - array_ops.fill(array_ops.shape(logu), -inf), - logu), - axis=0, keep_dims=True) - log_lomsum_u = array_ops.tile( - log_lomsum_u, - multiples=1 + array_ops.pad([n-1], [[0, array_ops.rank(logu)-1]])) - - d_not_ok_result = array_ops.where( - is_positive_and_largest, - log_lomsum_u, - array_ops.fill(array_ops.shape(d), -inf)) - - log_loosum_u = array_ops.where(d_ok, d_ok_result, d_not_ok_result) - - # The swap-one-out-sum ("soosum") is n different sums, each of which - # replaces the i-th item with the i-th-left-out average, i.e., - # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i])) - # = exp(log_loosum_u[i]) + exp(looavg_logu[i]) - looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1 - log_soosum_u = math_ops.reduce_logsumexp( - array_ops.stack([log_loosum_u, looavg_logu]), - axis=0) - - log_avg_u = log_sum_u_minus_log_max_u + log_max_u - log_n - log_sooavg_u = log_soosum_u - log_n - - log_avg_u.set_shape(logu.shape.with_rank_at_least(1)[1:]) - log_sooavg_u.set_shape(logu.shape) - - return log_avg_u, log_sooavg_u diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py deleted file mode 100644 index d44fe6529a7ff0da0c6747e193fdb98a272a8da3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions for specifying custom gradients. - -@@custom_gradient - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - -__all__ = [ - "custom_gradient", -] - - -def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, - name=None): - """Enables specifying a custom gradient. - - This function works by clever application of `stop_gradient`. I.e., observe - that: - - ```none - h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) - ``` - - is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = - stop_gradient(g(x)).` - - In addition to scalar-domain/scalar-range functions, this function also - supports tensor-domain/scalar-range functions. However, in the latter case it - is necessary to reduce `x` to a scalar. This can be done by indicating the - `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior - to calling this function. - - Partial Custom Gradient: - - Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy = - None`. This is because a `Tensor` cannot have only a portion of its gradient - stopped. To circumvent this issue, one must manually `stop_gradient` the - relevant portions of `f`, `g`. For example see the unit-test, - `test_works_correctly_fx_gx_manually_stopped`. - - Args: - fx: `Tensor`. Output of function evaluated at `x`. - gx: `Tensor`. Gradient of function evaluated at `x`. - x: `Tensor`. Point of evaluation for `f, g`. - axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain - of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range. - If `None` `f` is assumed to render one scalar given all of `x`. Otherwise - `f` is assumed to output one scalar for each of `axis` dimensions of `x`. - fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually - have `stop_gradient` applied. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - fx: Floating-type `Tensor` equal to `f(x)` but which has gradient - `stop_gradient(g(x))`. - """ - with ops.name_scope(name, "custom_gradient", [fx, gx, x]): - fx = ops.convert_to_tensor(fx, name="fx") - # We don't want to bother eagerly computing `gx` since we may not even need - # it. - with ops.control_dependencies([fx]): - gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx") - gx = array_ops.identity(gx, name="gx") - # Proof of correctness: - # - # f(x) = x * stop[gx] + stop[fx - x * gx] - # = stop[fx] - # - # g(x) = grad[fx] - # = stop[gx] + grad[stop[fx - x * gx]] - # = stop[gx] + 0 - # - # Notice that when x is zero it still works: - # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] - # - # The proof is similar for the tensor-domain case, except that `x` is - # replaced by `reduce_sum(x)`. - sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x") - if not fx_gx_manually_stopped: - fx = array_ops.stop_gradient(fx) - gx = array_ops.stop_gradient(gx) - # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write - # the code this way, rather than, e.g., - # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. - # For more discussion regarding the relevant portions of the IEEE754 - # standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py deleted file mode 100644 index 8cabf18903b5f15002470acdfb8fdd3ec31a7413..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Quasi Monte Carlo support: Halton sequence. - -@@sample -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - - -__all__ = [ - 'sample', -] - - -# The maximum dimension we support. This is limited by the number of primes -# in the _PRIMES array. -_MAX_DIMENSION = 1000 - - -def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): - r"""Returns a sample from the `m` dimensional Halton sequence. - - Warning: The sequence elements take values only between 0 and 1. Care must be - taken to appropriately transform the domain of a function if it differs from - the unit cube before evaluating integrals using Halton samples. It is also - important to remember that quasi-random numbers are not a replacement for - pseudo-random numbers in every context. Quasi random numbers are completely - deterministic and typically have significant negative autocorrelation (unless - randomized). - - Computes the members of the low discrepancy Halton sequence in dimension - `dim`. The d-dimensional sequence takes values in the unit hypercube in d - dimensions. Currently, only dimensions up to 1000 are supported. The prime - base for the `k`-th axes is the k-th prime starting from 2. For example, - if dim = 3, then the bases will be [2, 3, 5] respectively and the first - element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete - description of the Halton sequences see: - https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences - and their applications see: - https://en.wikipedia.org/wiki/Low-discrepancy_sequence. - - The user must supply either `num_samples` or `sample_indices` but not both. - The former is the number of samples to produce starting from the first - element. If `sample_indices` is given instead, the specified elements of - the sequence are generated. For example, sample_indices=tf.range(10) is - equivalent to specifying n=10. - - Example Use: - - ```python - bf = tf.contrib.bayesflow - - # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 - dim = 3 - sample = bf.halton_sequence.sample(dim, num_samples=num_samples) - - # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional - # hypercube. - powers = tf.range(1.0, limit=dim + 1) - integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) - true_value = 1.0 / tf.reduce_prod(powers + 1.0) - with tf.Session() as session: - values = session.run((integral, true_value)) - - # Produces a relative absolute error of 1.7%. - print ("Estimated: %f, True Value: %f" % values) - - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. - - - sample_indices = tf.range(start=1000, limit=1000 + num_samples, - dtype=tf.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) - - integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, - axis=-1)) - with tf.Session() as session: - values = session.run((integral_leaped, true_value)) - # Now produces a relative absolute error of 0.05%. - print ("Leaped Estimated: %f, True Value: %f" % values) - ``` - - Args: - dim: Positive Python `int` representing each sample's `event_size.` Must - not be greater than 1000. - num_samples: (Optional) positive Python `int`. The number of samples to - generate. Either this parameter or sample_indices must be specified but - not both. If this parameter is None, then the behaviour is determined by - the `sample_indices`. - sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements - of the sequence to compute specified by their position in the sequence. - The entries index into the Halton sequence starting with 0 and hence, - must be whole numbers. For example, sample_indices=[0, 5, 6] will produce - the first, sixth and seventh elements of the sequence. If this parameter - is None, then the `num_samples` parameter must be specified which gives - the number of desired samples starting from the first sample. - dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. - Default is `float32`. - name: (Optional) Python `str` describing ops managed by this function. If - not supplied the name of this function is used. - - Returns: - halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype - and `shape` `[num_samples, dim]` if `num_samples` was specified or shape - `[s, dim]` where s is the size of `sample_indices` if `sample_indices` - were specified. - - Raises: - ValueError: if both `sample_indices` and `num_samples` were specified or - if dimension `dim` is less than 1 or greater than 1000. - """ - if dim < 1 or dim > _MAX_DIMENSION: - raise ValueError( - 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, - dim)) - if (num_samples is None) == (sample_indices is None): - raise ValueError('Either `num_samples` or `sample_indices` must be' - ' specified but not both.') - - dtype = dtype or dtypes.float32 - if not dtype.is_floating: - raise ValueError('dtype must be of `float`-type') - - with ops.name_scope(name, 'sample', values=[sample_indices]): - # Here and in the following, the shape layout is as follows: - # [sample dimension, event dimension, coefficient dimension]. - # The coefficient dimension is an intermediate axes which will hold the - # weights of the starting integer when expressed in the (prime) base for - # an event dimension. - indices = _get_indices(num_samples, sample_indices, dtype) - radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) - - max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), - radixes) - - max_size = math_ops.reduce_max(max_sizes_by_axes) - - # The powers of the radixes that we will need. Note that there is a bit - # of an excess here. Suppose we need the place value coefficients of 7 - # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits - # for base 3. However, we can only create rectangular tensors so we - # store both expansions in a [2, 3] tensor. This leads to the problem that - # we might end up attempting to raise large numbers to large powers. For - # example, base 2 expansion of 1024 has 10 digits. If we were in 10 - # dimensions, then the 10th prime (29) we will end up computing 29^10 even - # though we don't need it. We avoid this by setting the exponents for each - # axes to 0 beyond the maximum value needed for that dimension. - exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1]) - weight_mask = exponents_by_axes > max_sizes_by_axes - capped_exponents = array_ops.where( - weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes) - weights = radixes ** capped_exponents - coeffs = math_ops.floor_div(indices, weights) - coeffs *= 1 - math_ops.cast(weight_mask, dtype) - coeffs = (coeffs % radixes) / radixes - return math_ops.reduce_sum(coeffs / weights, axis=-1) - - -def _get_indices(n, sample_indices, dtype, name=None): - """Generates starting points for the Halton sequence procedure. - - The k'th element of the sequence is generated starting from a positive integer - which must be distinct for each `k`. It is conventional to choose the starting - point as `k` itself (or `k+1` if k is zero based). This function generates - the starting integers for the required elements and reshapes the result for - later use. - - Args: - n: Positive `int`. The number of samples to generate. If this - parameter is supplied, then `sample_indices` should be None. - sample_indices: `Tensor` of dtype int32 and rank 1. The entries - index into the Halton sequence starting with 0 and hence, must be whole - numbers. For example, sample_indices=[0, 5, 6] will produce the first, - sixth and seventh elements of the sequence. If this parameter is not None - then `n` must be None. - dtype: The dtype of the sample. One of `float32` or `float64`. - Default is `float32`. - name: Python `str` name which describes ops created by this function. - - Returns: - indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. - """ - with ops.name_scope(name, 'get_indices', [n, sample_indices]): - if sample_indices is None: - sample_indices = math_ops.range(n, dtype=dtype) - else: - sample_indices = math_ops.cast(sample_indices, dtype) - - # Shift the indices so they are 1 based. - indices = sample_indices + 1 - - # Reshape to make space for the event dimension and the place value - # coefficients. - return array_ops.reshape(indices, [-1, 1, 1]) - - -def _base_expansion_size(num, bases): - """Computes the number of terms in the place value expansion. - - Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of - `num` in base b (ak <> 0). This function computes and returns `k` for each - base `b` specified in `bases`. - - This can be inferred from the base `b` logarithm of `num` as follows: - $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ - - Args: - num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to - compute the base expansion size of. - bases: `Tensor` of the same dtype as num. The bases to compute the size - against. - - Returns: - Tensor of same dtype and shape as `bases` containing the size of num when - written in that base. - """ - return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1 - - -def _primes_less_than(n): - # Based on - # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 - """Returns sorted array of primes such that `2 <= prime < n`.""" - small_primes = np.array((2, 3, 5)) - if n <= 6: - return small_primes[small_primes < n] - sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool) - sieve[0] = False - m = int(n ** 0.5) // 3 + 1 - for i in range(m): - if not sieve[i]: - continue - k = 3 * i + 1 | 1 - sieve[k ** 2 // 3::2 * k] = False - sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False - return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] - -_PRIMES = _primes_less_than(7919+1) - -assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py deleted file mode 100644 index f724910c59315867a42a56fab3deb36f5d3adb7a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ /dev/null @@ -1,1185 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. - -@@sample_chain -@@sample_annealed_importance_chain -@@kernel -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import util as distributions_util - -__all__ = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", -] - - -KernelResults = collections.namedtuple( - "KernelResults", - [ - "acceptance_probs", - "current_grads_target_log_prob", # "Current result" means "accepted". - "current_target_log_prob", # "Current result" means "accepted". - "energy_change", - "is_accepted", - "proposed_grads_target_log_prob", - "proposed_state", - "proposed_target_log_prob", - "random_positive", - ]) - - -def _make_dummy_kernel_results( - dummy_state, - dummy_target_log_prob, - dummy_grads_target_log_prob): - return KernelResults( - acceptance_probs=dummy_target_log_prob, - current_grads_target_log_prob=dummy_grads_target_log_prob, - current_target_log_prob=dummy_target_log_prob, - energy_change=dummy_target_log_prob, - is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), - proposed_grads_target_log_prob=dummy_grads_target_log_prob, - proposed_state=dummy_state, - proposed_target_log_prob=dummy_target_log_prob, - random_positive=dummy_target_log_prob, - ) - - -def sample_chain( - num_results, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - num_burnin_steps=0, - num_steps_between_results=0, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm - that takes a series of gradient-informed steps to produce a Metropolis - proposal. This function samples from an HMC Markov chain at `current_state` - and whose stationary distribution has log-unnormalized-density - `target_log_prob_fn()`. - - This function samples from multiple chains in parallel. It assumes that the - the leftmost dimensions of (each) `current_state` (part) index an independent - chain. The function `target_log_prob_fn()` sums log-probabilities across - event dimensions (i.e., current state (part) rightmost dimensions). Each - element of the output of `target_log_prob_fn()` represents the (possibly - unnormalized) log-probability of the joint distribution over (all) the current - state (parts). - - The `current_state` can be represented as a single `Tensor` or a `list` of - `Tensors` which collectively represent the current state. When specifying a - `list`, one must also specify a list of `step_size`s. - - Note: `target_log_prob_fn` is called exactly twice. - - Only one out of every `num_steps_between_samples + 1` steps is included in the - returned results. This "thinning" comes at a cost of reduced statistical - power, while reducing memory requirements and autocorrelation. For more - discussion see [1]. - - [1]: "Statistically efficient thinning of a Markov chain sampler." - Art B. Owen. April 2017. - http://statweb.stanford.edu/~owen/reports/bestthinning.pdf - - #### Examples: - - ##### Sample from a diagonal-variance Gaussian. - - ```python - tfd = tf.contrib.distributions - - def make_likelihood(true_variances): - return tfd.MultivariateNormalDiag( - scale_diag=tf.sqrt(true_variances)) - - dims = 10 - dtype = np.float32 - true_variances = tf.linspace(dtype(1), dtype(3), dims) - likelihood = make_likelihood(true_variances) - - states, kernel_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=likelihood.log_prob, - current_state=tf.zeros(dims), - step_size=0.5, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(states, axis=0) - sample_var = tf.reduce_mean( - tf.squared_difference(states, sample_mean), - axis=0) - ``` - - ##### Sampling from factor-analysis posteriors with known factors. - - I.e., - - ```none - for i=1..n: - w[i] ~ Normal(0, eye(d)) # prior - x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood - ``` - - where `F` denotes factors. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, factors): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) - - # Setup data. - num_weights = 10 - num_factors = 4 - num_chains = 100 - dtype = np.float32 - - prior = make_prior(num_weights, dtype) - weights = prior.sample(num_chains) - factors = np.random.randn(num_factors, num_weights).astype(dtype) - x = make_likelihood(weights, factors).sample(num_chains) - - def target_log_prob(w): - # Target joint is: `f(w) = p(w, x | factors)`. - return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) - - # Get `num_results` samples from `num_chains` independent chains. - chains_states, kernels_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target_log_prob, - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) - sample_var = tf.reduce_mean( - tf.squared_difference(chains_states, sample_mean), - axis=[0, 1]) - ``` - - Args: - num_results: Integer number of Markov chain draws. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - num_burnin_steps: Integer number of chain steps to take before starting to - collect results. - Default value: 0 (i.e., no burn-in). - num_steps_between_results: Integer number of chain steps between collecting - a result. Only one out of every `num_steps_between_samples + 1` steps is - included in the returned results. This "thinning" comes at a cost of - reduced statistical power, while reducing memory requirements and - autocorrelation. For more discussion see [1]. - Default value: 0 (i.e., no subsampling). - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob` at the `current_state` and wrt - the `current_state`. Must have same shape as `current_state`. The only - reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_chain"). - - Returns: - accepted_states: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state` but with a prepended `num_results`-size dimension. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - with ops.name_scope( - name, "hmc_sample_chain", - [num_results, current_state, step_size, num_leapfrog_steps, - num_burnin_steps, num_steps_between_results, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob, - ] = _prepare_args( - target_log_prob_fn, - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob) - num_results = ops.convert_to_tensor( - num_results, - dtype=dtypes.int32, - name="num_results") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - num_burnin_steps = ops.convert_to_tensor( - num_burnin_steps, - dtype=dtypes.int32, - name="num_burnin_steps") - num_steps_between_results = ops.convert_to_tensor( - num_steps_between_results, - dtype=dtypes.int32, - name="num_steps_between_results") - - def _run_chain(num_steps, current_state, kernel_results): - """Runs the chain(s) for `num_steps`.""" - def _loop_body(iter_, current_state, kernel_results): - return [iter_ + 1] + list(kernel( - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), - current_state, - kernel_results, - ], - ) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - return control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - def _scan_body(args_list, iter_): - """Closure which implements `tf.scan` body.""" - current_state, kernel_results = args_list - return _run_chain( - 1 + array_ops.where(math_ops.equal(iter_, 0), - num_burnin_steps, - num_steps_between_results), - current_state, - kernel_results) - - scan_kwargs = dict( - fn=_scan_body, - elems=math_ops.range(num_results), # iter_: used to choose burnin. - initializer=[ - current_state, - _make_dummy_kernel_results( - current_state, - current_target_log_prob, - current_grads_target_log_prob), - ]) - if seed is not None: - scan_kwargs["parallel_iterations"] = 1 - return functional_ops.scan(**scan_kwargs) - - -def sample_annealed_importance_chain( - proposal_log_prob_fn, - num_steps, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - name=None): - """Runs annealed importance sampling (AIS) to estimate normalizing constants. - - This function uses Hamiltonian Monte Carlo to sample from a series of - distributions that slowly interpolates between an initial "proposal" - distribution: - - `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - - and the target distribution: - - `exp(target_log_prob_fn(x) - target_log_normalizer)`, - - accumulating importance weights along the way. The product of these - importance weights gives an unbiased estimate of the ratio of the - normalizing constants of the initial distribution and the target - distribution: - - `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. - - Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three - times (although this may be reduced to two times, in the future). - - #### Examples: - - ##### Estimate the normalizing constant of a log-gamma distribution. - - ```python - tfd = tf.contrib.distributions - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 20 - dtype = np.float32 - - proposal = tfd.MultivatiateNormalDiag( - loc=tf.zeros([dims], dtype=dtype)) - - target = tfd.TransformedDistribution( - distribution=tfd.Gamma(concentration=dtype(2), - rate=dtype(3)), - bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), - event_shape=[dims]) - - chains_state, ais_weights, kernels_results = ( - hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal.log_prob, - num_steps=1000, - target_log_prob_fn=target.log_prob, - step_size=0.2, - current_state=proposal.sample(num_chains), - num_leapfrog_steps=2)) - - log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) - ``` - - ##### Estimate marginal likelihood of a Bayesian regression model. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, x): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, x, axes=[[0], [-1]])) - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 10 - dtype = np.float32 - - # Make training data. - x = np.random.randn(num_chains, dims).astype(dtype) - true_weights = np.random.randn(dims).astype(dtype) - y = np.dot(x, true_weights) + np.random.randn(num_chains) - - # Setup model. - prior = make_prior(dims, dtype) - def target_log_prob_fn(weights): - return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) - - proposal = tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - weight_samples, ais_weights, kernel_results = ( - hmc.sample_annealed_importance_chain( - num_steps=1000, - proposal_log_prob_fn=proposal.log_prob, - target_log_prob_fn=target_log_prob_fn - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2)) - log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - ``` - - Args: - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - num_steps: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). - - Returns: - accepted_state: `Tensor` or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at the final iteration. Has same shape as - input `current_state`. - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(current_state)`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - def make_convex_combined_log_prob_fn(iter_): - def _fn(*args): - p = proposal_log_prob_fn(*args) - t = target_log_prob_fn(*args) - dtype = p.dtype.base_dtype - beta = (math_ops.cast(iter_ + 1, dtype) - / math_ops.cast(num_steps, dtype)) - return (1. - beta) * p + beta * t - return _fn - - with ops.name_scope( - name, "hmc_sample_annealed_importance_chain", - [num_steps, current_state, step_size, num_leapfrog_steps, seed]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_log_prob, - current_grads_log_prob, - ] = _prepare_args( - make_convex_combined_log_prob_fn(iter_=0), - current_state, - step_size, - description="convex_combined_log_prob") - num_steps = ops.convert_to_tensor( - num_steps, - dtype=dtypes.int32, - name="num_steps") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - def _loop_body(iter_, ais_weights, current_state, kernel_results): - """Closure which implements `tf.while_loop` body.""" - current_state_parts = (list(current_state) - if _is_list_like(current_state) - else [current_state]) - # TODO(b/72994218): Consider refactoring things to avoid this unecessary - # call. - ais_weights += ((target_log_prob_fn(*current_state_parts) - - proposal_log_prob_fn(*current_state_parts)) - / math_ops.cast(num_steps, ais_weights.dtype)) - return [iter_ + 1, ais_weights] + list(kernel( - make_convex_combined_log_prob_fn(iter_), - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - array_ops.zeros_like(current_log_prob), # ais_weights - current_state, - _make_dummy_kernel_results(current_state, - current_log_prob, - current_grads_log_prob), - ]) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - - [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - return [current_state, ais_weights, kernel_results] - - -def kernel(target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs one iteration of Hamiltonian Monte Carlo. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function applies one step of HMC to - randomly update the variable `x`. - - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) - - #### Examples: - - ##### Simple chain with warm-up. - - ```python - tfd = tf.contrib.distributions - - # Tuning acceptance rates: - dtype = np.float32 - target_accept_rate = 0.631 - num_warmup_iter = 500 - num_chain_iter = 500 - - x = tf.get_variable(name="x", initializer=dtype(1)) - step_size = tf.get_variable(name="step_size", initializer=dtype(1)) - - target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - - new_x, other_results = hmc.kernel( - target_log_prob_fn=target.log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=3)[:4] - - x_update = x.assign(new_x) - - step_size_update = step_size.assign_add( - step_size * tf.where( - other_results.acceptance_probs > target_accept_rate, - 0.01, -0.01)) - - warmup = tf.group([x_update, step_size_update]) - - tf.global_variables_initializer().run() - - sess.graph.finalize() # No more graph building. - - # Warm up the sampler and adapt the step size - for _ in xrange(num_warmup_iter): - sess.run(warmup) - - # Collect samples without adapting step size - samples = np.zeros([num_chain_iter]) - for i in xrange(num_chain_iter): - _, x_, target_log_prob_, grad_ = sess.run([ - x_update, - x, - other_results.target_log_prob, - other_results.grads_target_log_prob]) - samples[i] = x_ - - print(samples.mean(), samples.std()) - ``` - - ##### Sample from more complicated posterior. - - I.e., - - ```none - W ~ MVN(loc=0, scale=sigma * eye(dims)) - for i=1...num_samples: - X[i] ~ MVN(loc=0, scale=eye(dims)) - eps[i] ~ Normal(loc=0, scale=1) - Y[i] = X[i].T * W + eps[i] - ``` - - ```python - tfd = tf.contrib.distributions - - def make_training_data(num_samples, dims, sigma): - dt = np.asarray(sigma).dtype - zeros = tf.zeros(dims, dtype=dt) - x = tfd.MultivariateNormalDiag( - loc=zeros).sample(num_samples, seed=1) - w = tfd.MultivariateNormalDiag( - loc=zeros, - scale_identity_multiplier=sigma).sample(seed=2) - noise = tfd.Normal( - loc=dt(0), - scale=dt(1)).sample(num_samples, seed=3) - y = tf.tensordot(x, w, axes=[[1], [0]]) + noise - return y, x, w - - def make_prior(sigma, dims): - # p(w | sigma) - return tfd.MultivariateNormalDiag( - loc=tf.zeros([dims], dtype=sigma.dtype), - scale_identity_multiplier=sigma) - - def make_likelihood(x, w): - # p(y | x, w) - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(x, w, axes=[[1], [0]])) - - # Setup assumptions. - dtype = np.float32 - num_samples = 150 - dims = 10 - num_iters = int(5e3) - - true_sigma = dtype(0.5) - y, x, true_weights = make_training_data(num_samples, dims, true_sigma) - - # Estimate of `log(true_sigma)`. - log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) - sigma = tf.exp(log_sigma) - - # State of the Markov chain. - weights = tf.get_variable( - name="weights", - initializer=np.random.randn(dims).astype(dtype)) - - prior = make_prior(sigma, dims) - - def joint_log_prob_fn(w): - # f(w) = log p(w, y | x) - return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) - - weights_update = weights.assign( - hmc.kernel(target_log_prob_fn=joint_log_prob, - current_state=weights, - step_size=0.1, - num_leapfrog_steps=5)[0]) - - with tf.control_dependencies([weights_update]): - loss = -prior.log_prob(weights) - - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) - - sess.graph.finalize() # No more graph building. - - tf.global_variables_initializer().run() - - sigma_history = np.zeros(num_iters, dtype) - weights_history = np.zeros([num_iters, dims], dtype) - - for i in xrange(num_iters): - _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) - weights_history[i, :] = weights_ - sigma_history[i] = sigma_ - - true_weights_ = sess.run(true_weights) - - # Should converge to something close to true_sigma. - plt.plot(sigma_history); - plt.ylabel("sigma"); - plt.xlabel("iteration"); - ``` - - Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `current_target_log_prob` at the `current_state` - and wrt the `current_state`. Must have same shape as `current_state`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_kernel"). - - Returns: - accepted_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - - Raises: - ValueError: if there isn't one `step_size` or a list with same length as - `current_state`. - """ - with ops.name_scope( - name, "hmc_kernel", - [current_state, step_size, num_leapfrog_steps, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [current_state_parts, step_sizes, current_target_log_prob, - current_grads_target_log_prob] = _prepare_args( - target_log_prob_fn, current_state, step_size, - current_target_log_prob, current_grads_target_log_prob, - maybe_expand=True) - independent_chain_ndims = distributions_util.prefer_static_rank( - current_target_log_prob) - current_momentums = [] - for s in current_state_parts: - current_momentums.append(random_ops.random_normal( - shape=array_ops.shape(s), - dtype=s.dtype.base_dtype, - seed=seed)) - seed = distributions_util.gen_new_seed( - seed, salt="hmc_kernel_momentums") - - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] = _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob, - current_grads_target_log_prob) - - energy_change = _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims) - - # u < exp(min(-energy, 0)), where u~Uniform[0,1) - # ==> -log(u) >= max(e, 0) - # ==> -log(u) >= e - # (Perhaps surprisingly, we don't have a better way to obtain a random - # uniform from positive reals, i.e., `tf.random_uniform(minval=0, - # maxval=np.inf)` won't work.) - random_uniform = random_ops.random_uniform( - shape=array_ops.shape(energy_change), - dtype=energy_change.dtype, - seed=seed) - random_positive = -math_ops.log(random_uniform) - is_accepted = random_positive >= energy_change - - accepted_target_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - - accepted_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] - - accepted_grads_target_log_prob = [ - _choose(is_accepted, - proposed_grad, - grad, - independent_chain_ndims) - for proposed_grad, grad - in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] - - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(accepted_state_parts), - KernelResults( - acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), - current_grads_target_log_prob=accepted_grads_target_log_prob, - current_target_log_prob=accepted_target_log_prob, - energy_change=energy_change, - is_accepted=is_accepted, - proposed_grads_target_log_prob=proposed_grads_target_log_prob, - proposed_state=maybe_flatten(proposed_state_parts), - proposed_target_log_prob=proposed_target_log_prob, - random_positive=random_positive, - ), - ] - - -def _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Applies `num_leapfrog_steps` of the leapfrog integrator. - - Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. - - #### Examples: - - ##### Simple quadratic potential. - - ```python - tfd = tf.contrib.distributions - - dims = 10 - num_iter = int(1e3) - dtype = np.float32 - - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - - [ - new_momentums, - new_positions, - ] = hmc._leapfrog_integrator( - current_momentums=[momentum], - target_log_prob_fn=tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)).log_prob, - current_state_parts=[position], - step_sizes=0.1, - num_leapfrog_steps=3)[:2] - - sess.graph.finalize() # No more graph building. - - momentum_ = np.random.randn(dims).astype(dtype) - position_ = np.random.randn(dims).astype(dtype) - - positions = np.zeros([num_iter, dims], dtype) - for i in xrange(num_iter): - position_, momentum_ = sess.run( - [new_momentums[0], new_position[0]], - feed_dict={position: position_, momentum: momentum_}) - positions[i] = position_ - - plt.plot(positions[:, 0]); # Sinusoidal. - ``` - - Args: - current_momentums: Tensor containing the value(s) of the momentum - variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like - `*current_state_parts` and returns its (possibly unnormalized) log-density - under the target distribution. - current_state_parts: Python `list` of `Tensor`s representing the current - state(s) of the Markov chain(s). The first `independent_chain_ndims` of - the `Tensor`(s) index different chains. - step_sizes: Python `list` of `Tensor`s representing the step size for the - leapfrog integrator. Must broadcast with the shape of - `current_state_parts`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. When - possible, it's often helpful to match per-variable step sizes to the - standard deviations of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn(*current_state_parts)`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob_fn(*current_state_parts`) wrt - `current_state_parts`. Must have same shape as `current_state_parts`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_leapfrog_integrator"). - - Returns: - proposed_momentums: Updated value of the momentum. - proposed_state_parts: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state_parts`. - proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `accepted_state`. - proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `accepted_state`. - - Raises: - ValueError: if `len(momentums) != len(state_parts)`. - ValueError: if `len(state_parts) != len(step_sizes)`. - ValueError: if `len(state_parts) != len(grads_target_log_prob)`. - TypeError: if `not target_log_prob.dtype.is_floating`. - """ - def _loop_body(step, - current_momentums, - current_state_parts, - ignore_current_target_log_prob, # pylint: disable=unused-argument - current_grads_target_log_prob): - return [step + 1] + list(_leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob)) - - with ops.name_scope( - name, "hmc_leapfrog_integrator", - [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, - current_target_log_prob, current_grads_target_log_prob]): - if len(current_momentums) != len(current_state_parts): - raise ValueError("`momentums` must be in one-to-one correspondence " - "with `state_parts`") - num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, - name="num_leapfrog_steps") - current_target_log_prob, current_grads_target_log_prob = ( - _maybe_call_fn_and_grads( - target_log_prob_fn, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob)) - return control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_leapfrog_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - current_momentums, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob, - ], - back_prop=False)[1:] # Lop-off "iter_". - - -def _leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob, - name=None): - """Applies one step of the leapfrog integrator.""" - with ops.name_scope( - name, "_leapfrog_step", - [current_momentums, current_state_parts, step_sizes, - current_grads_target_log_prob]): - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(current_momentums, - step_sizes, - current_grads_target_log_prob)] - proposed_state_parts = [x + ss * m for x, ss, m - in zip(current_state_parts, - step_sizes, - proposed_momentums)] - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - if not proposed_target_log_prob.dtype.is_floating: - raise TypeError("`target_log_prob_fn` must produce a `Tensor` " - "with `float` `dtype`.") - proposed_grads_target_log_prob = gradients_ops.gradients( - proposed_target_log_prob, proposed_state_parts) - if any(g is None for g in proposed_grads_target_log_prob): - raise ValueError( - "Encountered `None` gradient. Does your target `target_log_prob_fn` " - "access all `tf.Variable`s via `tf.get_variable`?\n" - " current_state_parts: {}\n" - " proposed_state_parts: {}\n" - " proposed_grads_target_log_prob: {}".format( - current_state_parts, - proposed_state_parts, - proposed_grads_target_log_prob)) - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(proposed_momentums, - step_sizes, - proposed_grads_target_log_prob)] - return [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] - - -def _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims, - name=None): - """Helper to `kernel` which computes the energy change.""" - with ops.name_scope( - name, "compute_energy_change", - ([current_target_log_prob, proposed_target_log_prob, - independent_chain_ndims] + - current_momentums + proposed_momentums)): - # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy - # since they're a mouthful and lets us inline more. - lk0, lk1 = [], [] - for current_momentum, proposed_momentum in zip(current_momentums, - proposed_momentums): - axis = math_ops.range(independent_chain_ndims, - array_ops.rank(current_momentum)) - lk0.append(_log_sum_sq(current_momentum, axis)) - lk1.append(_log_sum_sq(proposed_momentum, axis)) - - lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), - axis=-1) - lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), - axis=-1) - lp0 = -current_target_log_prob # log_potential - lp1 = -proposed_target_log_prob # proposed_log_potential - x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], - axis=-1) - - # The sum is NaN if any element is NaN or we see both +Inf and -Inf. - # Thus we will replace such rows with infinite energy change which implies - # rejection. Recall that float-comparisons with NaN are always False. - is_sum_determinate = ( - math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & - math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) - is_sum_determinate = array_ops.tile( - is_sum_determinate[..., array_ops.newaxis], - multiples=array_ops.concat([ - array_ops.ones(array_ops.rank(is_sum_determinate), - dtype=dtypes.int32), - [4], - ], axis=0)) - x = array_ops.where(is_sum_determinate, - x, - array_ops.fill(array_ops.shape(x), - value=x.dtype.as_numpy_dtype(np.inf))) - - return math_ops.reduce_sum(x, axis=-1) - - -def _choose(is_accepted, - accepted, - rejected, - independent_chain_ndims, - name=None): - """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" - def _expand_is_accepted_like(x): - with ops.name_scope("_choose"): - expand_shape = array_ops.concat([ - array_ops.shape(is_accepted), - array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], - dtype=dtypes.int32), - ], axis=0) - multiples = array_ops.concat([ - array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), - array_ops.shape(x)[independent_chain_ndims:], - ], axis=0) - m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), - multiples) - m.set_shape(x.shape) - return m - with ops.name_scope(name, "_choose", values=[ - is_accepted, accepted, rejected, independent_chain_ndims]): - return array_ops.where(_expand_is_accepted_like(accepted), - accepted, - rejected) - - -def _maybe_call_fn_and_grads(fn, - fn_arg_list, - fn_result=None, - grads_fn_result=None, - description="target_log_prob"): - """Helper which computes `fn_result` and `grads` if needed.""" - fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) - else [fn_arg_list]) - if fn_result is None: - fn_result = fn(*fn_arg_list) - if not fn_result.dtype.is_floating: - raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( - description)) - if grads_fn_result is None: - grads_fn_result = gradients_ops.gradients( - fn_result, fn_arg_list) - if len(fn_arg_list) != len(grads_fn_result): - raise ValueError("`{}` must be in one-to-one correspondence with " - "`grads_{}`".format(*[description]*2)) - if any(g is None for g in grads_fn_result): - raise ValueError("Encountered `None` gradient.") - return fn_result, grads_fn_result - - -def _prepare_args(target_log_prob_fn, state, step_size, - target_log_prob=None, grads_target_log_prob=None, - maybe_expand=False, description="target_log_prob"): - """Helper which processes input args to meet list-like assumptions.""" - state_parts = list(state) if _is_list_like(state) else [state] - state_parts = [ops.convert_to_tensor(s, name="state") - for s in state_parts] - target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( - target_log_prob_fn, - state_parts, - target_log_prob, - grads_target_log_prob, - description) - step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] - step_sizes = [ - ops.convert_to_tensor( - s, name="step_size", dtype=target_log_prob.dtype) - for s in step_sizes] - if len(step_sizes) == 1: - step_sizes *= len(state_parts) - if len(state_parts) != len(step_sizes): - raise ValueError("There should be exactly one `step_size` or it should " - "have same length as `current_state`.") - maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0] - return [ - maybe_flatten(state_parts), - maybe_flatten(step_sizes), - target_log_prob, - grads_target_log_prob, - ] - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _log_sum_sq(x, axis=None): - """Computes log(sum(x**2)).""" - return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py deleted file mode 100644 index a742b7c1aa593d6c08bf9d8d597c99c9fc4e7aed..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Probabilistic neural layers. - -See ${python/contrib.bayesflow.layers}. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import * -from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational import * -from tensorflow.contrib.bayesflow.python.ops.layers_util import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'Convolution1DReparameterization', - 'Convolution2DReparameterization', - 'Convolution3DReparameterization', - 'Convolution1DFlipout', - 'Convolution2DFlipout', - 'Convolution3DFlipout', - 'Conv1DReparameterization', - 'Conv2DReparameterization', - 'Conv3DReparameterization', - 'Conv1DFlipout', - 'Conv2DFlipout', - 'Conv3DFlipout', - 'convolution1d_reparameterization', - 'convolution2d_reparameterization', - 'convolution3d_reparameterization', - 'convolution1d_flipout', - 'convolution2d_flipout', - 'convolution3d_flipout', - 'conv1d_reparameterization', - 'conv2d_reparameterization', - 'conv3d_reparameterization', - 'conv1d_flipout', - 'conv2d_flipout', - 'conv3d_flipout', - 'DenseReparameterization', - 'DenseLocalReparameterization', - 'DenseFlipout', - 'dense_reparameterization', - 'dense_local_reparameterization', - 'dense_flipout', - 'default_loc_scale_fn', - 'default_mean_field_normal_fn', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py deleted file mode 100644 index 7723cfb442712626ff415f1412e3362f2392ce9f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py +++ /dev/null @@ -1,2943 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Convolutional variational layer classes and their functional aliases. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.layers import utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util - - -class _ConvVariational(layers_lib.Layer): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - """ - - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(_ConvVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self.rank = rank - self.filters = filters - self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size") - self.strides = utils.normalize_tuple(strides, rank, "strides") - self.padding = utils.normalize_padding(padding) - self.data_format = utils.normalize_data_format(data_format) - self.dilation_rate = utils.normalize_tuple( - dilation_rate, rank, "dilation_rate") - self.activation = activation - self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2) - self.kernel_posterior_fn = kernel_posterior_fn - self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn - self.kernel_prior_fn = kernel_prior_fn - self.kernel_divergence_fn = kernel_divergence_fn - self.bias_posterior_fn = bias_posterior_fn - self.bias_posterior_tensor_fn = bias_posterior_tensor_fn - self.bias_prior_fn = bias_prior_fn - self.bias_divergence_fn = bias_divergence_fn - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if self.data_format == "channels_first": - channel_axis = 1 - else: - channel_axis = -1 - if input_shape[channel_axis].value is None: - raise ValueError("The channel dimension of the inputs " - "should be defined. Found `None`.") - input_dim = input_shape[channel_axis].value - kernel_shape = self.kernel_size + (input_dim, self.filters) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel_posterior = self.kernel_posterior_fn( - dtype, kernel_shape, "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel_prior_fn is None: - self.kernel_prior = None - else: - self.kernel_prior = self.kernel_prior_fn( - dtype, kernel_shape, "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias_posterior_fn is None: - self.bias_posterior = None - else: - self.bias_posterior = self.bias_posterior_fn( - dtype, (self.filters,), "bias_posterior", - self.trainable, self.add_variable) - - if self.bias_prior_fn is None: - self.bias_prior = None - else: - self.bias_prior = self.bias_prior_fn( - dtype, (self.filters,), "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2, - axes={channel_axis: input_dim}) - self._convolution_op = nn_ops.Convolution( - input_shape, - filter_shape=tensor_shape.TensorShape(kernel_shape), - dilation_rate=self.dilation_rate, - strides=self.strides, - padding=self.padding.upper(), - data_format=utils.convert_data_format(self.data_format, - self.rank + 2)) - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) - if not self._built_kernel_divergence: - kernel_posterior = self.kernel_posterior - kernel_prior = self.kernel_prior - if isinstance(self.kernel_posterior, independent_lib.Independent): - kernel_posterior = kernel_posterior.distribution - if isinstance(self.kernel_prior, independent_lib.Independent): - kernel_prior = kernel_prior.distribution - self._apply_divergence(self.kernel_divergence_fn, - kernel_posterior, - kernel_prior, - self.kernel_posterior_tensor, - name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - bias_posterior = self.bias_posterior - bias_prior = self.bias_prior - if isinstance(self.bias_posterior, independent_lib.Independent): - bias_posterior = bias_posterior.distribution - if isinstance(self.bias_prior, independent_lib.Independent): - bias_prior = bias_prior.distribution - self._apply_divergence(self.bias_divergence_fn, - bias_posterior, - bias_prior, - self.bias_posterior_tensor, - name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_bias(self, inputs): - if self.bias_posterior is None: - self.bias_posterior_tensor = None - return inputs - self.bias_posterior_tensor = self.bias_posterior_tensor_fn( - self.bias_posterior) - outputs = inputs - if self.data_format == "channels_first": - if self.rank == 1: - # nn.bias_add does not accept a 1D input tensor. - bias = array_ops.reshape(self.bias_posterior_tensor, - (1, self.filters, 1)) - outputs += bias - if self.rank == 2: - outputs = nn.bias_add(outputs, - self.bias_posterior_tensor, - data_format="NCHW") - if self.rank == 3: - # As of Mar 2017, direct addition is significantly slower than - # bias_add when computing gradients. To use bias_add, we collapse Z - # and Y into a single dimension to obtain a 4D input tensor. - outputs_shape = outputs.shape.as_list() - outputs_4d = array_ops.reshape(outputs, - [outputs_shape[0], outputs_shape[1], - outputs_shape[2] * outputs_shape[3], - outputs_shape[4]]) - outputs_4d = nn.bias_add(outputs_4d, - self.bias_posterior_tensor, - data_format="NCHW") - outputs = array_ops.reshape(outputs_4d, outputs_shape) - else: - outputs = nn.bias_add(outputs, - self.bias_posterior_tensor, - data_format="NHWC") - return outputs - - def _apply_divergence(self, divergence_fn, posterior, prior, - posterior_tensor, name): - if (divergence_fn is None or - posterior is None or - prior is None): - divergence = None - return - divergence = standard_ops.identity( - divergence_fn( - posterior, prior, posterior_tensor), - name=name) - self.add_loss(divergence) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if self.data_format == "channels_last": - space = input_shape[1:-1] - new_space = [] - for i in range(len(space)): - new_dim = utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - return tensor_shape.TensorShape([input_shape[0]] + new_space + - [self.filters]) - else: - space = input_shape[2:] - new_space = [] - for i in range(len(space)): - new_dim = utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - return tensor_shape.TensorShape([input_shape[0], self.filters] + - new_space) - - -class _ConvReparameterization(_ConvVariational): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(_ConvReparameterization, self).__init__( - rank=rank, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - def _apply_variational_kernel(self, inputs): - self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( - self.kernel_posterior) - self.kernel_posterior_affine = None - self.kernel_posterior_affine_tensor = None - outputs = self._convolution_op(inputs, self.kernel_posterior_tensor) - return outputs - - -class Conv1DReparameterization(_ConvReparameterization): - """1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.Conv1DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(Conv1DReparameterization, self).__init__( - rank=1, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -def conv1d_reparameterization( - inputs, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.conv1d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = Conv1DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv2DReparameterization(_ConvReparameterization): - """2D convolution layer (e.g. spatial convolution over images). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.Conv2DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(Conv2DReparameterization, self).__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -def conv2d_reparameterization( - inputs, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.conv2d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = Conv2DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv3DReparameterization(_ConvReparameterization): - """3D convolution layer (e.g. spatial convolution over volumes). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.Conv3DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(Conv3DReparameterization, self).__init__( - rank=3, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -def conv3d_reparameterization( - inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.conv3d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = Conv3DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class _ConvFlipout(_ConvVariational): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(_ConvFlipout, self).__init__( - rank=rank, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - self.seed = seed - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`{}` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format( - type(self).__name__, self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), - scale=self.kernel_posterior.distribution.scale) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - - outputs = self._convolution_op( - inputs, self.kernel_posterior.distribution.loc) - - input_shape = array_ops.shape(inputs) - output_shape = array_ops.shape(outputs) - batch_shape = array_ops.expand_dims(input_shape[0], 0) - channels = input_shape[-1] - - sign_input = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(channels, 0)], 0), - dtype=inputs.dtype, - seed=self.seed) - sign_output = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(self.filters, 0)], 0), - dtype=inputs.dtype, - seed=distribution_util.gen_new_seed( - self.seed, salt="conv_flipout")) - for _ in range(self.rank): - sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) - sign_output = array_ops.expand_dims(sign_output, 1) - - sign_input = array_ops.tile( # tile for element-wise op broadcasting - sign_input, - [1] + [input_shape[i + 1] for i in range(self.rank)] + [1]) - sign_output = array_ops.tile( - sign_output, - [1] + [output_shape[i + 1] for i in range(self.rank)] + [1]) - - perturbed_inputs = self._convolution_op( - inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output - - outputs += perturbed_inputs - return outputs - - -class Conv1DFlipout(_ConvFlipout): - """1D convolution layer (e.g. temporal convolution) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.Conv1DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(Conv1DFlipout, self).__init__( - rank=1, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -def conv1d_flipout( - inputs, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.conv1d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = Conv1DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv2DFlipout(_ConvFlipout): - """2D convolution layer (e.g. spatial convolution over images) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.Conv2DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(Conv2DFlipout, self).__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -def conv2d_flipout( - inputs, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.conv2d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = Conv2DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv3DFlipout(_ConvFlipout): - """3D convolution layer (e.g. spatial convolution over volumes) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.Conv3DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(Conv3DFlipout, self).__init__( - rank=3, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -def conv3d_flipout( - inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.conv3d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = Conv3DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -# Aliases - -Convolution1DReparameterization = Conv1DReparameterization -Convolution2DReparameterization = Conv2DReparameterization -Convolution3DReparameterization = Conv3DReparameterization -convolution1d_reparameterization = conv1d_reparameterization -convolution2d_reparameterization = conv2d_reparameterization -convolution3d_reparameterization = conv3d_reparameterization -Convolution1DFlipout = Conv1DFlipout -Convolution2DFlipout = Conv2DFlipout -Convolution3DFlipout = Conv3DFlipout -convolution1d_flipout = conv1d_flipout -convolution2d_flipout = conv2d_flipout -convolution3d_flipout = conv3d_flipout diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py deleted file mode 100644 index 591a8e553de0c194786c7ee8693665f762711b2d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py +++ /dev/null @@ -1,1176 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Dense Bayesian layer using KL-divergence based variational inference. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util - - -class _DenseVariational(layers_lib.Layer): - """Abstract densely-connected class (private, used as implementation base). - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(_DenseVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self.units = units - self.activation = activation - self.input_spec = layers_lib.InputSpec(min_ndim=2) - self.kernel_posterior_fn = kernel_posterior_fn - self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn - self.kernel_prior_fn = kernel_prior_fn - self.kernel_divergence_fn = kernel_divergence_fn - self.bias_posterior_fn = bias_posterior_fn - self.bias_posterior_tensor_fn = bias_posterior_tensor_fn - self.bias_prior_fn = bias_prior_fn - self.bias_divergence_fn = bias_divergence_fn - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - in_size = input_shape.with_rank_at_least(2)[-1].value - if in_size is None: - raise ValueError("The last dimension of the inputs to `Dense` " - "should be defined. Found `None`.") - self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel_posterior = self.kernel_posterior_fn( - dtype, [in_size, self.units], "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel_prior_fn is None: - self.kernel_prior = None - else: - self.kernel_prior = self.kernel_prior_fn( - dtype, [in_size, self.units], "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias_posterior_fn is None: - self.bias_posterior = None - else: - self.bias_posterior = self.bias_posterior_fn( - dtype, [self.units], "bias_posterior", - self.trainable, self.add_variable) - - if self.bias_prior_fn is None: - self.bias_prior = None - else: - self.bias_prior = self.bias_prior_fn( - dtype, [self.units], "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) # pylint: disable=not-callable - if not self._built_kernel_divergence: - kernel_posterior = self.kernel_posterior - kernel_prior = self.kernel_prior - if isinstance(self.kernel_posterior, independent_lib.Independent): - kernel_posterior = kernel_posterior.distribution - if isinstance(self.kernel_prior, independent_lib.Independent): - kernel_prior = kernel_prior.distribution - self._apply_divergence(self.kernel_divergence_fn, - kernel_posterior, - kernel_prior, - self.kernel_posterior_tensor, - name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - bias_posterior = self.bias_posterior - bias_prior = self.bias_prior - if isinstance(self.bias_posterior, independent_lib.Independent): - bias_posterior = bias_posterior.distribution - if isinstance(self.bias_prior, independent_lib.Independent): - bias_prior = bias_prior.distribution - self._apply_divergence(self.bias_divergence_fn, - bias_posterior, - bias_prior, - self.bias_posterior_tensor, - name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_bias(self, inputs): - if self.bias_posterior is None: - self.bias_posterior_tensor = None - return inputs - self.bias_posterior_tensor = self.bias_posterior_tensor_fn( - self.bias_posterior) - return nn.bias_add(inputs, self.bias_posterior_tensor) - - def _apply_divergence(self, divergence_fn, posterior, prior, - posterior_tensor, name): - if (divergence_fn is None or - posterior is None or - prior is None): - divergence = None - return - divergence = standard_ops.identity( - divergence_fn( - posterior, prior, posterior_tensor), - name=name) - self.add_loss(divergence) - - def _matmul(self, inputs, kernel): - if inputs.shape.ndims <= 2: - return standard_ops.matmul(inputs, kernel) - # To handle broadcasting, we must use `tensordot`. - return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) - if input_shape[-1].value is None: - raise ValueError( - "The innermost dimension of input_shape must be defined, " - "but saw: {}".format(input_shape)) - return input_shape[:-1].concatenate(self.units) - - -class DenseReparameterization(_DenseVariational): - """Densely-connected layer class with reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the reparameterization estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseReparameterization( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(DenseReparameterization, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - - def _apply_variational_kernel(self, inputs): - self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( - self.kernel_posterior) - self.kernel_posterior_affine = None - self.kernel_posterior_affine_tensor = None - return self._matmul(inputs, self.kernel_posterior_tensor) - - -def dense_reparameterization( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Densely-connected layer with reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the reparameterization estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_reparameterization( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = DenseReparameterization( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class DenseLocalReparameterization(_DenseVariational): - """Densely-connected layer class with local reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the local reparameterization estimator [1], which performs a - Monte Carlo approximation of the distribution on the hidden units - induced by the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseLocalReparameterization( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseLocalReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses local reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Variational Dropout and the Local Reparameterization Trick." - Diederik P. Kingma, Tim Salimans, Max Welling. - Neural Information Processing Systems, 2015. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(DenseLocalReparameterization, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`DenseLocalReparameterization` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format(self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), - scale=standard_ops.sqrt(self._matmul( - standard_ops.square(inputs), - standard_ops.square(self.kernel_posterior.distribution.scale)))) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - return self.kernel_posterior_affine_tensor - - -def dense_local_reparameterization( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Densely-connected layer with local reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the local reparameterization estimator [1], which performs a - Monte Carlo approximation of the distribution on the hidden units - induced by the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_local_reparameterization( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_local_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses local reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Variational Dropout and the Local Reparameterization Trick." - Diederik P. Kingma, Tim Salimans, Max Welling. - Neural Information Processing Systems, 2015. - """ - layer = DenseLocalReparameterization( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class DenseFlipout(_DenseVariational): - """Densely-connected layer class with Flipout estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the Flipout estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. Flipout uses roughly twice as many floating point operations - as the reparameterization estimator but has the advantage of - significantly lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseFlipout( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(DenseFlipout, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - self.seed = seed - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`DenseFlipout` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format(self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), - scale=self.kernel_posterior.distribution.scale) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - - input_shape = array_ops.shape(inputs) - batch_shape = input_shape[:-1] - - sign_input = layers_util.random_sign( - input_shape, - dtype=inputs.dtype, - seed=self.seed) - sign_output = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(self.units, 0)], 0), - dtype=inputs.dtype, - seed=distribution_util.gen_new_seed( - self.seed, salt="dense_flipout")) - perturbed_inputs = self._matmul( - inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output - - outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc) - outputs += perturbed_inputs - return outputs - - -def dense_flipout( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Densely-connected layer with Flipout estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the Flipout estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. Flipout uses roughly twice as many floating point operations - as the reparameterization estimator but has the advantage of - significantly lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_flipout( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = DenseFlipout( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py deleted file mode 100644 index 8c1fb203f7328e8260e49b4326d813fbe133613e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_util.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for probabilistic layers. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib - - -def default_loc_scale_fn( - is_singular=False, - loc_initializer=init_ops.random_normal_initializer(stddev=0.1), - untransformed_scale_initializer=init_ops.random_normal_initializer( - mean=-3., stddev=0.1), - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. - - This function produces a closure which produces `loc`, `scale` using - `tf.get_variable`. The closure accepts the following arguments: - - dtype: Type of parameter's event. - shape: Python `list`-like representing the parameter's event shape. - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` indicating if `scale is None`. Default: `False`. - loc_initializer: Initializer function for the `loc` parameters. - The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. Default value: `tf.random_normal_initializer(mean=-3., - stddev=0.1)`. This implies the softplus transformed result has mean - approximately `0.05` and std. deviation approximately `0.005`. - loc_regularizer: Regularizer function for the `loc` parameters. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. The default (`None`) is to use the `tf.get_variable` default. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. The default - (`None`) is to use the `tf.get_variable` default. - - Returns: - default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` - parameters from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates `loc`, `scale` parameters.""" - loc = add_variable_fn( - name=name + "_loc", - shape=shape, - initializer=loc_initializer, - regularizer=loc_regularizer, - constraint=loc_constraint, - dtype=dtype, - trainable=trainable) - if is_singular: - return loc, None - untransformed_scale = add_variable_fn( - name=name + "_untransformed_scale", - shape=shape, - initializer=untransformed_scale_initializer, - regularizer=untransformed_scale_regularizer, - constraint=untransformed_scale_constraint, - dtype=dtype, - trainable=trainable) - scale = (np.finfo(dtype.as_numpy_dtype).eps + - nn_ops.softplus(untransformed_scale)) - return loc, scale - return _fn - - -def default_mean_field_normal_fn( - is_singular=False, - loc_initializer=None, - untransformed_scale_initializer=None, - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Creates a function to build Normal distributions with trainable params. - - This function produces a closure which produces `tf.distributions.Normal` - parameterized by a loc` and `scale` each created using `tf.get_variable`. The - produced closure accepts the following arguments: - - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` if `True`, forces the special case limit of - `scale->0`, i.e., a `Deterministic` distribution. - loc_initializer: Initializer function for the `loc` parameters. - If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - loc_regularizer: Regularizer function for the `loc` parameters. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. - - Returns: - make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` - using from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - loc_scale_fn_ = default_loc_scale_fn( - is_singular, - loc_initializer, - untransformed_scale_initializer, - loc_regularizer, - untransformed_scale_regularizer, - loc_constraint, - untransformed_scale_constraint) - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates multivariate `Deterministic` or `Normal` distribution.""" - loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) - if scale is None: - dist = deterministic_lib.Deterministic(loc=loc) - else: - dist = normal_lib.Normal(loc=loc, scale=scale) - reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] - return independent_lib.Independent( - dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) - return _fn - - -def random_sign(shape, dtype=dtypes.float32, seed=None): - """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" - random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, - dtype=dtypes.int32, - seed=seed) - return math_ops.cast(2 * random_bernoulli - 1, dtype) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py deleted file mode 100644 index 0424b6952bc89ce7fe5b00b0135c9a5fe1faa8cf..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for Markov Chain Monte Carlo (MCMC) sampling. - -@@effective_sample_size -@@potential_scale_reduction -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops import sample_stats -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - -__all__ = [ - "effective_sample_size", - "potential_scale_reduction", -] - - -def effective_sample_size(states, - filter_threshold=0., - filter_beyond_lag=None, - name=None): - """Estimate a lower bound on effective sample size for each independent chain. - - Roughly speaking, "effective sample size" (ESS) is the size of an iid sample - with the same variance as `state`. - - More precisely, given a stationary sequence of possibly correlated random - variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number - such that - - ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.``` - - If the sequence is uncorrelated, `ESS = N`. In general, one should expect - `ESS <= N`, with more highly correlated sequences having smaller `ESS`. - - #### Example of using ESS to estimate standard error. - - ``` - tfd = tf.contrib.distributions - tfb = tf.contrib.bayesflow - - target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) - - # Get 1000 states from one chain. - states = tfb.hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target.log_prob, - current_state=tf.constant([0., 0.]), - step_size=0.05, - num_leapfrog_steps=20, - num_burnin_steps=200) - states.shape - ==> (1000, 2) - - ess = effective_sample_size(states) - ==> Shape (2,) Tensor - - mean, variance = tf.nn.moments(states, axis=0) - standard_error = tf.sqrt(variance / ess) - ``` - - Some math shows that, with `R_k` the auto-correlation sequence, - `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have - - ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]``` - - This function estimates the above by first estimating the auto-correlation. - Since `R_k` must be estimated using only `N - k` samples, it becomes - progressively noisier for larger `k`. For this reason, the summation over - `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many - MCMC methods generate chains where `R_k > 0`, a reasonable critera is to - truncate at the first index where the estimated auto-correlation becomes - negative. - - The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to - remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning - terms are removed if they were to be filtered under the `filter_beyond_lag` OR - `filter_threshold` criteria. - - Args: - states: `Tensor` or list of `Tensor` objects. Dimension zero should index - identically distributed states. - filter_threshold: `Tensor` or list of `Tensor` objects. - Must broadcast with `state`. The auto-correlation sequence is truncated - after the first appearance of a term less than `filter_threshold`. - Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, - setting to any number less than `-1` has the same effect. - filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be - `int`-like and scalar valued. The auto-correlation sequence is truncated - to this length. Setting to `None` means we do not filter based on number - of lags. - name: `String` name to prepend to created ops. - - Returns: - ess: `Tensor` or list of `Tensor` objects. The effective sample size of - each component of `states`. Shape will be `states.shape[1:]`. - - Raises: - ValueError: If `states` and `filter_threshold` or `states` and - `filter_beyond_lag` are both lists with different lengths. - """ - states_was_list = _is_list_like(states) - - # Convert all args to lists. - if not states_was_list: - states = [states] - - filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag, - "filter_beyond_lag") - filter_threshold = _broadcast_maybelist_arg(states, filter_threshold, - "filter_threshold") - - # Process items, one at a time. - with ops.name_scope(name, "effective_sample_size"): - ess_list = [ - _effective_sample_size_single_state(s, ml, mlt) - for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold) - ] - - if states_was_list: - return ess_list - return ess_list[0] - - -def _effective_sample_size_single_state(states, filter_beyond_lag, - filter_threshold): - """ESS computation for one single Tensor argument.""" - - with ops.name_scope( - "effective_sample_size_single_state", - values=[states, filter_beyond_lag, filter_threshold]): - - states = ops.convert_to_tensor(states, name="states") - dt = states.dtype - - # filter_beyond_lag == None ==> auto_corr is the full sequence. - auto_corr = sample_stats.auto_correlation( - states, axis=0, max_lags=filter_beyond_lag) - if filter_threshold is not None: - filter_threshold = ops.convert_to_tensor( - filter_threshold, dtype=dt, name="filter_threshold") - # Get a binary mask to zero out values of auto_corr below the threshold. - # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, - # mask[i, ...] = 0, otherwise. - # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] - # Building step by step, - # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. - # Step 1: mask = [False, False, True, False] - mask = auto_corr < filter_threshold - # Step 2: mask = [0, 0, 1, 1] - mask = math_ops.cast(mask, dtype=dt) - # Step 3: mask = [0, 0, 1, 2] - mask = math_ops.cumsum(mask, axis=0) - # Step 4: mask = [1, 1, 0, 0] - mask = math_ops.maximum(1. - mask, 0.) - auto_corr *= mask - - # With R[k] := auto_corr[k, ...], - # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} - # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) - # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} - # where M is the filter_beyond_lag truncation point chosen above. - - # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total - # ndims the same as auto_corr - n = _axis_size(states, axis=0) - k = math_ops.range(0., _axis_size(auto_corr, axis=0)) - nk_factor = (n - k) / n - if auto_corr.shape.ndims is not None: - new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) - else: - new_shape = array_ops.concat( - ([-1], - array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)), - axis=0) - nk_factor = array_ops.reshape(nk_factor, new_shape) - - return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0)) - - -def potential_scale_reduction(chains_states, - independent_chain_ndims=1, - name=None): - """Gelman and Rubin's potential scale reduction factor for chain convergence. - - Given `N > 1` states from each of `C > 1` independent chains, the potential - scale reduction factor, commonly referred to as R-hat, measures convergence of - the chains (to the same target) by testing for equality of means. - Specifically, R-hat measures the degree to which variance (of the means) - between chains exceeds what one would expect if the chains were identically - distributed. See [1], [2]. - - Some guidelines: - - * The initial state of the chains should be drawn from a distribution - overdispersed with respect to the target. - * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1. - Before that, R-hat > 1 (except in pathological cases, e.g. if the chain - paths were identical). - * The above holds for any number of chains `C > 1`. Increasing `C` does - improves effectiveness of the diagnostic. - * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of - course this is problem depedendent. See [2]. - * R-hat only measures non-convergence of the mean. If higher moments, or other - statistics are desired, a different diagnostic should be used. See [2]. - - #### Examples - - Diagnosing convergence by monitoring 10 chains that each attempt to - sample from a 2-variate normal. - - ```python - tfd = tf.contrib.distributions - tfb = tf.contrib.bayesflow - - target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) - - # Get 10 (2x) overdispersed initial states. - initial_state = target.sample(10) * 2. - ==> (10, 2) - - # Get 1000 samples from the 10 independent chains. - chains_states, _ = tfb.hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target.log_prob, - current_state=initial_state, - step_size=0.05, - num_leapfrog_steps=20, - num_burnin_steps=200) - chains_states.shape - ==> (1000, 10, 2) - - rhat = tfb.mcmc_diagnostics.potential_scale_reduction( - chains_states, independent_chain_ndims=1) - - # The second dimension needed a longer burn-in. - rhat.eval() - ==> [1.05, 1.3] - ``` - - To see why R-hat is reasonable, let `X` be a random variable drawn uniformly - from the combined states (combined over all chains). Then, in the limit - `N, C --> infinity`, with `E`, `Var` denoting expectation and variance, - - ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].``` - - Using the law of total variance, the numerator is the variance of the combined - states, and the denominator is the total variance minus the variance of the - the individual chain means. If the chains are all drawing from the same - distribution, they will have the same mean, and thus the ratio should be one. - - [1] "Inference from Iterative Simulation Using Multiple Sequences" - Andrew Gelman and Donald B. Rubin - Statist. Sci. Volume 7, Number 4 (1992), 457-472. - [2] "General Methods for Monitoring Convergence of Iterative Simulations" - Stephen P. Brooks and Andrew Gelman - Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4. - - Args: - chains_states: `Tensor` or Python `list` of `Tensor`s representing the - state(s) of a Markov Chain at each result step. The `ith` state is - assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. - Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. - Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent - chains to be tested for convergence to the same target. - The remaining dimensions, `A`, can have any shape (even empty). - independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the - number of giving the number of dimensions, from `dim = 1` to `dim = D`, - holding independent chain results to be tested for convergence. - name: `String` name to prepend to created ops. Default: - `potential_scale_reduction`. - - Returns: - `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for - the state(s). Same `dtype` as `state`, and shape equal to - `state.shape[1 + independent_chain_ndims:]`. - - Raises: - ValueError: If `independent_chain_ndims < 1`. - """ - chains_states_was_list = _is_list_like(chains_states) - if not chains_states_was_list: - chains_states = [chains_states] - - # tensor_util.constant_value returns None iff a constant value (as a numpy - # array) is not efficiently computable. Therefore, we try constant_value then - # check for None. - icn_const_ = tensor_util.constant_value( - ops.convert_to_tensor(independent_chain_ndims)) - if icn_const_ is not None: - independent_chain_ndims = icn_const_ - if icn_const_ < 1: - raise ValueError( - "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format( - independent_chain_ndims)) - - with ops.name_scope(name, "potential_scale_reduction"): - rhat_list = [ - _potential_scale_reduction_single_state(s, independent_chain_ndims) - for s in chains_states - ] - - if chains_states_was_list: - return rhat_list - return rhat_list[0] - - -def _potential_scale_reduction_single_state(state, independent_chain_ndims): - """potential_scale_reduction for one single state `Tensor`.""" - with ops.name_scope( - "potential_scale_reduction_single_state", - values=[state, independent_chain_ndims]): - # We assume exactly one leading dimension indexes e.g. correlated samples - # from each Markov chain. - state = ops.convert_to_tensor(state, name="state") - sample_ndims = 1 - - sample_axis = math_ops.range(0, sample_ndims) - chain_axis = math_ops.range(sample_ndims, - sample_ndims + independent_chain_ndims) - sample_and_chain_axis = math_ops.range( - 0, sample_ndims + independent_chain_ndims) - - n = _axis_size(state, sample_axis) - m = _axis_size(state, chain_axis) - - # In the language of [2], - # B / n is the between chain variance, the variance of the chain means. - # W is the within sequence variance, the mean of the chain variances. - b_div_n = _reduce_variance( - math_ops.reduce_mean(state, sample_axis, keepdims=True), - sample_and_chain_axis, - biased=False) - w = math_ops.reduce_mean( - _reduce_variance(state, sample_axis, keepdims=True, biased=True), - sample_and_chain_axis) - - # sigma^2_+ is an estimate of the true variance, which would be unbiased if - # each chain was drawn from the target. c.f. "law of total variance." - sigma_2_plus = w + b_div_n - - return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) - - -# TODO(b/72873233) Move some variant of this to sample_stats. -def _reduce_variance(x, axis=None, biased=True, keepdims=False): - with ops.name_scope("reduce_variance"): - x = ops.convert_to_tensor(x, name="x") - mean = math_ops.reduce_mean(x, axis=axis, keepdims=True) - biased_var = math_ops.reduce_mean( - math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims) - if biased: - return biased_var - n = _axis_size(x, axis) - return (n / (n - 1.)) * biased_var - - -def _axis_size(x, axis=None): - """Get number of elements of `x` in `axis`, as type `x.dtype`.""" - if axis is None: - return math_ops.cast(array_ops.size(x), x.dtype) - return math_ops.cast( - math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype) - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _broadcast_maybelist_arg(states, secondary_arg, name): - """Broadcast a listable secondary_arg to that of states.""" - if _is_list_like(secondary_arg): - if len(secondary_arg) != len(states): - raise ValueError("Argument `%s` was a list of different length ({}) than " - "`states` ({})".format(name, len(states))) - else: - secondary_arg = [secondary_arg] * len(states) - - return secondary_arg diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py deleted file mode 100644 index dc1ac68ce009fa46d6c05a3200a29d9fdf245707..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step. - -@@evolve -@@uniform_random_proposal -@@normal_random_proposal -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops - -__all__ = [ - 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', -] - - -def _single_iteration(current_state, current_log_density, - log_unnormalized_prob_fn, proposal_fn, seed=None, - name='None'): - """Performs a single Metropolis-Hastings step. - - Args: - current_state: Float-like `Tensor` (i.e., `dtype` is either - `tf.float16`, `tf.float32` or `tf.float64`) of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` - callables. - current_log_density: Float-like `Tensor` with `dtype` and shape equivalent - to `log_unnormalized_prob_fn(current_state)`, i.e., matching the result of - `log_unnormalized_prob_fn` invoked at `current_state`. - log_unnormalized_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, where `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair is a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric (e.g., random walk, where the proposal is either - normal or uniform centered at `current_state`), i.e., - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to `None` instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: Python `str` name prefix for ops managed by this function. - - Returns: - next_state: `Tensor` with `dtype` and shape matching `current_state`. - Created by propagating the chain by one step, starting from - `current_state`. - next_log_density: `Tensor` with `dtype` and shape matching - `current_log_density`, which is equal to the value of the unnormalized - `log_unnormalized_prob_fn` computed at `next_state`. - log_accept_ratio: `Tensor` with `dtype` and shape matching - `current_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio used in generating the `next_state`. - """ - - with ops.name_scope(name, 'single_iteration', [current_state]): - # The proposed state and the log of the corresponding Hastings ratio. - proposal_state, log_transit_ratio = proposal_fn(current_state) - - # If the log ratio is None, assume that the transitions are symmetric, - # i.e., Prob(Current -> Proposed) = Prob(Proposed -> Current). - if log_transit_ratio is None: - log_transit_ratio = 0. - - # Log-density of the proposal state. - proposal_log_density = log_unnormalized_prob_fn(proposal_state) - - # Ops to compute the log of the acceptance ratio. Recall that the - # acceptance ratio is: [Prob(Proposed) / Prob(Current)] * - # [Prob(Proposed -> Current) / Prob(Current -> Proposed)]. The log of the - # second term is the log_transit_ratio. - with ops.name_scope('accept_reject'): - # The log of the acceptance ratio. - log_accept_ratio = (proposal_log_density - current_log_density - + log_transit_ratio) - - # A proposal is accepted or rejected depending on the acceptance ratio. - # If the acceptance ratio is greater than 1 then it is always accepted. - # If the acceptance ratio is less than 1 then the proposal is accepted - # with probability = acceptance ratio. As we are working in log space to - # prevent over/underflows, this logic is expressed in log terms below. - # If a proposal is accepted we place a True in the acceptance state - # tensor and if it is to be rejected we place a False. - # The log_draws below have to be compared to the log_accept_ratio so we - # make sure that they have the same data type. - log_draws = math_ops.log(random_ops.random_uniform( - array_ops.shape(current_log_density), seed=seed, - dtype=log_accept_ratio.dtype)) - is_proposal_accepted = log_draws < log_accept_ratio - - # The acceptance state decides which elements of the current state are to - # be replaced with the corresponding elements in the proposal state. - with ops.name_scope(name, 'metropolis_single_step', - [current_state, current_log_density]): - next_log_density = array_ops.where(is_proposal_accepted, - proposal_log_density, - current_log_density) - next_state = array_ops.where(is_proposal_accepted, proposal_state, - current_state) - - return next_state, next_log_density, log_accept_ratio - - -def evolve(initial_sample, - initial_log_density, - initial_log_accept_ratio, - log_unnormalized_prob_fn, - proposal_fn, - n_steps=1, - seed=None, - name=None): - """Performs `n_steps` of the Metropolis-Hastings update. - - Given a probability density function, `f(x)` and a proposal scheme which - generates new points from old, this `Op` returns a tensor - which may be used to generate approximate samples from the target distribution - using the Metropolis-Hastings algorithm. These samples are from a Markov chain - whose equilibrium distribution matches the target distribution. - - The probability distribution may have an unknown normalization constan. - We parameterize the probability density as follows: - ``` - f(x) = exp(L(x) + constant) - ``` - Here `L(x)` is any continuous function with an (possibly unknown but finite) - upper bound, i.e. there exists a number beta such that - `L(x)< beta < infinity` for all x. The constant is the normalization needed - to make `f(x)` a probability density (as opposed to just a finite measure). - - Although `initial_sample` can be arbitrary, a poor choice may result in a - slow-to-mix chain. In many cases the best choice is the one that maximizes - the target density, i.e., choose `initial_sample` such that - `f(initial_sample) >= f(x)` for all `x`. - - - If the support of the distribution is a strict subset of R^n (but of non zero - measure), then the unnormalized log-density `L(x)` should return `-infinity` - outside the support domain. This effectively forces the sampler to only - explore points in the regions of finite support. - - Usage: - This function is meant to be wrapped up with some of the common proposal - schemes (e.g. random walk, Langevin diffusion etc) to produce a more user - friendly interface. However, it may also be used to create bespoke samplers. - - The following example, demonstrates the use to generate a 1000 uniform random - walk Metropolis samplers run in parallel for the normal target distribution. - ```python - n = 3 # dimension of the problem - - # Generate 1000 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = tf.get_variable( - 'state',initializer=tf.random_normal([1000, n], mean=3.0, - dtype=tf.float64, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = tf.get_variable( - 'state_log_density', initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = tf.get_variable( - 'log_acceptance_ratio', initializer=tf.zeros([1000], dtype=tf.float64)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = tf.initialize_all_variables() - with tf.Session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps and print out the mean across - # the chains every 100 iterations. - for n_iter in range(10): - # Executing the stepper advances the chain to the next state. - sess.run(stepper) - # Print out the current value of the mean(sample) for every dimension. - print(np.mean(sess.run(state), 0)) - # Estimated covariance matrix - samples = sess.run(state) - print('') - print(np.cov(samples, rowvar=False)) - ``` - - Args: - initial_sample: A float-like `tf.Variable` of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` - callables. - initial_log_density: Float-like `tf.Variable` with `dtype` and shape - equivalent to `log_unnormalized_prob_fn(initial_sample)`, i.e., matching - the result of `log_unnormalized_prob_fn` invoked at `current_state`. - initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching - `initial_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio after propagating the chain for `n_steps`. - log_unnormalized_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, here `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair should be a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric, i.e. - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to None instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - n_steps: A positive `int` or a scalar `int32` tensor. Sets the number of - iterations of the chain. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - forward_step: an `Op` to step the Markov chain forward for `n_steps`. - """ - - with ops.name_scope(name, 'metropolis_hastings', [initial_sample]): - current_state = initial_sample - current_log_density = initial_log_density - log_accept_ratio = initial_log_accept_ratio - - # Stop condition for the while_loop - def stop_condition(i, _): - return i < n_steps - - def step(i, loop_vars): - """Wrap `_single_iteration` for `while_loop`.""" - state = loop_vars[0] - state_log_density = loop_vars[1] - return i + 1, list(_single_iteration(state, state_log_density, - log_unnormalized_prob_fn, - proposal_fn, seed=seed)) - - loop_vars = [current_state, current_log_density, log_accept_ratio] - # Build an `Op` to evolve the Markov chain for `n_steps` - (_, [end_state, end_log_density, end_log_acceptance]) = ( - control_flow_ops.while_loop( - stop_condition, step, - (0, loop_vars), - parallel_iterations=1, swap_memory=1)) - - forward_step = control_flow_ops.group( - state_ops.assign(current_log_density, end_log_density), - state_ops.assign(current_state, end_state), - state_ops.assign(log_accept_ratio, end_log_acceptance)) - - return forward_step - - -def uniform_random_proposal(step_size=1., - seed=None, - name=None): - """Returns a callable that adds a random uniform tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). It adds a - sample from a random uniform distribution drawn from [-stepsize, stepsize] - to its input. It also returns the log of the ratio of the probability of - moving from the input point to the proposed point, but since this log ratio is - identically equal to 0 (because the probability of drawing a value `x` from - the symmetric uniform distribution is the same as the probability of drawing - `-x`), it simply returns None for the second element of the returned tuple. - - Args: - step_size: A positive `float` or a scalar tensor of real dtype - controlling the scale of the uniform distribution. - If step_size = a, then draws are made uniformly from [-a, a]. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, 'uniform_random_proposal', [step_size]): - def proposal_fn(input_state, name=None): - """Adds a uniform perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') - return input_state + random_ops.random_uniform( - array_ops.shape(input_state), - minval=-step_size, - maxval=step_size, - seed=seed), None - return proposal_fn - - -def normal_random_proposal(scale=1., - seed=None, - name=None): - """Returns a callable that adds a random normal tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). The callable - adds a sample from a normal distribution with the supplied standard deviation - and zero mean to its input argument (called the proposal point). - The callable returns a tuple with the proposal point as the first element. - The second element is identically `None`. It is included so the callable is - compatible with the expected signature of the proposal scheme argument in the - `metropolis_hastings` function. A value of `None` indicates that the - probability of going from the input point to the proposal point is equal to - the probability of going from the proposal point to the input point. - - Args: - scale: A positive `float` or a scalar tensor of any real dtype controlling - the scale of the normal distribution. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, 'normal_random_proposal', [scale]): - def proposal_fn(input_state, name=None): - """Adds a normal perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') - return input_state + random_ops.random_normal( - array_ops.shape(input_state), - mean=0., - stddev=scale, - seed=seed), None - return proposal_fn diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 5770bcdd706723394bb06196d24aeb32b8b8491a..68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers. - -See the @{$python/contrib.bayesflow.monte_carlo} guide. -""" +"""Monte Carlo integration and helpers.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 985177e897f443989e466d1a498c461a30aeb5cb..032b859d469ee5039e08e4af4c2f4ebf35c2ff19 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -44,15 +44,13 @@ def expectation_importance_sampler(f, n=None, seed=None, name='expectation_importance_sampler'): - r"""Monte Carlo estimate of `E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]`. + r"""Monte Carlo estimate of \\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\). - With `p(z) := exp{log_p(z)}`, this `Op` returns + With \\(p(z) := exp^{log_p(z)}\\), this `Op` returns - ``` - n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q, - \approx E_q[ f(Z) p(Z) / q(Z) ] - = E_p[f(Z)] - ``` + \\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,\\) + \\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\) + \\(= E_p[f(Z)]\\) This integral is done in log-space with max-subtraction to better handle the often extreme values that `f(z) p(z) / q(z)` can take on. @@ -95,9 +93,9 @@ def expectation_importance_sampler(f, log_values = log_f_z + log_p_z - q_log_prob_z return _logspace_mean(log_values) - # With f_plus(z) = max(0, f(z)), f_minus(z) = max(0, -f(z)), - # E_p[f(Z)] = E_p[f_plus(Z)] - E_p[f_minus(Z)] - # = E_p[f_plus(Z) + 1] - E_p[f_minus(Z) + 1] + # With \\(f_{plus}(z) = max(0, f(z)), f_{minus}(z) = max(0, -f(z))\\), + # \\(E_p[f(Z)] = E_p[f_{plus}(Z)] - E_p[f_{minus}(Z)]\\) + # \\( = E_p[f_{plus}(Z) + 1] - E_p[f_{minus}(Z) + 1]\\) # Without incurring bias, 1 is added to each to prevent zeros in logspace. # The logarithm is approximately linear around 1 + epsilon, so this is good # for small values of 'z' as well. @@ -121,14 +119,12 @@ def expectation_importance_sampler_logspace( name='expectation_importance_sampler_logspace'): r"""Importance sampling with a positive function, in log-space. - With `p(z) := exp{log_p(z)}`, and `f(z) = exp{log_f(z)}`, this `Op` - returns + With \\(p(z) := exp^{log_p(z)}\\), and \\(f(z) = exp{log_f(z)}\\), + this `Op` returns - ``` - Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q, - \approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ] - = Log[E_p[f(Z)]] - ``` + \\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,\\) + \\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\) + \\(= Log[E_p[f(Z)]]\\) This integral is done in log-space with max-subtraction to better handle the often extreme values that `f(z) p(z) / q(z)` can take on. @@ -196,13 +192,11 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of `E_p[f(X)]`. + """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., - ```none - E_p[f(X)] approx= m**-1 sum_i^m f(x_j), x_j ~iid p(X) - ``` + \\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j), x_j\ ~iid\ p(X)\\) where: @@ -216,8 +210,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, parameterless distribution (e.g., `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and expectation, i.e., - `grad[ Avg{ s_i : i=1...n } ] = Avg{ grad[s_i] : i=1...n }` where - `S_n = Avg{s_i}` and `s_i = f(x_i), x_i ~ p`. + grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n } where + S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\). However, if p is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of non-reparameterized distributions. @@ -296,7 +290,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Args: f: Python callable which can return `f(samples)`. samples: `Tensor` of samples used to form the Monte-Carlo approximation of - `E_p[f(X)]`. A batch of samples should be indexed by `axis` dimensions. + \\(E_p[f(X)]\\). A batch of samples should be indexed by `axis` + dimensions. log_prob: Python callable which can return `log_prob(samples)`. Must correspond to the natural-logarithm of the pdf/pmf of each sample. Only required/used if `use_reparametrization=False`. @@ -316,7 +311,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Returns: approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation - of `E_p[f(X)]`. + of \\(E_p[f(X)]\\). Raises: ValueError: if `f` is not a Python `callable`. @@ -328,7 +323,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, if not callable(f): raise ValueError('`f` must be a callable function.') if use_reparametrization: - return math_ops.reduce_mean(f(samples), axis=axis, keep_dims=keep_dims) + return math_ops.reduce_mean(f(samples), axis=axis, keepdims=keep_dims) else: if not callable(log_prob): raise ValueError('`log_prob` must be a callable function.') @@ -348,7 +343,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, # "Is there a floating point value of x, for which x-x == 0 is false?" # http://stackoverflow.com/q/2686644 fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx). - return math_ops.reduce_mean(fx, axis=axis, keep_dims=keep_dims) + return math_ops.reduce_mean(fx, axis=axis, keepdims=keep_dims) def _sample_mean(values): diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py deleted file mode 100644 index 7786656398e3c87704227be95b3cd23a38785249..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An optimizer module for stochastic gradient Langevin dynamics.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import training_ops - - -class SGLDOptimizer(optimizer.Optimizer): - """An optimizer module for stochastic gradient Langevin dynamics. - - This implements the preconditioned Stochastic Gradient Langevin Dynamics - optimizer [1]. The optimization variable is regarded as a sample from the - posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in - each dimension according to RMSProp [2]. - - Note: If a prior is included in the loss, it should be scaled by - `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches - in the data. I.e., it should be divided by the `num_pseudo_batches` term - described below. - - [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural - Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. - ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 - [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf - - Args: - learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the - optimizer. Must be tuned to the specific function being minimized. - preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential - decay rate of the rescaling of the preconditioner (RMSprop). (This is - "alpha" in [1]). Should be smaller than but nearly `1` to approximate - sampling from the posterior. (Default: `0.95`) - num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of - minibatches in the data set. Trades off noise and prior with the SGD - likelihood term. Note: Assumes the loss is taken as the mean over a - minibatch. Otherwise if the sum was taken, divide this number by the - batch size. (Default: `1`) - burnin: Scalar `int`-like `Tensor`. The number of iterations to collect - gradient statistics to update the preconditioner before starting to draw - noisy samples. (Default: `25`) - diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of - the preconditioner to prevent the preconditioner from degenerating. - (Default: `1e-8`) - name: Python `str` describing ops managed by this function. - (Default: `"SGLDOptimizer"`) - variable_scope: Variable scope used for calls to `tf.get_variable`. - If `None`, a new variable scope is created using name - `ops.get_default_graph().unique_name(name or default_name)`. - - Raises: - InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in - `(0,1]`. - """ - - def __init__(self, - learning_rate, - preconditioner_decay_rate=0.95, - num_pseudo_batches=1, - burnin=25, - diagonal_bias=1e-8, - name=None, - variable_scope=None): - default_name = 'SGLDOptimizer' - with ops.name_scope(name, default_name, [ - learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, - diagonal_bias - ]): - if variable_scope is None: - var_scope_name = ops.get_default_graph().unique_name( - name or default_name) - with varscope_ops.variable_scope(var_scope_name) as scope: - self._variable_scope = scope - else: - self._variable_scope = variable_scope - - self._preconditioner_decay_rate = ops.convert_to_tensor( - preconditioner_decay_rate, name='preconditioner_decay_rate') - self._num_pseudo_batches = ops.convert_to_tensor( - num_pseudo_batches, name='num_pseudo_batches') - self._burnin = ops.convert_to_tensor(burnin, name='burnin') - self._diagonal_bias = ops.convert_to_tensor( - diagonal_bias, name='diagonal_bias') - self._learning_rate = ops.convert_to_tensor( - learning_rate, name='learning_rate') - - with varscope_ops.variable_scope(self._variable_scope): - self._counter = varscope_ops.get_variable( - 'counter', initializer=0, trainable=False) - - self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._preconditioner_decay_rate, - message='`preconditioner_decay_rate` must be non-negative'), - check_ops.assert_less_equal( - self._preconditioner_decay_rate, - 1., - message='`preconditioner_decay_rate` must be at most 1.'), - ], self._preconditioner_decay_rate) - - self._num_pseudo_batches = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._num_pseudo_batches, - 0, - message='`num_pseudo_batches` must be greater than zero') - ], self._num_pseudo_batches) - - self._burnin = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin, message='`burnin` must be non-negative'), - check_ops.assert_integer( - self._burnin, message='`burnin` must be an integer') - ], self._burnin) - - self._diagonal_bias = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._diagonal_bias, - message='`diagonal_bias` must be non-negative') - ], self._diagonal_bias) - - super(SGLDOptimizer, self).__init__(use_locking=False, - name=name or default_name) - - def _create_slots(self, var_list): - for v in var_list: - init_rms = init_ops.ones_initializer(dtype=v.dtype) - self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), - v.dtype, 'rms', self._name) - - def _prepare(self): - # We need to put the conversion and check here because a user will likely - # want to decay the learning rate dynamically. - self._learning_rate_tensor = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._learning_rate, message='`learning_rate` must be non-negative') - ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) - self._decay_tensor = ops.convert_to_tensor( - self._preconditioner_decay_rate, name='preconditioner_decay_rate') - - super(SGLDOptimizer, self)._prepare() - - def _apply_dense(self, grad, var): - rms = self.get_slot(var, 'rms') - - with ops.control_dependencies([ - self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, - var.dtype.base_dtype))]): - new_grad = self._apply_noisy_update(rms, grad) - - return training_ops.apply_gradient_descent( - var, - math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, - use_locking=self._use_locking).op - - def _apply_sparse(self, grad, var): - rms = self.get_slot(var, 'rms') - - with ops.control_dependencies([ - self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, - var.dtype.base_dtype))]): - new_grad = self._apply_noisy_update(rms, grad) - - return training_ops.apply_gradient_descent( - var, - math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, - use_locking=self._use_locking).op - - def _finish(self, update_ops, name_scope): - update_ops.append([self._counter.assign_add(1)]) - return control_flow_ops.group(*update_ops, name=name_scope) - - @property - def variable_scope(self): - """Variable scope of all calls to `tf.get_variable`.""" - return self._variable_scope - - def _apply_noisy_update(self, mom, grad): - # Compute and apply the gradient update following - # preconditioned Langevin dynamics - stddev = array_ops.where( - array_ops.squeeze(self._counter > self._burnin), - math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), - array_ops.zeros([], grad.dtype)) - - preconditioner = math_ops.rsqrt( - mom + math_ops.cast(self._diagonal_bias, grad.dtype)) - return ( - 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, - grad.dtype) + - random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * - stddev * math_ops.sqrt(preconditioner)) - - def _update_momentum(self, mom, grad, decay): - # Keep an exponentially weighted moving average of squared gradients. - # Not thread safe - return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py deleted file mode 100644 index ca3d75b5bfee093449026c7d1d62e3bdeff6b096..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions related to managing `tf.Variable`s. - -@@externalize_variables_as_args -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -from tensorflow.python.framework import ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.ops import variables as variables_ops - -__all__ = [ - "externalize_variables_as_args", -] - - -# Cause all warnings to always be triggered. -# Not having this means subsequent calls wont trigger the warning. -warnings.simplefilter("always") - - -def externalize_variables_as_args(fn, - fn_args=(), - ancestor_variables=None, - possible_ancestor_vars=None, - assert_variable_override=False, - name=None): - """"Converts variables within a callable into explicit args. - - Makes a new callable from `fn` which has arguments `list(fn_args) + - list(ancestor_variables)`. If `ancestor_variables` is not specified, it is - inferred by checking which of `possible_ancestor_vars` actually influences the - return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`). - By default `possible_ancestor_vars` is `tf.trainable_variables() + - tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`. - - #### Examples: - - ```python - num_samples = 2 - num_dims = 1 - dtype = np.float32 - - def foo(x): - x = tf.convert_to_tensor(x, dtype=dtype, name="x") - s = x.shape.as_list() - y = tf.get_variable( - name="y", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) - return x + y - - x = tf.constant(dtype([0.1, 0.2])) - - wrapped_foo, discovered_ancestor_variables = ( - externalize_variables_as_args(foo, [x])) - - new_x = dtype([[1.], [2.]]) - new_y = dtype([[3.], [4.]]) - new_result = wrapped_foo(new_x, new_y) - # ==> [[4.], [6.]] - - discovered_ancestor_variables == [tf.get_variable("y", dtype)] - # ==> [True] - ``` - - Args: - fn: Python callable which returns a `Tensor` and accepts `*fn_args`. - fn_args: Python list of args to `fn`. Represents dummy arguments passed to - `fn` to trace its execution; actual values are unimportant. These args are - only used to construct the output of `fn` and to resolve the ancestor - `tf.Variable`s. - Default value: `()` (i.e., `fn` takes no args). - ancestor_variables: Python list of `tf.Variable`s. When `None` the list is - expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing - the `ancestor_variables` the internal call to `fn` is avoided. - Default value: `None` (i.e., `tf.Variable` dependencies are discovered). - possible_ancestor_vars: Python list of possible `tf.Variable`s which might - be a dependency of computing `fn(*fn_args)`. - Default value: `None` (i.e., expanded as described above). - assert_variable_override: Python `bool` indicating that not finding a - `tf.Variable` in the override list is an exception. - Default value: `False` (i.e., missing a `Variable` triggers a `warning`). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "externalize_variables_as_args"). - - Returns: - wrapped_fn: Python callable taking arguments like - `*(list(fn_args) + discovered_ancestor_variables)`. - discovered_ancestor_variables: Python list of `tf.Variable`s known to be a - dependency of `fn(*fn_args)`. - - Raises: - ValueError: if `assert_variable_override` is `True` and `Variable` is - requested but not overridden. - """ - def _make_bypassing_custom_getter_fn(new_var_dict): - """Return dict value rather than what would otherwise be dict key.""" - def _custom_getter(getter, *args, **kwargs): - v = getter(*args, **kwargs) - new_v = new_var_dict.get(v, None) - if new_v is None: - msg = "Variable \"{}\" not found in bypass dict.".format(v) - if assert_variable_override: - raise ValueError(msg) - warnings.warn(msg) - return v - return new_v - return _custom_getter - - with ops.name_scope(name, "externalize_variables_as_args"): - if ancestor_variables is not None and not ancestor_variables: - return fn, () - if ancestor_variables is None: - y = fn(*fn_args) # Side-effect: adds trainable vars. - if possible_ancestor_vars is None: - possible_ancestor_vars = ( - variables_ops.trainable_variables() + - ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - # TODO(b/72873296): Add a dedicated op for identifying ancestors. - ancestors = [v for g, v - in zip(gradients_ops.gradients(y, possible_ancestor_vars), - possible_ancestor_vars) - if g is not None] - ancestor_variables = sorted(ancestors, key=lambda v: v.name) - n = len(fn_args) - def _fn(*args): - with ops.name_scope("wrapped_fn"): - vars_dict = dict( - (k, ops.convert_to_tensor( - v, dtype=k.dtype.base_dtype, name=k.op.name)) - for k, v in zip(ancestor_variables, args[n:])) - with varscope_ops.variable_scope( - varscope_ops.get_variable_scope(), - reuse=True, - custom_getter=_make_bypassing_custom_getter_fn(vars_dict)): - return fn(*args[:n]) - return _fn, ancestor_variables diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py deleted file mode 100644 index 4d5f0cfe9713a011b32c5aba8d429847d81f33e2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An optimizer module for constant stochastic gradient descent.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import training_ops - - -class VariationalSGDOptimizer(optimizer.Optimizer): - """An optimizer module for constant stochastic gradient descent. - - This implements an optimizer module for the constant stochastic gradient - descent algorithm [1]. The optimization variable is regarded as an - approximate sample from the posterior . - - Note: If a prior is included in the loss, it should be scaled by - `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches - in the data. I.e., it should be divided by the `num_pseudo_batches` term - described below. - - [1]: "Stochastic Gradient Descent as Approximate Bayesian Inference - Stephan Mandt, Matthew D. Hoffman, David M. Blei. - ArXiv:1704.04289, 2017. https://arxiv.org/abs/1704.04289 - - Args: - batch_size: Scalar `int`-like `Tensor`. The number of examples in a - minibatch in the data set. Note: Assumes the loss is taken as the mean - over a minibatch. Otherwise if the sum was taken set this to 1. - total_num_examples: Scalar `int`-like `Tensor`. The total number of examples - in the data set. - max_learning_rate: Scalar `float`-like `Tensor`. A maximum allowable - effective coordinate-wise learning rate. The algorithm scales down any - effective learning rate (i.e. after preconditioning) that is larger than - this. (Default: `1`) - preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential - decay rate of the rescaling of the preconditioner (RMSprop). (This is - "alpha" in [1]). Should be smaller than but nearly `1` to approximate - sampling from the posterior. (Default: `0.95`) - burnin: Scalar `int`-like `Tensor`. The number of iterations to collect - gradient statistics to update the preconditioner before starting to draw - noisy samples. (Default: `25`) - burnin_max_learning_rate: Scalar `float`-like `Tensor`. Maximum learning - rate to use during the burnin period. - (Default: `1e-8`) - use_single_learning_rate: Boolean Indicates whether one single learning - rate is used or coordinate_wise learning rates are used. - (Default: `False`) - name: Python `str` describing ops managed by this function. - (Default: `"VariationalSGDOptimizer"`) - variable_scope: Variable scope used for calls to `tf.get_variable`. - If `None`, a new variable scope is created using name - `ops.get_default_graph().unique_name(name or default_name)`. - - Raises: - InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in - `(0,1]`. - """ - - def __init__(self, - batch_size, - total_num_examples, - max_learning_rate=1.0, - preconditioner_decay_rate=0.95, - burnin=25, - burnin_max_learning_rate=1e-6, - use_single_learning_rate=False, - name=None, - variable_scope=None): - default_name = 'VariationalSGDOptimizer' - with ops.name_scope(name, default_name, [ - max_learning_rate, preconditioner_decay_rate, batch_size, burnin, - burnin_max_learning_rate - ]): - if variable_scope is None: - var_scope_name = ops.get_default_graph().unique_name( - name or default_name) - with varscope_ops.variable_scope(var_scope_name) as scope: - self._variable_scope = scope - else: - self._variable_scope = variable_scope - - self._preconditioner_decay_rate = ops.convert_to_tensor( - preconditioner_decay_rate, name='preconditioner_decay_rate') - self._batch_size = ops.convert_to_tensor(batch_size, name='batch_size') - self._total_num_examples = ops.convert_to_tensor( - total_num_examples, name='total_num_examples') - self._burnin = ops.convert_to_tensor(burnin, name='burnin') - self._burnin_max_learning_rate = ops.convert_to_tensor( - burnin_max_learning_rate, name='burnin_max_learning_rate') - self._max_learning_rate = ops.convert_to_tensor( - max_learning_rate, name='max_learning_rate') - self._use_single_learning_rate = use_single_learning_rate - - with varscope_ops.variable_scope(self._variable_scope): - self._counter = varscope_ops.get_variable( - 'counter', initializer=0, trainable=False) - - self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._preconditioner_decay_rate, - message='`preconditioner_decay_rate` must be non-negative'), - check_ops.assert_less_equal( - self._preconditioner_decay_rate, - 1., - message='`preconditioner_decay_rate` must be at most 1.'), - ], self._preconditioner_decay_rate) - - self._batch_size = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._batch_size, - 0, - message='`batch_size` must be greater than zero') - ], self._batch_size) - - self._total_num_examples = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._total_num_examples, - 0, - message='`total_num_examples` must be greater than zero') - ], self._total_num_examples) - - self._burnin = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin, message='`burnin` must be non-negative'), - check_ops.assert_integer( - self._burnin, message='`burnin` must be an integer') - ], self._burnin) - - self._burnin_max_learning_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin_max_learning_rate, - message='`burnin_max_learning_rate` must be non-negative') - ], self._burnin_max_learning_rate) - - self._max_learning_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._max_learning_rate, - message='`max_learning_rate` must be non-negative') - ], self._max_learning_rate) - - super(VariationalSGDOptimizer, self).__init__( - use_locking=False, name=name or default_name) - - def _create_slots(self, var_list): - for v in var_list: - init_moment = init_ops.zeros_initializer(dtype=v.dtype) - self._get_or_make_slot_with_initializer( - v, init_moment, v.get_shape(), v.dtype, 'first_moment', self._name) - self._get_or_make_slot_with_initializer( - v, init_moment, v.get_shape(), v.dtype, 'second_moment', self._name) - - def _prepare(self): - self._decay_tensor = ops.convert_to_tensor( - self._preconditioner_decay_rate, name='preconditioner_decay_rate') - self._batch_size_tensor = ops.convert_to_tensor( - self._batch_size, name='batch_size_tensor') - - super(VariationalSGDOptimizer, self)._prepare() - - def _get_coordinatewise_learning_rate(self, grad, var): - # Compute the learning rate using a moving average for the diagonal of BB^T - avg_first = self.get_slot(var, 'first_moment') - avg_second = self.get_slot(var, 'second_moment') - decay_tensor = math_ops.cast(self._decay_tensor, var.dtype) - batch_size = math_ops.cast(self._batch_size_tensor, var.dtype) - - # Create an estimator for the moving average of gradient mean and variance - # via Welford's algorithm - if isinstance(grad, ops.Tensor): - delta = grad - avg_first - first_moment_update = avg_first.assign_add( - array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype), - 1. - decay_tensor) * delta) - - with ops.control_dependencies([first_moment_update]): - second_moment_update = avg_second.assign_add( - math_ops.cast(self._counter < 1, var.dtype) * - -(1. - decay_tensor) * ( - avg_second - decay_tensor * math_ops.square(delta))) - diag_preconditioner = control_flow_ops.with_dependencies( - [second_moment_update], - clip_ops.clip_by_value(avg_second, 1e-12, 1e12)) - elif isinstance(grad, ops.IndexedSlices): - delta = grad.values - array_ops.gather_nd(avg_first, grad.indices) - first_moment_update = state_ops.scatter_add( - avg_first, - grad.indices, - array_ops.where(self._counter < 1, - math_ops.cast(1., var.dtype), - 1. - decay_tensor) * delta) - - with ops.control_dependencies([first_moment_update]): - avg_second = state_ops.scatter_add( - avg_second, - grad.indices, - math_ops.cast(self._counter < 1, var.dtype) * - -(1. - decay_tensor) * ( - array_ops.gather_nd(avg_second, grad.indices) - decay_tensor * - math_ops.square(delta))) - avg_second = array_ops.gather_nd(avg_second, grad.indices) - # TODO(b/70783772) - diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12) - else: - raise errors.InvalidArgumentError( - None, None, 'grad must of type Tensor or IndexedSlice') - - diag_preconditioner *= batch_size - - if self._use_single_learning_rate: - diag_preconditioner = math_ops.reduce_mean(diag_preconditioner) - - # From Theorem 2 Corollary 1 of Mandt et al. 2017 - return 2. * batch_size / ( - math_ops.cast(self._total_num_examples, var.dtype.base_dtype) * - diag_preconditioner) - - def _apply_dense(self, grad, var): - - max_learning_rate = array_ops.where(self._counter < self._burnin, - self._burnin_max_learning_rate, - self._max_learning_rate) - - learn_rates = clip_ops.clip_by_value( - self._get_coordinatewise_learning_rate(grad, var), 0.0, - math_ops.cast(max_learning_rate, var.dtype.base_dtype)) - - newgrad = grad * learn_rates - return training_ops.apply_gradient_descent( - var, - math_ops.cast(1.0, var.dtype), - newgrad, - use_locking=self._use_locking).op - - def _apply_sparse(self, grad, var): - - max_learning_rate = array_ops.where(self._counter < self._burnin, - self._burnin_max_learning_rate, - self._max_learning_rate) - - learn_rate = clip_ops.clip_by_value( - self._get_coordinatewise_learning_rate(grad, var), 0.0, - math_ops.cast(max_learning_rate, var.dtype)) - delta = grad.values * learn_rate - - return state_ops.scatter_sub(var, grad.indices, delta, - use_locking=self._use_locking) - - def _finish(self, update_ops, name_scope): - update_ops.append([self._counter.assign_add(1)]) - return control_flow_ops.group(*update_ops, name=name_scope) - - @property - def variable_scope(self): - """Variable scope of all calls to `tf.get_variable`.""" - return self._variable_scope diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 6fdcd0f996ee011842a5add79f06264a28a2145c..8eac1243ef63dd09c5c5dad4bcd9bd7a15f58900 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -14,15 +14,6 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = ["**/OWNERS"], - ), - visibility = ["//tensorflow:__subpackages__"], -) - package_group(name = "friends") cc_library( @@ -128,7 +119,7 @@ py_library( py_test( name = "gbdt_batch_test", - size = "small", + size = "medium", srcs = ["python/training/functions/gbdt_batch_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 289f5bb3140974d8c37f4938ceef27275b099f9a..8cff1a3bb1d11aff6a264636291a7149b40de516 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -10,23 +10,17 @@ package( load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", + deps = [ + "custom_export_strategy", + ":custom_loss_head", + ":estimator", + ":model", + ":trainer_hooks", + ], ) py_library( @@ -34,12 +28,13 @@ py_library( srcs = ["model.py"], srcs_version = "PY2AND3", deps = [ + ":estimator_utils", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", "//tensorflow/python:framework_ops", "//tensorflow/python:state_ops", - "//tensorflow/python:training", + "//tensorflow/python:training_util", ], ) @@ -57,6 +52,18 @@ py_library( ], ) +py_library( + name = "estimator_utils", + srcs = ["estimator_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + ], +) + py_test( name = "trainer_hooks_test", size = "small", @@ -124,6 +131,7 @@ py_library( srcs = ["estimator.py"], srcs_version = "PY2AND3", deps = [ + ":estimator_utils", ":model", "//tensorflow/contrib/boosted_trees:losses", "//tensorflow/contrib/learn", @@ -136,6 +144,7 @@ py_library( srcs = ["dnn_tree_combined_estimator.py"], srcs_version = "PY2AND3", deps = [ + ":estimator_utils", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", @@ -149,7 +158,7 @@ py_library( py_test( name = "dnn_tree_combined_estimator_test", - size = "small", + size = "medium", srcs = ["dnn_tree_combined_estimator_test.py"], srcs_version = "PY2AND3", tags = [ @@ -165,3 +174,22 @@ py_test( "//tensorflow/python:framework_for_generated_wrappers", ], ) + +py_test( + name = "estimator_test", + size = "medium", + srcs = ["estimator_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", + "notsan", + ], + deps = [ + ":estimator", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 31f5c444817b9b82723c86bea3504d4934e57eb8..62f1f4122b05b56a708823df4246d618bd3fa5d4 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -39,7 +39,8 @@ _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" def make_custom_export_strategy(name, convert_fn, feature_columns, - export_input_fn): + export_input_fn, + use_core_columns=False): """Makes custom exporter of GTFlow tree format. Args: @@ -54,11 +55,11 @@ def make_custom_export_strategy(name, An `ExportStrategy`. """ base_strategy = saved_model_export_utils.make_export_strategy( - serving_input_fn=export_input_fn) + serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( - input_fn.features, feature_columns) + input_fn.features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" @@ -93,7 +94,9 @@ def make_custom_export_strategy(name, "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir - return export_strategy.ExportStrategy(name, export_fn) + + return export_strategy.ExportStrategy( + name, export_fn, strip_default_attrs=True) def convert_to_universal_format(dtec, sorted_feature_names, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index cec3892b57655dc967b4e7926f7f5a6a30084487..911d87fa10570382ee5f03edfc1bfd1d116c8360 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -19,14 +19,13 @@ logits of the DNN. The input layer of the DNN (including the embeddings learned over sparse features) can optionally be provided to the boosted trees as an additional input feature. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function import six - from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch @@ -34,6 +33,7 @@ from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn @@ -43,7 +43,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary from tensorflow.python.training import training_util - _DNN_LEARNING_RATE = 0.001 @@ -59,16 +58,26 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, - dnn_feature_columns, tree_learner_config, num_trees, - tree_examples_per_layer, - config=None, dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, - tree_feature_columns=None, - tree_center_bias=True): +def _dnn_tree_combined_model_fn(features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + use_core_versions=False): """DNN and GBDT combined model_fn. Args: @@ -101,11 +110,15 @@ def _dnn_tree_combined_model_fn( as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. Returns: A `ModelFnOps` object. @@ -123,8 +136,7 @@ def _dnn_tree_combined_model_fn( dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( partitioned_variables.min_max_variable_partitioner( - max_partitions=config.num_ps_replicas, - min_slice_size=64 << 20)) + max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) with variable_scope.variable_scope( dnn_parent_scope, @@ -135,11 +147,17 @@ def _dnn_tree_combined_model_fn( "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=dnn_partitioner) as input_layer_scope: - input_layer = layers.input_from_feature_columns( - columns_to_tensors=features, - feature_columns=dnn_feature_columns, - weight_collections=[dnn_parent_scope], - scope=input_layer_scope) + if use_core_versions: + input_layer = feature_column_lib.input_layer( + features=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope]) + else: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) previous_layer = input_layer for layer_id, num_hidden_units in enumerate(dnn_hidden_units): with variable_scope.variable_scope( @@ -156,8 +174,7 @@ def _dnn_tree_combined_model_fn( _add_hidden_layer_summary(net, hidden_layer_scope.name) previous_layer = net with variable_scope.variable_scope( - "logits", - values=(previous_layer,)) as logits_scope: + "logits", values=(previous_layer,)) as logits_scope: dnn_logits = layers.fully_connected( previous_layer, head.logits_dimension, @@ -175,8 +192,7 @@ def _dnn_tree_combined_model_fn( optimizer=_get_optimizer(dnn_optimizer), name=dnn_parent_scope, variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=dnn_parent_scope), + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), # Empty summaries to prevent optimizers from logging training_loss. summaries=[]) @@ -215,41 +231,77 @@ def _dnn_tree_combined_model_fn( update_op = state_ops.assign_add(global_step, 1).op return update_op - tree_train_logits = dnn_logits + tree_logits + if predict_with_tree_only: + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + tree_train_logits = tree_logits + else: + tree_train_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: tree_logits, + lambda: dnn_logits) + else: + tree_train_logits = dnn_logits + tree_logits def _no_train_op_fn(loss): """Returns a no-op.""" del loss return control_flow_ops.no_op() - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + dnn_train_op).train_op + + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + tree_train_op).train_op + + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op if tree_center_bias: num_trees += 1 finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp( - dnn_train_op, dnn_steps_to_train, tree_train_op), - trainer_hooks.StopAfterNTrees( - num_trees, attempted_trees, finalized_trees)]) + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) + ]) return model_fn_ops @@ -276,8 +328,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -317,11 +371,15 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -332,16 +390,32 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedRegressor(estimator.Estimator): @@ -365,8 +439,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -406,11 +482,15 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.regression_head( label_name=label_name, @@ -426,15 +506,32 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedEstimator(estimator.Estimator): @@ -459,8 +556,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -495,21 +594,42 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ + def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 83d58c561008e8a5a69eb503d1605bb9e940f281..f495edc62f0909880c170ccb4cf5d11e3f20f55c 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -19,15 +19,17 @@ from __future__ import division from __future__ import print_function import tempfile - from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest @@ -100,6 +102,35 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateDontThrowExceptionWithCore(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + # Use core head + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + classifier = estimator.DNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + # Use core feature columns + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[], + use_core_versions=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 01752416b347dd0a5e646283b6b5572592df4690..9c36c302210185bc390751a0229a61f2f8cd91b8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -40,7 +40,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): label_keys=None, feature_engineering_fn=None, logits_modifier_function=None, - center_bias=True): + center_bias=True, + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -63,6 +65,17 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree Raises: ValueError: If learner_config is not valid. @@ -72,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): # supports second order derivative. def loss_fn(labels, logits, weights=None): result = losses.per_example_maxent_loss( - labels=labels, logits=logits, weights=weights, + labels=labels, + logits=logits, + weights=weights, num_classes=n_classes) return math_ops.reduce_mean(result[0]) else: @@ -81,7 +96,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): n_classes=n_classes, weight_column_name=weight_column_name, enable_centered_bias=False, - loss_fn=loss_fn) + loss_fn=loss_fn, + label_keys=label_keys) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: @@ -98,6 +114,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'examples_per_layer': examples_per_layer, 'center_bias': center_bias, 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, }, model_dir=model_dir, config=config, @@ -119,7 +137,9 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): config=None, feature_engineering_fn=None, logits_modifier_function=None, - center_bias=True): + center_bias=True, + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -144,6 +164,15 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ head = head_lib.regression_head( label_name=label_name, @@ -165,6 +194,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'examples_per_layer': examples_per_layer, 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, + 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, @@ -188,7 +219,9 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): config=None, feature_engineering_fn=None, logits_modifier_function=None, - center_bias=True): + center_bias=True, + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -209,6 +242,15 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -221,6 +263,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'examples_per_layer': examples_per_layer, 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, + 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75ef1b050028b6462b255827c06e836e5c481844 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -0,0 +1,160 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GBDT estimator.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +from tensorflow.contrib.boosted_trees.estimator_batch import estimator +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column +from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.feature_column import feature_column_lib as core_feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest + + +def _train_input_fn(): + features = {"x": constant_op.constant([[2.], [1.], [1.]])} + label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + return features, label + + +def _eval_input_fn(): + features = {"x": constant_op.constant([[1.], [2.], [2.]])} + label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) + return features, label + + +class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._export_dir_base = tempfile.mkdtemp() + "export/" + gfile.MkDir(self._export_dir_base) + + def testFitAndEvaluateDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + + def testThatLeafIndexIsInPredictions(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("leaf_index" in prediction_dict) + self.assertTrue("logits" in prediction_dict) + + def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + # Use core head + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + model = estimator.GradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")], + use_core_libs=True) + + model.fit(input_fn=_train_input_fn, steps=15) + model.evaluate(input_fn=_eval_input_fn, steps=1) + model.export(self._export_dir_base) + + def testFitAndEvaluateDontThrowExceptionWithCoreForClassifier(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")], + use_core_libs=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + + def testFitAndEvaluateDontThrowExceptionWithCoreForRegressor(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + regressor = estimator.GradientBoostedDecisionTreeRegressor( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")], + use_core_libs=True) + + regressor.fit(input_fn=_train_input_fn, steps=15) + regressor.evaluate(input_fn=_eval_input_fn, steps=1) + regressor.export(self._export_dir_base) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48a7f85eada8c72de83b814af2f00e97a62a073e --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for converting between core and contrib feature columns.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output + +_CORE_MODE_TO_CONTRIB_MODE_ = { + model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN, + model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER +} + + +def _core_mode_to_contrib_mode(mode): + return _CORE_MODE_TO_CONTRIB_MODE_[mode] + + +def _export_outputs_to_output_alternatives(export_outputs): + """Converts EstimatorSpec.export_outputs to output_alternatives. + + Args: + export_outputs: export_outputs created by create_estimator_spec. + Returns: + converted output_alternatives. + """ + output = dict() + if export_outputs is not None: + for key, value in export_outputs.items(): + if isinstance(value, export_output.ClassificationOutput): + exported_predictions = { + prediction_key.PredictionKey.SCORES: value.scores, + prediction_key.PredictionKey.CLASSES: value.classes + } + output[key] = (constants.ProblemType.CLASSIFICATION, + exported_predictions) + return output + return None + + +def estimator_spec_to_model_fn_ops(estimator_spec, export_alternatives=False): + if export_alternatives: + alternatives = _export_outputs_to_output_alternatives( + estimator_spec.export_outputs) + else: + alternatives = [] + + return model_fn.ModelFnOps( + mode=_core_mode_to_contrib_mode(estimator_spec.mode), + predictions=estimator_spec.predictions, + loss=estimator_spec.loss, + train_op=estimator_spec.train_op, + eval_metric_ops=estimator_spec.eval_metric_ops, + output_alternatives=alternatives) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index c6455a7ea3d18eb358edee034cee58b2bed21024..1ee891198939e53fc5913104b2c2e65dc977823f 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -20,6 +20,7 @@ from __future__ import print_function import copy +from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch @@ -60,7 +61,10 @@ def model_builder(features, labels, mode, params, config): feature_columns = params["feature_columns"] weight_column_name = params["weight_column_name"] num_trees = params["num_trees"] + use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + if features is None: raise ValueError("At least one feature must be specified.") @@ -93,7 +97,9 @@ def model_builder(features, labels, mode, params, config): learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, - features=training_features) + features=training_features, + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -108,12 +114,25 @@ def model_builder(features, labels, mode, params, config): update_op = state_ops.assign_add(global_step, 1).op return update_op - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) + create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] if num_trees: if center_bias: num_trees += 1 diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 754b7bc3270d647fc381033b769eadd7b791771e..3bf33186ec13f5ff991db938d59849c0124a30a0 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -137,6 +137,61 @@ class TreeEnsembleDeserializeOp : public OpKernel { } }; +class TreeEnsembleUsedHandlersOp : public OpKernel { + public: + explicit TreeEnsembleUsedHandlersOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("num_all_handlers", &num_handlers_)); + } + + void Compute(OpKernelContext* context) override { + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); + + // Get the stamp token. + const Tensor* stamp_token_t; + OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); + int64 stamp_token = stamp_token_t->scalar()(); + + // Only the Chief should run this Op and it is guaranteed to be in + // a consistent state so the stamps must always match. + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); + + Tensor* output_used_handlers_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output("used_handlers_mask", {num_handlers_}, + &output_used_handlers_t)); + auto output_used_handlers = output_used_handlers_t->vec(); + + Tensor* output_num_used_handlers_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("num_used_handlers", {}, + &output_num_used_handlers_t)); + int handler_idx = 0; + std::vector used_handlers = ensemble_resource->GetUsedHandlers(); + output_num_used_handlers_t->scalar()() = used_handlers.size(); + for (int64 i = 0; i < num_handlers_; ++i) { + if (handler_idx >= used_handlers.size() || + used_handlers[handler_idx] > i) { + output_used_handlers(i) = false; + } else { + OP_REQUIRES(context, used_handlers[handler_idx] == i, + errors::InvalidArgument("Handler IDs should be sorted.")); + ++handler_idx; + output_used_handlers(i) = true; + } + } + } + + private: + int64 num_handlers_; +}; + REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource); REGISTER_KERNEL_BUILDER( @@ -155,5 +210,7 @@ REGISTER_KERNEL_BUILDER(Name("TreeEnsembleSerialize").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("TreeEnsembleDeserialize").Device(DEVICE_CPU), TreeEnsembleDeserializeOp); +REGISTER_KERNEL_BUILDER(Name("TreeEnsembleUsedHandlers").Device(DEVICE_CPU), + TreeEnsembleUsedHandlersOp); } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index b3fe38614e05801b223f0c96f7a70ce7e432a70b..9493c1a1394040db3b744f1b382b20bd5bd1988d 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; +const char* kLeafIndexTensorName = "leaf_index"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, @@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel { core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } else { - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } } - private: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource) { + protected: + // return_output_leaf_index is a boolean variable indicating whether to output + // leaf index in prediction. Though this class invokes only with this param + // value as false, the subclass GradientTreesPredictionVerboseOp will invoke + // with the true value. + virtual void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + const bool return_output_leaf_index) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); + // Allocate output leaf index matrix. + Tensor* output_leaf_index_t = nullptr; + if (return_output_leaf_index) { + OP_REQUIRES_OK(context, context->allocate_output( + kLeafIndexTensorName, + {batch_size, ensemble_resource->num_trees()}, + &output_leaf_index_t)); + } // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; @@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel { i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, - worker_threads, output_predictions); + worker_threads, output_predictions, + output_leaf_index_t); } else { MultipleAdditiveTrees::Predict( ensemble_resource->decision_tree_ensemble(), trees_to_include, - batch_features, worker_threads, output_predictions); + batch_features, worker_threads, output_predictions, + output_leaf_index_t); } // Output dropped trees and original weights. @@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel { {2, static_cast(dropped_trees.size())}, &output_dropout_info_t)); auto output_dropout_info = output_dropout_info_t->matrix(); - for (int32 i = 0; i < dropped_trees.size(); ++i) { output_dropout_info(0, i) = dropped_trees[i]; output_dropout_info(1, i) = original_weights[i]; @@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU), GradientTreesPredictionOp); +// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp +// and have an additional output of tensor of rank 2 containing leaf ids for +// each tree where an instance ended up with. +class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp { + public: + explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context) + : GradientTreesPredictionOp(context) {} + + protected: + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + bool return_output_leaf_index) override { + GradientTreesPredictionOp::DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/true); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU), + GradientTreesPredictionVerboseOp); + class GradientTreesPartitionExamplesOp : public OpKernel { public: explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context) diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0f4c2298f56be48bb32f52d5d44cff8afe284f1e..0b28f81e7ca9a1228adc5bde19c429265e0aa9b8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -253,7 +253,7 @@ class CreateQuantileAccumulatorOp : public OpKernel { private: float epsilon_; int32 num_quantiles_; - // An upperbound on the number of enteries that the summaries might have + // An upper bound on the number of entries that the summaries might have // for a feature. int64 max_elements_; bool generate_quantiles_; diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 44a8ffaf4b2f5a9c11b3abc46ce55a18c80ad318..401bec84a20a0fefcddbfa1039a117e65f853633 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -43,47 +43,60 @@ namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; } // namespace -class BaseBuildSplitOp : public OpKernel { +class SplitBuilderState { public: - explicit BaseBuildSplitOp(OpKernelConstruction* const context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", - &feature_column_group_id_)); + explicit SplitBuilderState(OpKernelContext* const context) { + const Tensor* l1_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l1_regularization", &l1_regularization_)); + context->input("l1_regularization", &l1_regularization_t)); + const Tensor* l2_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l2_regularization", &l2_regularization_)); - OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization", - &tree_complexity_regularization_)); + context->input("l2_regularization", &l2_regularization_t)); + const Tensor* tree_complexity_regularization_t; + OP_REQUIRES_OK(context, context->input("tree_complexity_regularization", + &tree_complexity_regularization_t)); + const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, - context->GetAttr("min_node_weight", &min_node_weight_)); + context->input("min_node_weight", &min_node_weight_t)); - int strategy; - OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy)); + const Tensor* feature_column_group_id_t; + OP_REQUIRES_OK(context, context->input("feature_column_group_id", + &feature_column_group_id_t)); + + const Tensor* multiclass_strategy_t; + OP_REQUIRES_OK( + context, context->input("multiclass_strategy", &multiclass_strategy_t)); + int strategy = multiclass_strategy_t->scalar()(); OP_REQUIRES( context, boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid( strategy), errors::InvalidArgument("Wrong multiclass strategy passed.")); - multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - } - NodeStats ComputeNodeStats(const GradientStats& grad_stats) { - return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, - multiclass_strategy_, grad_stats); - } + multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - void ReadClassId(OpKernelContext* const context, int32* class_id) { const Tensor* class_id_t; OP_REQUIRES_OK(context, context->input("class_id", &class_id_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()), errors::InvalidArgument("class_id must be a scalar.")); - *class_id = class_id_t->scalar()(); + class_id_ = class_id_t->scalar()(); + + l1_regularization_ = l1_regularization_t->scalar()(); + l2_regularization_ = l2_regularization_t->scalar()(); + tree_complexity_regularization_ = + tree_complexity_regularization_t->scalar()(); + min_node_weight_ = min_node_weight_t->scalar()(); + feature_column_group_id_ = feature_column_group_id_t->scalar()(); } - void FillLeaf(const int class_id, const NodeStats& best_node_stats, + NodeStats ComputeNodeStats(const GradientStats& grad_stats) { + return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, + multiclass_strategy_, grad_stats); + } + + void FillLeaf(const NodeStats& best_node_stats, boosted_trees::trees::Leaf* leaf) const { - if (class_id == -1) { + if (class_id_ == -1) { // This would be the case either for TREE_PER_CLASS with only 2 classes, // or for other multiclass strategies. for (float f : best_node_stats.weight_contribution) { @@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel { CHECK(best_node_stats.weight_contribution.size() == 1) << "Weight contribution size = " << best_node_stats.weight_contribution.size(); - leaf->mutable_sparse_vector()->add_index(class_id); + leaf->mutable_sparse_vector()->add_index(class_id_); leaf->mutable_sparse_vector()->add_value( best_node_stats.weight_contribution[0]); } } - protected: + int32 feature_column_group_id() { return feature_column_group_id_; } + float tree_complexity_regularization() { + return tree_complexity_regularization_; + } + + private: LearnerConfig_MultiClassStrategy multiclass_strategy_; - int32 feature_column_group_id_; float l1_regularization_; float l2_regularization_; - float min_node_weight_; float tree_complexity_regularization_; + float min_node_weight_; + int32 class_id_; + int32 feature_column_group_id_; }; -class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildDenseInequalitySplitsOp : public OpKernel { public: explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) {} + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } @@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); -class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildSparseInequalitySplitsOp : public OpKernel { public: explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // For each partition (tree node), store starting index for each dimension. PartitionAndDimensionBoundaries partition_boundaries; @@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { const auto& dimension_boundaries = @@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { OP_REQUIRES( context, - bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); // Dimension for bias feature is always 0 @@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats root_gradient_stats(*gradients_t, *hessians_t, bias_start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { @@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { << bucket_ids_and_dimensions(start_index, 1) << " and for " << bucket_ids_and_dimensions(end_index - 1, 0) << " " << bucket_ids_and_dimensions(end_index - 1, 1); - if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) << "Dimension of bias feature should be 0"; @@ -422,6 +439,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } present_gradient_stats *= normalizer_ratio; + GradientStats not_present = + root_gradient_stats - present_gradient_stats; + // If there was (almost) no sparsity, fix the default direction to LEFT. + bool fixed_default_direction = not_present.IsAlmostZero(); GradientStats left_gradient_stats; for (int64 element_idx = start_index; element_idx < end_index; @@ -441,11 +462,12 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // backward pass gradients. GradientStats right_gradient_stats = present_gradient_stats - left_gradient_stats; + { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats left_stats_default_left = state.ComputeNodeStats( + root_gradient_stats - right_gradient_stats); NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); + state.ComputeNodeStats(right_gradient_stats); if (left_stats_default_left.gain + right_stats_default_left.gain > best_gain) { best_gain = @@ -457,11 +479,13 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { best_dimension_idx = dimension_id; } } - { + // Consider calculating the default direction only when there were + // enough missing examples. + if (!fixed_default_direction) { NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); + state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = state.ComputeNodeStats( + root_gradient_stats - left_gradient_stats); if (left_stats_default_right.gain + right_stats_default_right.gain > best_gain) { best_gain = left_stats_default_right.gain + @@ -487,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_sparse_float_binary_split_default_left() ->mutable_split(); } - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); // Set the feature index for the best feature column. const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); @@ -498,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); } } @@ -519,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // For each partition, store start indices of feature column dimensions. typedef std::vector> PartitionAndDimensionBoundaries; - - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), BuildSparseInequalitySplitsOp); -class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { +class BuildCategoricalEqualitySplitsOp : public OpKernel { public: explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -554,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; @@ -598,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -618,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -630,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(feature_column_group_id_); + equality_split->set_feature_column(state.feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } - - private: - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 7f8dea1d3c2a04b725843f6e2932a0cdfbc7733c..1bfeed306641111718984b2097512e5ec3fa8630 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -361,27 +361,10 @@ class GrowTreeEnsembleOp : public OpKernel { // Increment attempt stats. ensemble_resource->IncrementAttempts(); - // In case we want to do feature selection and we have reached the limit, - // build a list of handlers used so far to avoid adding new features. - std::vector allowed_handlers; - if (learner_config_.constraints().max_number_of_unique_feature_columns() > - 0) { - allowed_handlers = ensemble_resource->GetUsedHandlers(); - // TODO(soroush): We can disable handlers that are not going to be used to - // avoid unnecessary computations. - if (allowed_handlers.size() < - learner_config_.constraints() - .max_number_of_unique_feature_columns()) { - // We have not reached the limit yet. Empty the list of allow features - // which means we can keep adding new features. - allowed_handlers.clear(); - } - } - // Find best splits for each active partition. std::map best_splits; - FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list, - gains_list, splits_list, &best_splits); + FindBestSplitsPerPartition(context, partition_ids_list, gains_list, + splits_list, &best_splits); // No-op if no new splits can be considered. if (best_splits.empty()) { @@ -422,19 +405,12 @@ class GrowTreeEnsembleOp : public OpKernel { // and finds the best split for each partition. void FindBestSplitsPerPartition( OpKernelContext* const context, - const std::vector& allowed_handlers, // Empty means all handlers. const OpInputList& partition_ids_list, const OpInputList& gains_list, const OpInputList& splits_list, std::map* best_splits) { // Find best split per partition going through every feature candidate. // TODO(salehay): Is this worth parallelizing? for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { - if (!allowed_handlers.empty()) { - if (!std::binary_search(allowed_handlers.begin(), - allowed_handlers.end(), handler_id)) { - continue; - } - } const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); const auto& splits = splits_list[handler_id].vec(); diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 131bd48562a55a08981ac73277e93024db0d85d3..3028c2281705bd7e34b212332160d25386559d4e 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -15,17 +15,6 @@ load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # Utils cc_library( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 7df514cd207c5e781f3b4abaa2020016b197669d..409a2d8f46c331c13aec10542c4967d50575e94a 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,8 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops +from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops @@ -72,9 +74,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -130,11 +134,14 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): gradient_shape, hessian_shape, name="StatsAccumulator/{}".format(self._name)) - self._quantile_accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token, - epsilon=epsilon, - num_quantiles=num_quantiles, - name="QuantileAccumulator/{}".format(self._name)) + # Allocate both stats accumulator and quantile accumulator on the same + # device so that we can build splits with fewer RPCs. + with ops.colocate_with(self._stats_accumulator.resource()): + self._quantile_accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token, + epsilon=epsilon, + num_quantiles=num_quantiles, + name="QuantileAccumulator/{}".format(self._name)) class DenseSplitHandler(InequalitySplitHandler): @@ -236,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - # Get the aggregated gradients and hessians per - # pair. - # In order to distribute the computation on all the PSs we use the PS that - # had the stats accumulator on. - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - - partition_ids, gains, split_infos = ( - split_handler_ops.build_dense_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_dense_split_scalar + else: + handler = make_dense_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a dense feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_dense_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos class SparseSplitHandler(InequalitySplitHandler): @@ -327,9 +363,6 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, name=name) - # Register sparse_make_stats_update function as an Op to the graph. - g = ops.get_default_graph() - sparse_make_stats_update.add_to_graph(g) self._sparse_float_column = sparse_float_column def scheduled_reads(self): @@ -361,8 +394,8 @@ class SparseSplitHandler(InequalitySplitHandler): are_buckets_ready, buckets = scheduled_reads[0] with ops.name_scope(self._name, "SparseSplitHandler"): (quantile_indices, quantile_values, quantile_shapes, quantile_weights, - example_partition_ids, - feature_ids, gradients, hessians) = sparse_make_stats_update( + example_partition_ids, feature_ids, gradients, + hessians) = sparse_make_stats_update( is_active, are_buckets_ready, self._sparse_float_column.indices, self._sparse_float_column.values, self._sparse_float_column.dense_shape, buckets, @@ -379,47 +412,129 @@ class SparseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) - - -@function.Defun(dtypes.bool, dtypes.bool, dtypes.float32, dtypes.float32, - dtypes.int32, dtypes.float32, dtypes.float32, dtypes.float32, - dtypes.float32, dtypes.float32) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_sparse_split_scalar + else: + handler = make_sparse_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a sparse feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_make_split(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight): + """Function that builds splits for a sparse feature column.""" + return func( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional) + + return f + +make_dense_split_scalar = _specialize_make_split(_make_dense_split, + is_multi_dimentional=False) +make_dense_split_tensor = _specialize_make_split(_make_dense_split, + is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=True) + + +@function.Defun( + dtypes.bool, + dtypes.bool, + dtypes.float32, + dtypes.float32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantile_buckets, example_partition_ids, gradients, hessians, weights, empty_gradients, empty_hessians): @@ -452,9 +567,20 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, gradients, hessians) -@function.Defun(dtypes.bool, dtypes.bool, dtypes.int64, dtypes.float32, - dtypes.int64, dtypes.float32, dtypes.int32, dtypes.float32, - dtypes.float32, dtypes.float32, dtypes.float32, dtypes.float32) +@function.Defun( + dtypes.bool, + dtypes.bool, + dtypes.int64, + dtypes.float32, + dtypes.int64, + dtypes.float32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) def sparse_make_stats_update( is_active, are_buckets_ready, sparse_column_indices, sparse_column_values, sparse_column_shape, quantile_buckets, example_partition_ids, gradients, @@ -481,11 +607,18 @@ def sparse_make_stats_update( example_partition_ids) # Compute aggregate stats for each partition. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. bias_feature_ids = array_ops.fill( @@ -513,8 +646,9 @@ def sparse_make_stats_update( empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( - [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( - [0, 1], dtype=dtypes.int64), empty_float) + [], dtype=dtypes.int64, shape=[0, 2]), empty_float, + constant_op.constant([0, 1], dtype=dtypes.int64), + empty_float) handler_active = (sparse_column_indices, sparse_column_values, sparse_column_shape, weights) quantile_indices, quantile_values, quantile_shape, quantile_weights = ( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 54d03018d9e266beabbbabd78ebbb80cfe689c04..2f2c2302113bf59d6a065d5005c934dc76c2148d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 @@ -65,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.scalar() split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -199,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2, 2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -227,7 +231,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -240,7 +246,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -285,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -313,7 +319,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -326,7 +333,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -369,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -396,7 +403,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -409,7 +417,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -443,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, - min_node_weight=0, + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -470,7 +478,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -483,7 +492,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -576,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, min_node_weight=1.5, epsilon=0.001, @@ -603,7 +612,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -616,7 +626,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -685,10 +695,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -713,8 +723,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] - + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -727,7 +737,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -811,10 +821,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -839,7 +849,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -853,7 +864,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -905,10 +916,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -933,7 +944,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -947,7 +959,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -996,10 +1008,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1024,7 +1036,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1038,7 +1051,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1065,10 +1078,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1096,7 +1109,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1110,7 +1124,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1138,10 +1152,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1166,7 +1180,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1180,7 +1195,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index cd925f6b65e569538212e9c26aef0abc8482960b..794ba2bcb0aafa26c5e1c90fcd66caf9dd5bf7d5 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -137,7 +137,7 @@ struct NodeStats { Eigen::MatrixXf hessian = TensorToEigenMatrix(grad_stats.second.t, grad_dim, grad_dim); // I is an identity matrix. - // The gain in general form is -g^T (H+l2 I)^-1 g. + // The gain in general form is g^T (H+l2 I)^-1 g. // The node weights are -(H+l2 I)^-1 g. Eigen::MatrixXf identity; identity.setIdentity(grad_dim, grad_dim); @@ -240,7 +240,7 @@ struct NodeStats { // given regularized Hessian and gradient vector g. void CalculateWeightAndGain(const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g) { - // The gain in general form is -g^T (Hessian_and_regularization)^-1 g. + // The gain in general form is g^T (Hessian_and_regularization)^-1 g. // The node weights are -(Hessian_and_regularization)^-1 g. Eigen::VectorXf weight; // If we want to calculate x = K^-1 v, instead of explicitly calculating diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 43b00d4c6dc2e0066810012292874314215c41be..c9223afeab233497bce9f680bd44bd10ccfc6491 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict( const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions) { + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index) { // Zero out predictions as the model is additive. output_predictions.setZero(); @@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict( // Lambda for doing a block of work. auto update_predictions = [&config, &features, &trees_to_include, - &output_predictions](int64 start, int64 end) { + &output_predictions, + &output_leaf_index](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); + Tensor dummy_tensor(DT_INT32, TensorShape({1, 1})); + tensorflow::TTypes::Matrix output_leaf_index_mat = + output_leaf_index != nullptr ? output_leaf_index->matrix() + : dummy_tensor.matrix(); for (const auto& example : examples_iterable) { for (const int32 tree_idx : trees_to_include) { const boosted_trees::trees::DecisionTreeConfig& tree = @@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict( const float tree_weight = config.tree_weights(tree_idx); const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + // Checks if output leaf tree index is required. + if (output_leaf_index != nullptr) { + output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx; + } const auto& leaf_node = tree.nodes(leaf_idx); QCHECK(leaf_node.has_leaf()) << "Invalid leaf node: " << leaf_node.DebugString(); diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f..940531c4ba4bcac19fa980deb091e55b48e0693b 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -33,12 +33,17 @@ class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates // output predictions accordingly, for the given list of trees. + // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not + // nullptr, this method fills output_leaf_indices with a per-tree leaf id + // where each of the instances from 'features' ended up in. Its shape is num + // examples X num of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions); + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 4ca18bedb1054ef64c6d4b25bbad04842bab1a6a..462a9ac86fe51d07cfb958d9be49bef84811a52e 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); } @@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } + // Normal case with leaf node. + { + // Initialize output leaf index tensor, since leaf index is positive in this + // case, initialize with the value of -1. Since there are 2 examples and + // there are 2 trees, initialize leaf output index by 2 * 2. + Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2})); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix, + &output_leaf_index_tensor); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 0, 0)); // 1st leaf for the first example + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 1, 0)); // 1st leaf for the second example + EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix()( + 0, 1)); // 2nd leaf for the first example + EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix()( + 1, 1)); // 2nd leaf for the second example + } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). @@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Drop first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). } // Drop second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). } // Drop all trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); } @@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) @@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) @@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) @@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Drop both trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); @@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 8ad97fedc923ac50bcaad86e0ba2c2e46df6821b..c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -295,7 +295,7 @@ WeightedQuantilesStream::GetQuantileSpecs( if (eps <= std::numeric_limits::epsilon()) { // Exact quantile computation at the expense of RAM. max_level = 1; - block_size = std::max(max_elements, 2LL); + block_size = std::max(max_elements, int64{2}); } else { // The bottom-most level will become full at most // (max_elements / block_size) times, the level above will become full @@ -315,7 +315,7 @@ WeightedQuantilesStream::GetQuantileSpecs( block_size = static_cast(ceil(max_level / eps)) + 1; } } - return std::make_tuple(max_level, std::max(block_size, 2LL)); + return std::make_tuple(max_level, std::max(block_size, int64{2})); } } // namespace quantiles diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc index 4481c0d0e4400acd93c9a277de185db7aaf9bcb0..67ac9bf387ae9b3ca29e610c2c4138c28302ca33 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc @@ -138,6 +138,12 @@ void GenerateOneValue(int32 worker_id, int64 max_elements, double *total_weight, stream->Finalize(); } +void GenerateOneZeroWeightedValue(int32 worker_id, int64 max_elements, + double *total_weight, Stream *stream) { + stream->PushEntry(10, 0); + stream->Finalize(); +} + TEST(WeightedQuantilesStreamTest, OneValue) { const double eps = 0.01; const int64 max_elements = 1 << 16; @@ -145,6 +151,13 @@ TEST(WeightedQuantilesStreamTest, OneValue) { {10.0, 10.0, 10.0, 10.0, 10.0}, 1e-2); } +TEST(WeightedQuantilesStreamTest, OneZeroWeightValue) { + const double eps = 0.01; + const int64 max_elements = 1 << 16; + TestSingleWorkerStreams(eps, max_elements, GenerateOneZeroWeightedValue, {}, + 1e-2); +} + TEST(WeightedQuantilesStreamTest, FixedUniform) { const double eps = 0.01; const int64 max_elements = 1 << 16; diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index aec232f3cbb096f0aa51e4362a821882391f8027..a7e7bfc13cadcea4d29d33e0dbd955bdad6ffcb9 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -195,7 +195,7 @@ class WeightedQuantilesSummary { // designed to be cache-friendly. void Compress(int64 size_hint, double min_eps = 0) { // No-op if we're already within the size requirement. - size_hint = std::max(size_hint, 2LL); + size_hint = std::max(size_hint, int64{2}); if (entries_.size() <= size_hint) { return; } @@ -235,6 +235,11 @@ class WeightedQuantilesSummary { // The resulting boundaries are guaranteed to both contain at least // num_boundaries unique elements and maintain approximation bounds. std::vector GenerateBoundaries(int64 num_boundaries) const { + std::vector output; + if (entries_.empty()) { + return output; + } + // Generate soft compressed summary. WeightedQuantilesSummary compressed_summary; @@ -246,7 +251,6 @@ class WeightedQuantilesSummary { compressed_summary.Compress(num_boundaries, compression_eps); // Return boundaries. - std::vector output; output.reserve(compressed_summary.entries_.size()); for (const auto& entry : compressed_summary.entries_) { output.push_back(entry.value); @@ -260,7 +264,10 @@ class WeightedQuantilesSummary { // full rank queries O(nlogn). std::vector GenerateQuantiles(int64 num_quantiles) const { std::vector output; - num_quantiles = std::max(num_quantiles, 2LL); + if (entries_.empty()) { + return output; + } + num_quantiles = std::max(num_quantiles, int64{2}); output.reserve(num_quantiles + 1); // Make successive rank queries to get boundaries. diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index cf4f9a097a3368465fd4d9afb981bbaa68b4df49..35b059f3496dbc8fb2b3d4fe6ec6b55a9d73dd0c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -54,7 +54,7 @@ Status BatchFeatures::Initialize( TF_CHECK_AND_RETURN_IF_ERROR( dense_float_feature.dim_size(1) == 1, errors::InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = ", + "Dense float features may not be multivalent: dim_size(1) = ", dense_float_feature.dim_size(1))); dense_float_feature_columns_.emplace_back(dense_float_feature); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index da5e7448519cb7f4092f7bbbe1b526271008ec22..a3b1b013e3a40116f74d6ed2df78d87ed3a11ac7 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -48,9 +48,9 @@ class BatchFeatures { Status GetFeatureColumnSizes(int64* const num_dense_float_features, int64* const num_sparse_float_features, int64* const num_sparse_int_features) const { - QCHECK_NE(num_dense_float_features, nullptr); - QCHECK_NE(num_sparse_float_features, nullptr); - QCHECK_NE(num_sparse_int_features, nullptr); + QCHECK_NE(num_dense_float_features, static_cast(nullptr)); + QCHECK_NE(num_sparse_float_features, static_cast(nullptr)); + QCHECK_NE(num_sparse_int_features, static_cast(nullptr)); *num_dense_float_features = dense_float_feature_columns_.size(); *num_sparse_float_features = sparse_float_feature_columns_.size(); *num_sparse_int_features = sparse_int_feature_columns_.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 609519e8b1153a27d987c5f9ca9bfcc9ee6717d6..cfe9101e7435cd798569f3e52a87fc8ed7b6a239 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -59,7 +59,7 @@ TEST_F(BatchFeaturesTest, DenseFloatFeatures_Multivalent) { BatchFeatures batch_features(1); auto dense_vec = AsTensor({3.0f, 7.0f}, {1, 2}); auto expected_error = InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = 2"); + "Dense float features may not be multivalent: dim_size(1) = 2"); EXPECT_EQ(expected_error, batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {})); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index db34db998a7442c69f2ab468f4557d991429f4ee..ce67db797ded54f5023eaa89369d4781aad31a7c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -54,7 +54,7 @@ Status DropoutUtils::DropOutTrees( if (probability_of_skipping_dropout < 0 || probability_of_skipping_dropout > 1) { return errors::InvalidArgument( - "Probability of skiping dropout must be in [0,1] range"); + "Probability of skipping dropout must be in [0,1] range"); } const auto num_trees = weights.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h index 928bfbfe5c9394ab4083aabced4c8e1149bb10aa..77c16da5410fe65b20839c7b6bc677067d7ff297 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h @@ -66,7 +66,7 @@ class DropoutUtils { // Current weights and num_updates will be updated as a result of this // func std::vector* current_weights, - // How many weight assignements have been done for each tree already. + // How many weight assignments have been done for each tree already. std::vector* num_updates); }; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc index 0138aae3dbd3773241cb6644db625b99f9bf1372..cc7604745e6bb90837eeca1123faa88dc914e4fc 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc @@ -34,7 +34,7 @@ TEST_F(SparseColumnIterableTest, Empty) { } TEST_F(SparseColumnIterableTest, Iterate) { - // 8 examples having 7 sparse features with the 3rd and 7th multi-valent. + // 8 examples having 7 sparse features with the 3rd and 7th multivalent. // This can be visualized like the following: // Instance | Sparse | // 0 | x | diff --git a/tensorflow/contrib/boosted_trees/ops/model_ops.cc b/tensorflow/contrib/boosted_trees/ops/model_ops.cc index 0786c4166410720e8d4d70960e5747ff111076d8..9d6343c7e80f369bf6a5465821c5f4bacb984cd0 100644 --- a/tensorflow/contrib/boosted_trees/ops/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/model_ops.cc @@ -110,5 +110,32 @@ stamp_token: Token to use as the new value of the resource stamp. tree_ensemble_config: Serialized proto of the ensemble. )doc"); +REGISTER_OP("TreeEnsembleUsedHandlers") + .Attr("num_all_handlers: int >= 0") + .Input("tree_ensemble_handle: resource") + .Input("stamp_token: int64") + .Output("num_used_handlers: int64") + .Output("used_handlers_mask: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + c->set_output(0, c->Scalar()); + int num_all_handlers; + c->GetAttr("num_all_handlers", &num_all_handlers).IgnoreError(); + c->set_output(1, {c->Vector(num_all_handlers)}); + + return Status::OK(); + }) + .Doc(R"doc( +Returns the mask of used handlers along with the number of non-zero elements in +this mask. Used in feature selection. + +tree_ensemble_handle: Handle to the tree ensemble. +stamp_token: Token to use as the new value of the resource stamp. +num_used_handlers: number of feature column handlers used in the model. +used_handlers_mask: A boolean vector of showing which handlers are used in the + model. +)doc"); + } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index d66f645f62aba84261337eb37d6e3204930f8f15..6491d58794332e9417951753532e018aafb652b1 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { return Status::OK(); } +static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) { + string learner_config_str; + c->GetAttr("learner_config", &learner_config_str).IgnoreError(); + LearnerConfig learner_config; + ParseProtoUnlimited(&learner_config, learner_config_str); + + bool reduce_dim; + c->GetAttr("reduce_dim", &reduce_dim).IgnoreError(); + // Sets the shape of the output as a matrix. + c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, + reduce_dim ? learner_config.num_classes() - 1 + : learner_config.num_classes())}); + c->set_output(1, {c->UnknownShape()}); + c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)}); + return Status::OK(); +} + REGISTER_OP("GradientTreesPrediction") .Attr("learner_config: string") .Attr("num_dense_float_features: int >= 0") @@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); +REGISTER_OP("GradientTreesPredictionVerbose") + .Attr("learner_config: string") + .Attr("num_dense_float_features: int >= 0") + .Attr("num_sparse_float_features: int >= 0") + .Attr("num_sparse_int_features: int >= 0") + .Attr("use_locking: bool = false") + .Attr("apply_dropout: bool") + .Attr("apply_averaging: bool") + .Attr("center_bias: bool") + .Attr("reduce_dim: bool") + .Input("tree_ensemble_handle: resource") + .Input("seed: int64") + .Input("dense_float_features: num_dense_float_features * float") + .Input("sparse_float_feature_indices: num_sparse_float_features * int64") + .Input("sparse_float_feature_values: num_sparse_float_features * float") + .Input("sparse_float_feature_shapes: num_sparse_float_features * int64") + .Input("sparse_int_feature_indices: num_sparse_int_features * int64") + .Input("sparse_int_feature_values: num_sparse_int_features * int64") + .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") + .Output("predictions: float") + .Output("drop_out_tree_indices_weights: float") + .Output("leaf_index: int32") + .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn) + .Doc(R"doc( +Runs multiple additive regression forests predictors on input instances +and computes the final prediction for each class, and outputs a matrix of +leaf ids per each tree in an ensemble. + +learner_config: Config for the learner of type LearnerConfig proto. Prediction +ops for now uses only LearningRateDropoutDrivenConfig config from the learner. +num_dense_float_features: Number of dense float features. +num_sparse_float_features: Number of sparse float features. +num_sparse_int_features: Number of sparse int features. +use_locking: Whether to use locking. +seed: random seed to be used for dropout. +reduce_dim: whether to reduce the dimension (legacy impl) or not. +apply_dropout: whether to apply dropout during prediction. +apply_averaging: whether averaging of tree ensembles should take place. If set +to true, will be based on AveragingConfig from learner_config. +tree_ensemble_handle: The handle to the tree ensemble. +dense_float_features: Rank 2 Tensors containing dense float feature values. +sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices. +sparse_float_feature_values: Rank 1 Tensors containing sparse float values. +sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes. +sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. +sparse_int_feature_values: Rank 1 Tensors containing sparse int values. +sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. +predictions: Rank 2 Tensor containing predictions per example per class. +drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices +leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up. +)doc"); + REGISTER_OP("GradientTreesPartitionExamples") .Attr("num_dense_float_features: int >= 0") .Attr("num_sparse_float_features: int >= 0") diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index ae99d53a2cf805d70d60746cd44f73f7fd9dc6e2..6aa52463987b55a54b7308765920cbe94c15b8d1 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -272,6 +272,20 @@ REGISTER_OP("Quantiles") .Input("sparse_indices: num_sparse_features * int64") .Output("dense_quantiles: num_dense_features * int32") .Output("sparse_quantiles: num_sparse_features * int32") + .SetShapeFn([](InferenceContext* c) { + int num_dense_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features)); + int num_sparse_features; + TF_RETURN_IF_ERROR( + c->GetAttr("num_sparse_features", &num_sparse_features)); + // Set output shapes (dense_quantiles and sparse_quantiles) by the + // relevant inputs (dense_values and sparse_values). Note that the output + // has an additional dimension for dimension_ids. + for (int i = 0; i < num_dense_features + num_sparse_features; ++i) { + c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 2})); + } + return Status::OK(); + }) .Doc(R"doc( Computes quantile for each a given list of dense and sparse feature values using the given buckets. diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 5d0ebbf73ce1272b51a475f67984db3a181b7130..ca5c7f3d8c78a543c63fbfa9f7eb7c3d348f11b8 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -23,12 +23,6 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("BuildDenseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildSparseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildCategoricalEqualitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("feature_ids: int64") .Input("gradients: float32") .Input("hessians: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs. feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); } // namespace tensorflow - // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index 9a61e163eb5ff51dc75de4e40e0f43b090d03c0c..b07f0a4314246eea63764bb6d5e166dd720644fb 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -4,17 +4,6 @@ exports_files(["LICENSE"]) load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "learner_proto", srcs = [ diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 4407c4d981785a279b6296f4726a221cacb4c5b1..81411aa84ae848cfaa1392e82a1e38c3df19cdb6 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,7 +53,7 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the dimensiong + // If feature column is multivalent, this holds the index of the dimension // for the split. Defaults to 0. int32 dimension_id = 5; float threshold = 2; diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 27c288bbf78b3b593d0807e92ac7fd9afc4d2725..63b9c5fddf0d9967d53077608664b59d9ae00481 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -310,6 +310,22 @@ class ModelOpsTest(test_util.TensorFlowTestCase): # The third tree was added after the save. self.assertAllClose(result.eval(), [[-1.1], [-1.1]]) + def testUsedHandlers(self): + with self.test_session(): + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree_ensemble_config.growing_metadata.used_handler_ids.append(1) + tree_ensemble_config.growing_metadata.used_handler_ids.append(5) + stamp_token = 3 + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=stamp_token, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="create_tree") + resources.initialize_resources(resources.shared_resources()).run() + result = model_ops.tree_ensemble_used_handlers( + tree_ensemble_handle, stamp_token, num_all_handlers=6) + self.assertAllEqual([0, 1, 0, 0, 0, 1], result.used_handlers_mask.eval()) + self.assertEqual(2, result.num_used_handlers.eval()) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf..cf55759aaabfb265466f4bbf8b2806d4347ca0b1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -120,8 +120,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): """Sets up the prediction tests. Create a batch of two examples having one dense float, two sparse float - single valued, one sparse float multidimensionl and one sparse int features. - The data looks like the following: + single valued, one sparse float multidimensional and one sparse int + features. The data looks like the following: | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM | 0 | 7 | -3 | | 9,1 | __, 5.0 | 1 | -2 | | 4 | | 3, ___ @@ -810,7 +810,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # building. This tree should never be dropped. num_trees = 10 with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 10 trees with some weights. for i in range(0, num_trees): @@ -951,7 +951,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testDropOutZeroProb(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 1000 trees with some weights. for i in range(0, 999): @@ -994,7 +994,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testAveragingAllTrees(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() adjusted_tree_ensemble_config = ( tree_config_pb2.DecisionTreeEnsembleConfig()) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 81f58de28cbe98bb996c6665114eeb0030ee52f9..074623699d9d82f999c9cbc483ddcd8a959f4bad 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -482,7 +482,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): """Sets up the quantile op tests. Create a batch of 4 examples having 2 dense and 4 sparse features. - Forth sparse feature is multivalent (3 dimensional) + Fourth sparse feature is multivalent (3 dimensional) The data looks like this | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 28834ef55bf8e1f32cc8f2380a4be3bf3824d8e1..5cd37ec67ec3bdefb6ea19049a7a12249162d45a 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops @@ -399,6 +401,65 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose(0.6, split_node.split.threshold) + def testMakeSparseSplitDefaultDirectionIsStable(self): + """Tests default direction is stable when no sparsity.""" + random.seed(1123) + for _ in range(50): + with self.test_session() as sess: + grad = random.random() + hessian = random.random() + # The data looks like the following (divide by the num of steps 2). + # Gradients | Partition | bucket ID | + # (grad, hessian) | 0 | -1 | + # And then 100 buckets of + # (grad/100, hessian/100), so there is no sparsity. + n_buckets = 100 + + # 1 for the overall sum, and 100 buckets. + partition_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + + bucket_ids = [-1] + [n for n in range(100)] + bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64) + dimension_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = [grad] + [grad / n_buckets] * n_buckets + gradients = array_ops.constant(gradients) + hessians = [hessian] + [hessian / n_buckets] * n_buckets + hessians = array_ops.constant(hessians) + + boundaries = [x * 1 for x in range(n_buckets + 1)] + bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32) + + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0], partitions) + self.assertEqual(1, len(splits)) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + self.assertTrue( + split_info.split_node.HasField( + 'sparse_float_binary_split_default_left')) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 8ca1aabacaf53b66aaba184962922294427d6803..3e524efbeac74ff754d63cae92b3e194411cb2de 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -1588,7 +1588,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual( 2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates) - def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self): + def testGrowExistingEnsembleTreeWithFeatureSelectionUsedHandlers(self): """Test growing a tree with feature selection.""" with self.test_session() as session: # Create existing ensemble with one root split and one bias tree. @@ -1649,7 +1649,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): num_trees_attempted: 2 num_layers_attempted: 2 used_handler_ids: 2 - used_handler_ids: 5 } """, tree_ensemble_config) tree_ensemble_handle = model_ops.tree_ensemble_variable( @@ -1668,183 +1667,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - # There are 2 handler_ids in used_handler_ids already but one of them - # is handler 2, so we can still grow trees. - learner_config.constraints.max_number_of_unique_feature_columns = 2 - learner_config = learner_config.SerializeToString() - # Prepare handler inputs. - handler1_partitions = np.array([0], dtype=np.int32) - handler1_gains = np.array([7.62], dtype=np.float32) - handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] - handler2_partitions = np.array([0], dtype=np.int32) - handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] - handler3_partitions = np.array([0], dtype=np.int32) - handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] - - # Grow tree ensemble. - grow_op = training_ops.grow_tree_ensemble( - tree_ensemble_handle, - stamp_token=0, - next_stamp_token=1, - learning_rate=1, - partition_ids=[ - handler1_partitions, handler2_partitions, handler3_partitions - ], - gains=[handler1_gains, handler2_gains, handler3_gains], - splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, - dropout_seed=123, - center_bias=True) - session.run(grow_op) - - # Expect a new tree to be added with the split from handler 1. - _, serialized = session.run( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)) - tree_ensemble_config.ParseFromString(serialized) - self.assertEqual(3, len(tree_ensemble_config.trees)) - self.assertEqual( - 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) - - def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self): - """Test growing a tree with feature selection with empty ensemble.""" - with self.test_session() as session: - # Create existing ensemble with one root split and one bias tree. - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble_config.SerializeToString(), - name="tree_ensemble") - resources.initialize_resources(resources.shared_resources()).run() - - # Prepare learner config. - learner_config = _gen_learner_config( - num_classes=2, - l1_reg=0, - l2_reg=0, - tree_complexity=0, - max_depth=1, - min_node_weight=0, - pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - learner_config.constraints.max_number_of_unique_feature_columns = 2 - learner_config = learner_config.SerializeToString() - # Prepare handler inputs. - handler1_partitions = np.array([0], dtype=np.int32) - handler1_gains = np.array([7.62], dtype=np.float32) - handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] - handler2_partitions = np.array([0], dtype=np.int32) - handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] - handler3_partitions = np.array([0], dtype=np.int32) - handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] - - # Grow tree ensemble. - grow_op = training_ops.grow_tree_ensemble( - tree_ensemble_handle, - stamp_token=0, - next_stamp_token=1, - learning_rate=1, - partition_ids=[ - handler1_partitions, handler2_partitions, handler3_partitions - ], - gains=[handler1_gains, handler2_gains, handler3_gains], - splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, - dropout_seed=123, - center_bias=True) - session.run(grow_op) - - _, serialized = session.run( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)) - tree_ensemble_config.ParseFromString(serialized) - self.assertEqual(1, len(tree_ensemble_config.trees)) - self.assertEqual( - 1, len(tree_ensemble_config.growing_metadata.used_handler_ids)) - - def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self): - """Test growing a tree with feature selection with empty ensemble.""" - with self.test_session() as session: - # Create existing ensemble with one root split and one bias tree. - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" - trees { - nodes { - leaf { - vector { - value: -0.32 - value: 0.28 - } - } - } - } - trees { - nodes { - categorical_id_binary_split { - feature_column: 3 - feature_id: 7 - left_id: 1 - right_id: 2 - } - node_metadata { - gain: 1.3 - } - } - nodes { - leaf { - sparse_vector { - index: 0 - value: 2.3 - } - } - } - nodes { - leaf { - sparse_vector { - index: 0 - value: -0.9 - } - } - } - } - tree_weights: 0.7 - tree_weights: 1 - tree_metadata { - num_tree_weight_updates: 1 - num_layers_grown: 1 - is_finalized: true - } - tree_metadata { - num_tree_weight_updates: 5 - num_layers_grown: 1 - is_finalized: true - } - growing_metadata { - num_trees_attempted: 2 - num_layers_attempted: 2 - used_handler_ids: 4 - used_handler_ids: 5 - } - """, tree_ensemble_config) - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble_config.SerializeToString(), - name="tree_ensemble") - resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = _gen_learner_config( - num_classes=2, - l1_reg=0, - l2_reg=0, - tree_complexity=0, - max_depth=1, - min_node_weight=0, - pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config.constraints.max_number_of_unique_feature_columns = 3 learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1876,12 +1700,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): _, serialized = session.run( model_ops.tree_ensemble_serialize(tree_ensemble_handle)) tree_ensemble_config.ParseFromString(serialized) - # We can't grow a tree since we have reached the limit of 2 unique - # features [4, 5] and the only available splits are from - # handlers [0, 1, 2]. - self.assertEqual(2, len(tree_ensemble_config.trees)) - self.assertEqual( - 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + self.assertEqual(3, len(tree_ensemble_config.trees)) + # 2 was already used. handler 0 is being added in this tree. + self.assertAllEqual( + [0, 2], tree_ensemble_config.growing_metadata.used_handler_ids) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 7a5f329b7ab3216972180ccbb4c85f2537175422..843420968ac6a6716fdf6b4967146e131139f67c 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc import collections +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -60,6 +62,7 @@ def _move_tensors(tensors, device): """Moves a list of tensors to a device by concatenating/splitting them.""" # Reset the device setting to avoid weird interactions with device merging # logic. + zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): with ops.device(tensors[0].device): @@ -68,12 +71,11 @@ def _move_tensors(tensors, device): return array_ops.unstack(values) else: with ops.device(tensors[0].device): - sizes = array_ops.stack( - [array_ops.shape(tensor)[0] for tensor in tensors]) - values = array_ops.concat(tensors, axis=0) + sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0] + values = array_ops.concat(tensors, axis=zero) with ops.device(device): sizes = array_ops.unstack(sizes) - return list(array_ops.split(values, sizes, axis=0)) + return list(array_ops.split(values, sizes, axis=zero)) def _scheduled_stamp_resource_op_runner(batch, stamp): diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 7a5f509047d46549ba81039a23d29ec987ca7920..25b2c9e2fd72bd018717e8a87fce726f26bad968 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_serialize # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_stamp_token +from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_used_handlers # pylint: enable=unused-import from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index 58f0d36b0f78eeed6abcec1c4fa696f4ccffa615..7f6e55ae5888fc4ef50e34690d61c3ed303e971a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -21,4 +21,5 @@ from __future__ import print_function from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose # pylint: enable=unused-import diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 97d57e8b23608d4c3a8719426a75056fc6417d1d..19b6b3296db394b07f57a25dbde187eb9195af38 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -184,10 +184,10 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Finalizes quantile summary stream and resets it for next iteration. Args: - stamp_token: Exepcted current token. + stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: - A list of quantiles or approximate boundaries. + The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, @@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result + + def resource(self): + return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index f0b66dcbbe1c5167b9993e66b30b1dc8a839c380..47698d45c81478f2b694aaadc603f742c44d5351 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -23,7 +23,6 @@ import copy from tensorflow.contrib import learn from tensorflow.contrib import stateless - from tensorflow.contrib.boosted_trees.lib.learner.batch import categorical_split_handler from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 @@ -57,6 +56,9 @@ PREDICTIONS = "predictions" PARTITION_IDS = "partition_ids" NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" +NUM_USED_HANDLERS = "num_used_handlers" +USED_HANDLERS_MASK = "used_handlers_mask" +LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -70,15 +72,24 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): +def _make_predictions_dict(stamp, + logits, + partition_ids, + ensemble_stats, + used_handlers, + leaf_index=None): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. - logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - that contains predictions when no dropout was applied. + logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that + contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. + used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a + boolean mask. + leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that + contains leaf id for each example prediction. Returns: A dict of predictions. @@ -89,6 +100,10 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): result[PARTITION_IDS] = partition_ids result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees + result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers + result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask + if leaf_index is not None: + result[LEAF_INDEX] = leaf_index return result @@ -134,7 +149,7 @@ class _OpRoundRobinStrategy(object): return task -def extract_features(features, feature_columns): +def extract_features(features, feature_columns, use_core_columns): """Extracts columns from a dictionary of features. Args: @@ -167,11 +182,14 @@ def extract_features(features, feature_columns): transformed_features = collections.OrderedDict() for fc in feature_columns: # pylint: disable=protected-access - if isinstance(fc, feature_column_lib._EmbeddingColumn): + if use_core_columns: + # pylint: disable=protected-access + tensor = fc_core._transform_features(features, [fc])[fc] + transformed_features[fc.name] = tensor + elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( - features, [fc], - weight_collections=[scope]) + features, [fc], weight_collections=[scope]) else: result = feature_column_ops.transform_features(features, [fc]) if len(result) > 1: @@ -258,7 +276,9 @@ class GradientBoostedDecisionTreeModel(object): learner_config, features, logits_dimension, - feature_columns=None): + feature_columns=None, + use_core_columns=False, + output_leaf_index=False): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -266,13 +286,15 @@ class GradientBoostedDecisionTreeModel(object): num_ps_replicas: Number of parameter server replicas, can be 0. ensemble_handle: A handle to the ensemble variable. center_bias: Whether to center the bias before growing trees. - examples_per_layer: Number of examples to accumulate before growing - a tree layer. It can also be a function that computes the number of - examples based on the depth of the layer that's being built. + examples_per_layer: Number of examples to accumulate before growing a tree + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. feature_columns: A list of feature columns. + output_leaf_index: A boolean variable indicating whether to output leaf + index into predictions dictionary. Raises: ValueError: if inputs are not valid. @@ -323,16 +345,19 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, - sparse_float_shapes, sparse_int_indices, sparse_int_values, - sparse_int_shapes) = extract_features(features, self._feature_columns) + sparse_float_shapes, sparse_int_indices, + sparse_int_values, sparse_int_shapes) = extract_features( + features, self._feature_columns, use_core_columns) logging.info("Active Feature Columns: " + str(fc_names)) self._fc_names = fc_names self._dense_floats = dense_floats @@ -342,9 +367,11 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_int_indices = sparse_int_indices self._sparse_int_values = sparse_int_values self._sparse_int_shapes = sparse_int_shapes - self._reduce_dim = (self._learner_config.multi_class_strategy == - learner_pb2.LearnerConfig.TREE_PER_CLASS and - learner_config.num_classes == 2) + self._reduce_dim = ( + self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + learner_config.num_classes == 2) + self._output_leaf_index = output_leaf_index def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -357,32 +384,61 @@ class GradientBoostedDecisionTreeModel(object): Returns: a dictionary of prediction results - ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, - NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. + NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPTED. """ ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) + num_handlers = ( + len(self._dense_floats) + len(self._sparse_float_shapes) + len( + self._sparse_int_shapes)) + # Used during feature selection. + used_handlers = model_ops.tree_ensemble_used_handlers( + ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) + # We don't need dropout info - we can always restore it based on the # seed. apply_dropout, seed = _dropout_params(mode, ensemble_stats) # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): - predictions, _ = prediction_ops.gradient_trees_prediction( - ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=mode != learn.ModeKeys.TRAIN, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim) + leaf_index = None + # Only used in infer (predict), not used in train and eval. + if self._output_leaf_index and mode == learn.ModeKeys.INFER: + predictions, _, leaf_index = ( + prediction_ops).gradient_trees_prediction_verbose( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + else: + leaf_index = None + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, @@ -395,7 +451,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats) + ensemble_stats, used_handlers, leaf_index) def predict(self, mode): """Returns predictions given the features and mode. @@ -413,8 +469,9 @@ class GradientBoostedDecisionTreeModel(object): # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") @@ -438,8 +495,8 @@ class GradientBoostedDecisionTreeModel(object): # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): - # Serialize the model from parameter server after reading all inputs. - with ops.control_dependencies(input_deps): + # Serialize the model from parameter server after reading the inputs. + with ops.control_dependencies([input_deps[0]]): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) @@ -481,8 +538,9 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ # Get the worker device from input dependencies. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) worker_device = input_deps[0].device # Get tensors relevant for training and form the loss. @@ -498,7 +556,7 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = -1 + class_id = constant_op.constant(-1, dtype=dtypes.int32) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. @@ -552,31 +610,39 @@ class GradientBoostedDecisionTreeModel(object): # Get the weights for each example for quantiles calculation, weights = self._get_weights(hessian_shape, squeezed_hessians) - regularization_config = self._learner_config.regularization - min_node_weight = self._learner_config.constraints.min_node_weight # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) + l1_regularization = constant_op.constant( + self._learner_config.regularization.l1, dtypes.float32) + l2_regularization = constant_op.constant( + self._learner_config.regularization.l2, dtypes.float32) + tree_complexity_regularization = constant_op.constant( + self._learner_config.regularization.tree_complexity, dtypes.float32) + min_node_weight = constant_op.constant( + self._learner_config.constraints.min_node_weight, dtypes.float32) + epsilon = 0.01 + num_quantiles = 100 + strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.DenseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=dense_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -585,14 +651,13 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.SparseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( self._sparse_float_indices[sparse_float_column_idx], self._sparse_float_values[sparse_float_column_idx], @@ -600,7 +665,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -609,10 +674,9 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( categorical_split_handler.EqualitySplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_int_column_idx, sparse_int_column=sparse_tensor.SparseTensor( @@ -622,7 +686,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -675,11 +739,11 @@ class GradientBoostedDecisionTreeModel(object): name="continue_centering", trainable=False) stats_update_ops.append( - control_flow_ops.cond(continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), - control_flow_ops.no_op)) + control_flow_ops.cond( + continue_centering, + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator), + control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -701,8 +765,8 @@ class GradientBoostedDecisionTreeModel(object): shape=[len(handlers)], seed=[seed + 1, 1]) active_handlers = array_ops.stack( [active_handlers_current_layer, active_handlers_next_layer], axis=1) - active_handlers = (active_handlers < - self._learner_config.feature_fraction_per_level) + active_handlers = ( + active_handlers < self._learner_config.feature_fraction_per_level) elif subsampling_type == "feature_fraction_per_tree": seed = predictions_dict[NUM_TREES_ATTEMPTED] active_handlers_current_layer = stateless.stateless_random_uniform( @@ -710,12 +774,31 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack(active_handlers_current_layer, - array_ops.ones( - [len(handlers)], dtype=dtypes.bool)) + active_handlers = array_ops.stack( + [ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool) + ], + axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) + if self._learner_config.constraints.max_number_of_unique_feature_columns: + target = ( + self._learner_config.constraints.max_number_of_unique_feature_columns) + + def _feature_selection_active_handlers(): + # The active list for current and the next iteration. + used_handlers = array_ops.reshape(predictions_dict[USED_HANDLERS_MASK], + [-1, 1]) + used_handlers = array_ops.concat([used_handlers, used_handlers], axis=1) + return math_ops.logical_and(used_handlers, active_handlers) + + active_handlers = ( + control_flow_ops.cond(predictions_dict[NUM_USED_HANDLERS] >= target, + _feature_selection_active_handlers, + lambda: active_handlers)) + # Prepare empty gradients and hessians when handlers are not ready. empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() @@ -725,6 +808,7 @@ class GradientBoostedDecisionTreeModel(object): empty_hessians = constant_op.constant( [], dtype=dtypes.float32, shape=empty_hess_shape) + active_handlers = array_ops.unstack(active_handlers, axis=0) for handler_idx in range(len(handlers)): handler = handlers[handler_idx] is_active = active_handlers[handler_idx] @@ -866,7 +950,6 @@ class GradientBoostedDecisionTreeModel(object): "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - "QuantileStreamResourceHandleOp", ] ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( @@ -935,10 +1018,8 @@ class GradientBoostedDecisionTreeModel(object): # Stack all the inputs to one tensor per type. # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). - split_sizes = array_ops.stack([ - array_ops.shape(partition_id)[0] - for partition_id in partition_ids_list - ]) + split_sizes = array_ops.reshape( + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) @@ -1003,8 +1084,11 @@ class GradientBoostedDecisionTreeModel(object): # Update ensemble. update_ops = [are_all_splits_ready] - update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, - _grow_ensemble_fn) + if self._center_bias: + update_model = control_flow_ops.cond(continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() update_ops.append(update_model) # Update ensemble stats. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index dba51d4f527792d2a8dedc693f74c07119fd231d..e3d4397fadcbaf148f7f6cfaca13e850639786cf 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -19,17 +19,15 @@ from __future__ import division from __future__ import print_function from google.protobuf import text_format - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.boosted_trees.python.utils import losses - from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn - +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -43,10 +41,42 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keep_dims=True) + math_ops.square(predictions - label), 1, keepdims=True) return loss +def _append_to_leaf(leaf, c_id, w): + """Helper method for building tree leaves. + + Appends weight contributions for the given class index to a leaf node. + + Args: + leaf: leaf node to append to. + c_id: class Id for the weight update. + w: weight contribution value. + """ + leaf.sparse_vector.index.append(c_id) + leaf.sparse_vector.value.append(w) + + +def _set_float_split(split, feat_col, thresh, l_id, r_id): + """Helper method for building tree float splits. + + Sets split feature column, threshold and children. + + Args: + split: split node to update. + feat_col: feature column for the split. + thresh: threshold to split on forming rule x <= thresh. + l_id: left child Id. + r_id: right child Id. + """ + split.feature_column = feat_col + split.threshold = thresh + split.left_id = l_id + split.right_id = r_id + + class GbdtTest(test_util.TensorFlowTestCase): def setUp(self): @@ -63,11 +93,12 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_int"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros([2], dtypes.int64), - array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.int64), array_ops.zeros([2], + dtypes.int64)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, - sparse_int_shapes) = (gbdt_batch.extract_features(features, None)) + sparse_int_shapes) = ( + gbdt_batch.extract_features(features, None, use_core_columns=False)) self.assertEqual(len(fc_names), 3) self.assertAllEqual(fc_names, ["dense_float", "sparse_float", "sparse_int"]) @@ -104,8 +135,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_categorical"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros( - [2], dtypes.string), array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) feature_columns = set() feature_columns.add(layers.real_valued_column("dense_float")) feature_columns.add( @@ -116,8 +147,9 @@ class GbdtTest(test_util.TensorFlowTestCase): "sparse_categorical", hash_bucket_size=1000000)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, - sparse_int_shapes) = (gbdt_batch.extract_features( - features, feature_columns)) + sparse_int_shapes) = ( + gbdt_batch.extract_features( + features, feature_columns, use_core_columns=False)) self.assertEqual(len(fc_names), 3) self.assertAllEqual(fc_names, ["dense_float", "sparse_float", "sparse_categorical"]) @@ -142,6 +174,41 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllEqual(sparse_int_shapes[0].eval(), features["sparse_categorical"].dense_shape.eval()) + def testExtractFeaturesFromCoreFeatureColumns(self): + """Tests feature extraction when using core columns.""" + with self.test_session(): + features = {} + # Sparse float column does not exist in core, so only dense numeric and + # categorical. + features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32) + features["sparse_categorical"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) + + feature_columns = set() + feature_columns.add(core_feature_column.numeric_column("dense_float")) + feature_columns.add( + core_feature_column.categorical_column_with_hash_bucket( + "sparse_categorical", hash_bucket_size=1000000)) + (fc_names, dense_floats, _, _, _, sparse_int_indices, sparse_int_values, + sparse_int_shapes) = ( + gbdt_batch.extract_features( + features, feature_columns, use_core_columns=True)) + self.assertEqual(len(fc_names), 2) + self.assertAllEqual(fc_names, ["dense_float", "sparse_categorical"]) + self.assertEqual(len(dense_floats), 1) + self.assertEqual(len(sparse_int_indices), 1) + self.assertEqual(len(sparse_int_values), 1) + self.assertEqual(len(sparse_int_shapes), 1) + self.assertAllEqual(dense_floats[0].eval(), + features["dense_float"].eval()) + self.assertAllEqual(sparse_int_indices[0].eval(), + features["sparse_categorical"].indices.eval()) + self.assertAllEqual(sparse_int_values[0].eval(), [397263, 397263]) + self.assertAllEqual(sparse_int_shapes[0].eval(), + features["sparse_categorical"].dense_shape.eval()) + def testTrainFnChiefNoBiasCentering(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -164,7 +231,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -245,6 +313,113 @@ class GbdtTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): + """Tests the train function with sparse and dense features.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + features["sparse_float"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.float32), + array_ops.constant([4, 1], dtypes.int64)) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + sparse_float_binary_split_default_right { + split{ + left_id: 1 + right_id: 2 + } + } + node_metadata { + gain: 1.125 + } + } + nodes { + leaf { + vector { + value: 1.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -268,7 +443,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -371,7 +547,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -442,7 +619,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -505,7 +683,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -551,7 +730,8 @@ class GbdtTest(test_util.TensorFlowTestCase): with self.test_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" + text_format.Merge( + """ trees { nodes { leaf { @@ -588,15 +768,128 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL predictions_dict = sess.run(gbdt_model.predict(mode)) self.assertEquals(predictions_dict["ensemble_stamp"], 3) - self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25], - [0.25], [0.25]]) + self.assertAllClose(predictions_dict["predictions"], + [[0.25], [0.25], [0.25], [0.25]]) + self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + + def testPredictFnWithLeafIndexAdvancedLeft(self): + """Tests the predict function with output leaf ids.""" + with self.test_session() as sess: + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.15 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: 0.99 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 00 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.23 + } + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=3, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.constant( + [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features, + output_leaf_index=True) + + # Create predict op. + mode = model_fn.ModeKeys.INFER + predictions_dict = sess.run(gbdt_model.predict(mode)) + self.assertEquals(predictions_dict["ensemble_stamp"], 3) + # here are how the numbers in expected results are calculated, + # 0.5 = 0.25 + 0.25 + # 0.48 = 0.25 + 0.23 + # 0.38 = 0.15 + 0.23 + # 0.38 = 0.15 + 0.23 + self.assertAllClose(predictions_dict["predictions"], + [[0.5], [0.48], [0.38], [0.38]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + self.assertAllClose(predictions_dict["leaf_index"], + [[1, 1], [1, 2], [2, 2], [2, 2]]) def testTrainFnMulticlassFullHessian(self): """Tests the GBDT train for multiclass full hessian.""" @@ -627,7 +920,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -730,7 +1024,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -822,8 +1117,8 @@ class GbdtTest(test_util.TensorFlowTestCase): learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = { - "dense_float": array_ops.constant( - [[1.0], [1.5], [2.0]], dtypes.float32), + "dense_float": + array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32), } gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( @@ -833,7 +1128,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) batch_size = 3 predictions = array_ops.constant( @@ -915,7 +1211,354 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllClose( 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0], - atol=1e-4, rtol=1e-4) + atol=1e-4, + rtol=1e-4) + + def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + # Feature 1 is predictive but it won't be used because we have reached the + # limit of num_used_handlers >= max_number_of_unique_feature_columns + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([True, False], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + # On second run, expect a trivial split to be chosen to basically + # predict the average. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + dense_float_binary_split { + feature_column: 0 + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: -0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + + def testTrainFnChiefFeatureSelectionWithGoodSplits(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + # Feature 1 is predictive and is in our selected features so it will be + # used even when we're at the limit. + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([False, True], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + dense_float_binary_split { + feature_column: 1 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0.5 + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + + def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree = tree_ensemble_config.trees.add() + + _set_float_split( + tree.nodes.add().sparse_float_binary_split_default_right.split, 2, + 4.0, 1, 2) + _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) + _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) + tree_ensemble_config.tree_weights.append(1.0) + metadata = tree_ensemble_config.tree_metadata.add() + metadata.is_finalized = False + metadata.num_layers_grown = 1 + tree_ensemble_config = tree_ensemble_config.SerializeToString() + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config, + name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + # Both features will be disabled since the feature selection limit is + # already reached. + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + # We have somehow reached our limit 1. Both of the handlers will be + # disabled. + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([False, False], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertEquals(output.growing_metadata.num_layers_attempted, 1) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + # Make sure the trees are not modified, but the num_layers_attempted is + # incremented so that eventually the training stops. + self.assertEquals(len(output.trees), 1) + self.assertEquals(len(output.trees[0].nodes), 3) + + self.assertEquals(output.growing_metadata.num_layers_attempted, 2) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD index 9fc101612f1e2a6bf6c5d86ea8c7199936dbb069..c0651868453d40d57e842862855f89e6845c507f 100644 --- a/tensorflow/contrib/boosted_trees/resources/BUILD +++ b/tensorflow/contrib/boosted_trees/resources/BUILD @@ -9,17 +9,6 @@ package( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - cc_library( name = "stamped_resource", hdrs = ["stamped_resource.h"], diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 3ebf28ea442edf87815c39971ae9e01a2a8aae9a..94aeb2c7bb48c6eddb6c7894f8bf6f1567470113 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -126,7 +126,8 @@ class DecisionTreeEnsembleResource : public StampedResource { return; } used_ids->Add(handler_id); - std::rotate(first, used_ids->end() - 1, used_ids->end()); + // Keep the list of used handlers sorted. + std::sort(used_ids->begin(), used_ids->end()); } std::vector GetUsedHandlers() const { diff --git a/tensorflow/contrib/checkpoint/README.md b/tensorflow/contrib/checkpoint/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d35c5bae3b702c0fea5194e5e653660e319e38c5 --- /dev/null +++ b/tensorflow/contrib/checkpoint/README.md @@ -0,0 +1,2 @@ +Tools for working with object-based checkpoints produced by +`tf.train.Checkpoint`. diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae493ba998bd882b5ef946f927ec1882d91f61d --- /dev/null +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for working with object-based checkpoints. + +Visualization and inspection: +@@dot_graph_from_checkpoint +@@object_metadata + +Managing dependencies: +@@Checkpointable +@@CheckpointableObjectGraph +@@NoDependency +@@split_dependency + +Checkpointable data structures: +@@List +@@Mapping +@@UniqueNameTracker +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker +from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency +from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint +from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.data_structures import List +from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.util import object_metadata + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(module_name=__name__) + diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7b200a29bf60087d6da1010b0be05c04faec80cd --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -0,0 +1,95 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "checkpoint", + srcs_version = "PY2AND3", + deps = [ + ":containers", + ":split_dependency", + ":visualize", + "//tensorflow/python/training/checkpointable:data_structures", + ], +) + +py_library( + name = "containers", + srcs = ["containers.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:data_structures", + ], +) + +py_test( + name = "containers_test", + srcs = ["containers_test.py"], + deps = [ + ":containers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", + "@six_archive//:six", + ], +) + +py_library( + name = "split_dependency", + srcs = ["split_dependency.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", + ], +) + +py_test( + name = "split_dependency_test", + srcs = ["split_dependency_test.py"], + deps = [ + ":split_dependency", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", + ], +) + +py_library( + name = "visualize", + srcs = ["visualize.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", + ], +) + +py_test( + name = "visualize_test", + srcs = ["visualize_test.py"], + deps = [ + ":visualize", + "//tensorflow/python:constant_op", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//tensorflow/python/training/checkpointable:util", + ], +) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3d5312993740636709cb732c0b8e3e2626262d --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -0,0 +1,80 @@ +"""Checkpointable data structures.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures + + +class UniqueNameTracker(data_structures.CheckpointableDataStructure): + """Adds dependencies on checkpointable objects with name hints. + + Useful for creating dependencies with locally unique names. + + Example usage: + ```python + class SlotManager(tf.contrib.checkpoint.Checkpointable): + + def __init__(self): + # Create a dependency named "slotdeps" on the container. + self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tfe.Variable(4.), "y")) + slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + ``` + """ + + def __init__(self): + super(UniqueNameTracker, self).__init__() + self._maybe_initialize_checkpointable() + self._name_counts = {} + + def track(self, checkpointable, base_name): + """Add a dependency on `checkpointable`. + + Args: + checkpointable: An object to add a checkpoint dependency on. + base_name: A name hint, which is uniquified to determine the dependency + name. + Returns: + `checkpointable`, for chaining. + Raises: + ValueError: If `checkpointable` is not a checkpointable object. + """ + + if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + raise ValueError( + ("Expected a checkpointable value, got %s which does not inherit " + "from CheckpointableBase.") % (checkpointable,)) + + def _format_name(prefix, number): + if number > 0: + return "%s_%d" % (prefix, number) + else: + return prefix + + count = self._name_counts.get(base_name, 0) + candidate = _format_name(base_name, count) + while self._lookup_dependency(candidate) is not None: + count += 1 + candidate = _format_name(base_name, count) + self._name_counts[base_name] = count + 1 + self._track_value(checkpointable, name=candidate) + return checkpointable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3717d7f583ffdc205a279d45df60cddbc5cbf08e --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from tensorflow.contrib.checkpoint.python import containers +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class UniqueNameTrackerTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testNames(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + x1 = resource_variable_ops.ResourceVariable(2.) + x2 = resource_variable_ops.ResourceVariable(3.) + x3 = resource_variable_ops.ResourceVariable(4.) + y = resource_variable_ops.ResourceVariable(5.) + slots = containers.UniqueNameTracker() + slots.track(x1, "x") + slots.track(x2, "x") + slots.track(x3, "x_1") + slots.track(y, "y") + self.evaluate((x1.initializer, x2.initializer, x3.initializer, + y.initializer)) + save_root = checkpointable_utils.Checkpoint(slots=slots) + save_path = save_root.save(checkpoint_prefix) + + restore_slots = checkpointable.Checkpointable() + restore_root = checkpointable_utils.Checkpoint( + slots=restore_slots) + status = restore_root.restore(save_path) + restore_slots.x = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.y = resource_variable_ops.ResourceVariable(0.) + status.assert_consumed().run_restore_ops() + self.assertEqual(2., self.evaluate(restore_slots.x)) + self.assertEqual(3., self.evaluate(restore_slots.x_1)) + self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) + self.assertEqual(5., self.evaluate(restore_slots.y)) + + @test_util.run_in_graph_and_eager_modes() + def testExample(self): + class SlotManager(checkpointable.Checkpointable): + + def __init__(self): + self.slotdeps = containers.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(3.), "x")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(4.), "y")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(5.), "x")) + self.slots = slots + + manager = SlotManager() + self.evaluate([v.initializer for v in manager.slots]) + checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + metadata = checkpointable_utils.object_metadata(save_path) + dependency_names = [] + for node in metadata.nodes: + for child in node.children: + dependency_names.append(child.local_name) + six.assertCountEqual( + self, + dependency_names, + ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + + @test_util.run_in_graph_and_eager_modes() + def testLayers(self): + tracker = containers.UniqueNameTracker() + tracker.track(layers.Dense(3), "dense") + tracker.layers[0](array_ops.zeros([1, 1])) + self.assertEqual(2, len(tracker.trainable_weights)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..7e77453f3d848c2e321ed2ba66917a742d95459a --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -0,0 +1,136 @@ +"""Utility for creating multiple dependencies with synchronized save/restore.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import base as checkpointable + + +class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + """Wraps save and restore callbacks as a `SaveableObject`.""" + + def __init__(self, name, dtype, save_callback, restore_callback): + self._restore_callback = restore_callback + spec = saver_lib.BaseSaverBuilder.SaveSpec( + tensor=save_callback, + slice_spec="", + name=name, + dtype=dtype) + super(_CallbackSaveable, self).__init__( + save_callback, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return self._restore_callback(tensor) + + +class _SplitDependency(checkpointable.CheckpointableBase): + """Looks like a regular variable while synchronizing save/restores.""" + + def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, + fill_save_buffer_fn, consume_restore_buffer_fn): + self._save_buffer = save_buffer + self._restore_buffer = restore_buffer + self._name = name + self._dtype = dtype + self._num_components = num_components + self._fill_save_buffer_fn = fill_save_buffer_fn + self._consume_restore_buffer_fn = consume_restore_buffer_fn + + def _save(self): + """Pull from the shared buffer, populating it if necessary.""" + if self._name not in self._save_buffer: + if self._save_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be saved together.") % (self._name, self)) + self._fill_save_buffer_fn(self._save_buffer) + return self._save_buffer.pop(self._name) + + def _restore(self, tensor): + """Push into the shared buffer, flushing it if necessary.""" + if self._name in self._restore_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be restored together.") % (self._name, self)) + self._restore_buffer[self._name] = tensor + if len(self._restore_buffer) == self._num_components: + op = self._consume_restore_buffer_fn(self._restore_buffer) + self._restore_buffer.clear() + return op + else: + return control_flow_ops.no_op() + + def _gather_saveables_for_checkpoint(self): + """Looks to Checkpointable like a regular variable.""" + return { + checkpointable.VARIABLE_VALUE_KEY: + functools.partial(_CallbackSaveable, + dtype=self._dtype, + save_callback=self._save, + restore_callback=self._restore) + } + + +def split_dependency(component_names, component_dtypes, + fill_save_buffer_fn, consume_restore_buffer_fn): + """Creates multiple dependencies with a synchronized save/restore. + + Useful when a single op produces `Tensor`s which should each be saved under + different objects, or when `Tensor`s saved with many different objects need to + be restored together as inputs to a single op (i.e. an object which uses a + single fused op may be swapped out for a subgraph of objects, and these two + programs are checkpoint compatible). + + Args: + component_names: A sequence of names for the split + dependencies. `fill_save_buffer_fn` must add these keys to the dictionary + it is passed, and `consume_restore_buffer_fn` will receive a dictionary + with these keys. + component_dtypes: Data types for the `Tensor`s being saved and restored, a + sequence corresponding to `component_names`. + fill_save_buffer_fn: A function which takes an empty dictionary as an + argument and adds `Tensor`s with `component_names` as keys. These + `Tensor`s will be saved as if they were individual variables. + consume_restore_buffer_fn: A function which takes a dictionary with + `component_names` as keys mapping to restored individual `Tensor`s and + returns a restore op (or if executing eagerly, runs the restoration and + may return `None`). + + Returns: + A dictionary mapping from names to Checkpointable objects. If one is + reachable from an object as a dependency, the others should be too; adding + dependencies on some but not all of the objects will result in errors. + """ + save_buffer = {} + restore_buffer = {} + split_dependencies = {} + for name, dtype in zip(component_names, component_dtypes): + split_dependencies[name] = _SplitDependency( + save_buffer=save_buffer, + restore_buffer=restore_buffer, + name=name, + dtype=dtype, + num_components=len(component_names), + fill_save_buffer_fn=fill_save_buffer_fn, + consume_restore_buffer_fn=consume_restore_buffer_fn) + return split_dependencies diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..69dc0b9be2d5548852c37552a64a0d31c9557b43 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.checkpoint.python import split_dependency +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +def _split_variable_closure(variable): + def _fill_save_buffer_fn(save_buffer): + save_buffer["first_half"] = variable[:2] + save_buffer["second_half"] = variable[2:] + return _fill_save_buffer_fn + + +def _combine_variable_closure(variable): + def _consume_restore_buffer_fn(restore_buffer): + return variable.assign( + array_ops.concat([restore_buffer["first_half"], + restore_buffer["second_half"]], + axis=0)) + return _consume_restore_buffer_fn + + +class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): + + def __init__(self): + self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) + split_dependencies = split_dependency.split_dependency( + component_names=("first_half", "second_half"), + component_dtypes=(self.combined.dtype,) * 2, + fill_save_buffer_fn=_split_variable_closure( + self.combined), + consume_restore_buffer_fn=_combine_variable_closure( + self.combined)) + for name, dep in split_dependencies.items(): + self._track_checkpointable(dep, name=name) + + +class HasRegularDeps(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class OnlyOneDep(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class SplitTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestoreSplitDep(self): + save_checkpoint = checkpointable_utils.Checkpoint( + dep=SaveTensorSlicesAsDeps()) + self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_checkpoint.save(checkpoint_prefix) + + regular_deps = HasRegularDeps() + regular_restore_checkpoint = checkpointable_utils.Checkpoint( + dep=regular_deps) + regular_restore_checkpoint.restore( + save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half)) + self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) + + one_dep = OnlyOneDep() + one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + status = one_dep_restore_checkpoint.restore(save_path) + with self.assertRaises(AssertionError): + # Missing the second dependency. + status.assert_consumed() + status.run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) + + restore_checkpoint = checkpointable_utils.Checkpoint() + status = restore_checkpoint.restore(save_path) + restore_checkpoint.dep = SaveTensorSlicesAsDeps() + status.assert_consumed().run_restore_ops() + self.assertAllEqual( + [1., 2., 3., 4.], + self.evaluate(restore_checkpoint.dep.combined)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..bac071c4cff383f60b707b6e42c13faf5e0ac948 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -0,0 +1,99 @@ +"""Utilities for visualizing dependency graphs.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +def dot_graph_from_checkpoint(save_path): + r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`). + + Useful for inspecting checkpoints and debugging loading issues. + + Example usage from Python (requires pydot): + ```python + import tensorflow as tf + import pydot + + dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt') + parsed, = pydot.graph_from_dot_data(dot_string) + parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg') + ``` + + Example command line usage: + ```sh + python -c "import tensorflow as tf;\ + print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\ + | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg + ``` + + Args: + save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save` + or `tf.train.latest_checkpoint`. + Returns: + A graph in DOT format as a string. + """ + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + object_graph = checkpointable_utils.object_metadata(save_path) + shape_map = reader.get_variable_to_shape_map() + dtype_map = reader.get_variable_to_dtype_map() + graph = 'digraph {\n' + def _escape(name): + return name.replace('"', '\\"') + slot_ids = set() + for node in object_graph.nodes: + for slot_reference in node.slot_variables: + slot_ids.add(slot_reference.slot_variable_node_id) + for node_id, node in enumerate(object_graph.nodes): + if (len(node.attributes) == 1 + and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY): + if node_id in slot_ids: + color = 'orange' + tooltip_prefix = 'Slot variable' + else: + color = 'blue' + tooltip_prefix = 'Variable' + attribute = node.attributes[0] + graph += ('N_%d [shape=point label="" color=%s width=.25' + ' tooltip="%s %s shape=%s %s"]\n') % ( + node_id, + color, + tooltip_prefix, + _escape(attribute.full_name), + shape_map[attribute.checkpoint_key], + dtype_map[attribute.checkpoint_key].name) + elif node.slot_variables: + graph += ('N_%d [shape=point label="" width=.25 color=red,' + 'tooltip="Optimizer"]\n') % node_id + else: + graph += 'N_%d [shape=point label="" width=.25]\n' % node_id + for reference in node.children: + graph += 'N_%d -> N_%d [label="%s"]\n' % ( + node_id, reference.node_id, _escape(reference.local_name)) + for slot_reference in node.slot_variables: + graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % ( + node_id, + slot_reference.slot_variable_node_id, + _escape(slot_reference.slot_name)) + graph += 'N_%d -> N_%d [style=dotted]\n' % ( + slot_reference.original_variable_node_id, + slot_reference.slot_variable_node_id) + graph += '}\n' + return graph diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py new file mode 100644 index 0000000000000000000000000000000000000000..583e3bc442893d825c337d73fb999d1e586738a1 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +from tensorflow.contrib.checkpoint.python import visualize + +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import adam +from tensorflow.python.training.checkpointable import util as checkpointable_utils + +try: + import pydot # pylint: disable=g-import-not-at-top +except ImportError: + pydot = None + + +class MyModel(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(MyModel, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + +class DotGraphTests(test.TestCase): + + def testMakeDotGraph(self): + with context.eager_mode(): + input_value = constant_op.constant([[3.]]) + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + optimizer_step = resource_variable_ops.ResourceVariable(12) + save_checkpoint = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + optimizer.minimize(functools.partial(model, input_value)) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + save_path = save_checkpoint.save(checkpoint_prefix) + prefix = save_checkpoint.save(save_path) + + dot_graph_string = visualize.dot_graph_from_checkpoint(prefix) + + # The remainder of this test is more-or-less optional since it's so + # dependent on pydot/platform/Python versions. + if pydot is None: + self.skipTest('pydot is required for the remainder of this test.') + try: + parsed, = pydot.graph_from_dot_data(dot_graph_string) + except NameError as e: + if "name 'dot_parser' is not defined" in str(e): + self.skipTest("pydot isn't working") + else: + raise + # Check that the graph isn't completely trivial + self.assertEqual( + '"model"', + parsed.obj_dict['edges'][('N_0', 'N_1')][0]['attributes']['label']) + image_path = os.path.join(self.get_temp_dir(), 'saved.svg') + try: + parsed.write_svg(image_path) + except Exception as e: # pylint: disable=broad-except + # For some reason PyDot's "dot not available" error is an Exception, not + # something more specific. + if '"dot" not found in path' in str(e): + self.skipTest("pydot won't save SVGs (dot not available)") + else: + raise + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index fe8bd072afd43a64fa62a65bd8900b5a98dbe761..42ba368531468b789a87429f88ca84937f9b909d 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -14,20 +14,11 @@ load( "tf_py_test", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -40,15 +31,25 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95dfd9f8c4b1655c475fe23e0639924f..a6e13ea3ae938444b9ead0772e52fb8797a847da 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -20,9 +20,15 @@ from __future__ import print_function # pylint: disable=line-too-long,wildcard-import from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * # pylint: enable=line-too-long,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'ConfigureColabSession', + 'ConfigureGcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 56f930a9a8d32c5c3a025163ef56c9562f17d864..40160706f70e8fa8323005dd183770ed51c8c415 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -20,20 +20,6 @@ load( "tf_proto_library", ) -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_kernel_library( name = "bigquery_reader_ops", srcs = ["bigquery_reader_ops.cc"], @@ -73,6 +59,7 @@ tf_cc_test( ], deps = [ ":bigquery_table_accessor", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -86,3 +73,17 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index e9b79a066def566096d6c3f3745974423e3371d1..7416eb19d3324fad84876cde5353bc25bac8f648 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" #include "tensorflow/core/platform/test.h" @@ -28,8 +29,8 @@ constexpr char kTestProject[] = "test-project"; constexpr char kTestDataset[] = "test-dataset"; constexpr char kTestTable[] = "test-table"; -bool HasSubstr(const string& base, const string& substr) { - bool ok = StringPiece(base).contains(substr); +bool HasSubstr(StringPiece base, StringPiece substr) { + bool ok = str_util::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..648a219fb87a6ebc64767a7da780013ef6b95443 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cf85f5f1811d873075b6d2e1931d8badfd6e32c --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8c5acb31af69b4f738a13c6548cdd31947d71a --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,188 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 80e18a43a71cc9d6c9e2ccf5836e50c6427a30f6..c239e6f8f960910cee14e1df7c4678c643496f54 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -10,19 +10,6 @@ package( licenses(["notice"]) # Apache 2.0 -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - py_library( name = "cluster_resolver_pip", srcs = [ @@ -30,6 +17,7 @@ py_library( "python/training/__init__.py", ], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":cluster_resolver_py", ":gce_cluster_resolver_py", @@ -109,5 +97,6 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training", ], + grpc_enabled = True, main = "python/training/tpu_cluster_resolver_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index b04822fa9d66465e34a545d3b00c399bbb196514..1c480b25134b1e54200e0ddb780bd7bb0f122341 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -53,11 +53,16 @@ class ClusterResolver(object): raise NotImplementedError( 'cluster_spec is not implemented for {}.'.format(self)) + @abc.abstractmethod + def master(self): + """...""" + raise NotImplementedError('master is not implemented for {}.'.format(self)) + class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec): + def __init__(self, cluster_spec, master=''): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() @@ -65,10 +70,18 @@ class SimpleClusterResolver(ClusterResolver): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec + if not isinstance(master, str): + raise TypeError('master must be a string.') + self._master = master + def cluster_spec(self): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec + def master(self): + """Returns the master address to use when creating a session.""" + return self._master + class UnionClusterResolver(ClusterResolver): """Performs a union on underlying ClusterResolvers. @@ -87,9 +100,13 @@ class UnionClusterResolver(ClusterResolver): Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. + ValueError: If there are no arguments passed. """ super(UnionClusterResolver, self).__init__() + if not args: + raise ValueError('At least one ClusterResolver is required.') + for cluster_resolver in args: if not isinstance(cluster_resolver, ClusterResolver): raise TypeError('All arguments must be a sub-class of ' @@ -169,3 +186,7 @@ class UnionClusterResolver(ClusterResolver): merged_cluster[job_name].update(task_dict) return ClusterSpec(merged_cluster) + + def master(self): + """master returns the master address from the first cluster resolver.""" + return self._cluster_resolvers[0].master() diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index dbfb77723cdaab66e29bb41b764593bb5fd61b35..d9c97d53eb3663f6ab2f7b40395592dc7638b896 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -234,5 +234,7 @@ class UnionClusterResolverTest(test.TestCase): self._verifyClusterSpecEquality(cluster_spec, expected_proto) +# TODO(saeta): Include tests for master resolution + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index d6f2eced93ba4fda5ac27f9412b6f729981f4f40..3f5824128948453634bc5e5a7d6fdeedae60f5bd 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -134,3 +134,6 @@ class GceClusterResolver(ClusterResolver): worker_list.sort() return ClusterSpec({self._job_name: worker_list}) + + def master(self): + return '' diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index a6a6e642e4e4c721b94821a70d55d6fe931347d6..8f521ffee4d31e090c13bac98290656d6e1d330e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,12 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from six.moves.urllib.request import Request from six.moves.urllib.request import urlopen from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat _GOOGLE_API_CLIENT_INSTALLED = True try: @@ -33,6 +35,12 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' +_DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' + + class TPUClusterResolver(ClusterResolver): """Cluster Resolver for Google Cloud TPUs. @@ -46,22 +54,56 @@ class TPUClusterResolver(ClusterResolver): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) resp = urlopen(req) - return resp.read() + return compat.as_bytes(resp.read()) + + def _shouldResolve(self): + if (self._tpu == compat.as_bytes('') or + self._tpu == compat.as_bytes('local') or + self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('grpc://'))): + return False + return True + + @staticmethod + def _inGke(): + """When running in GKE, the environment variable will be set.""" + return _GKE_ENV_VARIABLE in os.environ + + @staticmethod + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] + + @staticmethod + def _envVarFallback(): + if _DEFAULT_ENV_VARIABLE in os.environ: + return os.environ[_DEFAULT_ENV_VARIABLE] + return None + + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) def __init__(self, - tpu_names, + tpu=None, zone=None, project=None, - job_name='tpu_worker', + job_name='worker', + coordinator_name=None, + coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs for the IP addresses and ports of each Cloud TPU listed. Args: - tpu_names: A list of names of the target Cloud TPUs. + tpu: Either a string, or a list of strings corresponding to the TPUs to + use. If the single string is the empty string, the string 'local', or a + string that begins with 'grpc://' or '/bns', then it is assumed to not + correspond with a Cloud TPU and will instead be passed as the session + master and no ClusterSpec propagation will be done. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. @@ -69,63 +111,121 @@ class TPUClusterResolver(ClusterResolver): empty, we will try to discover the project name of the GCE VM from the GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. + coordinator_name: The name to use for the coordinator. Set to None if the + coordinator should not be included in the computed ClusterSpec. + coordinator_address: The address of the coordinator (typically an ip:port + pair). If set to None, a TF server will be started. If coordinator_name + is None, a TF server will not be started even if coordinator_address is + None. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. + ValueError: If no TPUs are specified. """ + if isinstance(tpu, list): + if not tpu: + raise ValueError('At least one TPU must be specified.') + if len(tpu) != 1: + raise NotImplementedError( + 'Using multiple TPUs in a single session is not yet implemented') + tpu = tpu[0] - if not project: - project = self._requestComputeMetadata('/project/project-id') + in_gke = self._inGke() + # When using GKE with Cloud TPUs, the env variable will be set. + if tpu is None: + if in_gke: + tpu = self._gkeEndpoints() + else: + tpu = self._envVarFallback() - if not zone: - zone_path = self._requestComputeMetadata('/instance/zone') + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes + self._job_name = job_name + self._credentials = credentials + + should_resolve = self._shouldResolve() + + if not project and should_resolve: + project = compat.as_str( + self._requestComputeMetadata('project/project-id')) + + if not zone and should_resolve: + zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] self._project = project self._zone = zone - self._tpu_names = tpu_names - self._job_name = job_name - self._credentials = credentials - if credentials == 'default': + if credentials == 'default' and should_resolve: if _GOOGLE_API_CLIENT_INSTALLED: self._credentials = GoogleCredentials.get_application_default() - if service is None: + if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver') + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service - def get_master(self): - """Get the ClusterSpec grpc master path. + self._coordinator_name = coordinator_name + if coordinator_name and not coordinator_address and (should_resolve or + in_gke): + self._start_local_server() + else: + self._coordinator_address = coordinator_address + + def master(self): + """Get the Master string to be used for the session. - This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the - ClusterSpec returned by the cluster_spec function. This is suitable for use - for the `master` argument in tf.Session() when you are using one TPU. + In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of + first instance in the ClusterSpec returned by the cluster_spec function. + + If a non-TPU name is used when constructing a TPUClusterResolver, that will + be returned instead (e.g. If the tpus argument's value when constructing + this TPUClusterResolver was 'grpc://10.240.1.2:8470', + 'grpc://10.240.1.2:8470' will be returned). Returns: - string, the grpc path of the first instance in the ClusterSpec. + string, the connection string to use when creating a session. Raises: ValueError: If none of the TPUs specified exists. """ + if not self._shouldResolve(): + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] + job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: raise ValueError('No TPUs exists with the specified names exist.') return 'grpc://' + job_tasks[0] + def get_master(self): + return self.master() + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. @@ -134,17 +234,81 @@ class TPUClusterResolver(ClusterResolver): Returns: A ClusterSpec containing host information returned from Cloud TPUs. + + Raises: + RuntimeError: If the provided TPU is not healthy. """ - worker_list = [] + ############################################################################ + # There are 5 potential cases this code must handle: + # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and + # a. Create a ClusterSpec that includes the coordinator job + # b. Create a ClusterSpec without the coordinator job. + # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of + # tasks and + # a. Create a ClusterSpec with the coordinator + # b. Create a ClusterSpec without the coordinator + # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. + ############################################################################ - for tpu_name in self._tpu_names: + if self._shouldResolve(): + # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, tpu_name) + self._project, self._zone, compat.as_text(self._tpu)) request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() - if 'health' in response and response['health'] == 'HEALTHY': + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + + if 'health' in response and response['health'] != 'HEALTHY': + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, + response['health'])) + + if 'networkEndpoints' in response: + worker_list = [ + '%s:%s' % (endpoint['ipAddress'], endpoint['port']) + for endpoint in response['networkEndpoints'] + ] + else: + # Fall back to the deprecated response format instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) + worker_list = [instance_url] + + cluster_spec = {self._job_name: worker_list} + else: + if not self._tpu.startswith(compat.as_bytes('grpc://')): + # Case 3. + return None + # Case 2. + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } + + if self._coordinator_address: + # {1, 2}.a + cluster_spec[self._coordinator_name] = [self._coordinator_address] + + return server_lib.ClusterSpec(cluster_spec) + + def _start_local_server(self): + address = self._requestComputeMetadata('instance/network-interfaces/0/ip') + self._server = server_lib.Server( + { + 'local': ['0.0.0.0:0'] + }, protocol='grpc', config=None, start=True) + # self._server.target is of the form: grpc://ipaddress:port + target = compat.as_bytes(self._server.target) + splits = target.split(compat.as_bytes(':')) + assert len(splits) == 3, self._server.target + assert splits[0] == compat.as_bytes('grpc'), self._server.target + self._coordinator_port = compat.as_text(splits[2]) + self._coordinator_address = '%s:%s' % ( + address, compat.as_text(self._coordinator_port)) - return ClusterSpec({self._job_name: worker_list}) + def __deepcopy__(self, memo): + # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. + return self diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 4fd34629cf74f90869c77b8cb098d3c585a49404..ad4f6432630be44a7de6e778f55f1fb7fd66f307 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib - +from tensorflow.python.util import compat mock = test.mock @@ -50,10 +52,12 @@ class MockNodeClass(object): def mock_request_compute_metadata(cls, *args, **kwargs): del cls, kwargs # Unused. - if args[0] == '/project/project-id': + if args[0] == 'project/project-id': return 'test-project' - elif args[0] == '/instance/zone': + elif args[0] == 'instance/zone': return 'projects/test-project/locations/us-central1-c' + elif args[0] == 'instance/network-interfaces/0/ip': + return '10.128.1.2' return '' @@ -71,18 +75,17 @@ class TPUClusterResolverTest(test.TestCase): expected_proto: Expected protobuf """ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) - self.assertProtoEquals( - expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_dict()).as_cluster_def()) - def mock_service_client( - self, - tpu_map=None): + def mock_service_client(self, tpu_map=None): if tpu_map is None: tpu_map = {} @@ -98,8 +101,7 @@ class TPUClusterResolverTest(test.TestCase): return mock_client - @mock.patch.object(TPUClusterResolver, - '_requestComputeMetadata', + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', mock_request_compute_metadata) def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { @@ -113,17 +115,27 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver( project=None, zone=None, - tpu_names=['test-tpu-1'], + tpu=['test-tpu-1'], credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) + service=self.mock_service_client(tpu_map=tpu_map), + coordinator_name='coordinator') actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + job { + name: 'coordinator' + tasks { key: 0 value: '10.128.1.2:%s' } + } + job { + name: 'worker' + tasks { key: 0 value: '10.1.2.3:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) - def testSimpleSuccessfulRetrieval(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', @@ -133,27 +145,67 @@ class TPUClusterResolverTest(test.TestCase): } tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu_names=['test-tpu-1'], + project=None, + zone=None, + tpu=['test-tpu-1'], + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testMultipleSuccessfulRetrieval(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + def testSimpleSuccessfulRetrieval(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' } @@ -162,87 +214,255 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu=['test-tpu-1'], + coordinator_name='coordinator', + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' } - tasks { key: 1 value: '10.1.2.3:8470' } } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testHealthyTpuNodeRetrieval(self): + def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { - 'ipAddress': '10.7.8.9', - 'port': '8470', - 'health': 'UNHEALTHY' + 'health': 'HEALTHY', + 'networkEndpoints': [{ + 'ipAddress': '10.2.3.4', + 'port': 8470, + }] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + tpu='test-tpu-1', + coordinator_name='coordinator', + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { - name: 'tpu_worker' - tasks { - key: 0 - value: '10.1.2.3:8470' - } - } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master()) - def testGetMasterMultipleEntries(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] + } + } + + tpu_cluster_resolver = TPUClusterResolver( + tpu='test-tpu-1', + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map), + coordinator_name='coordinator') + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'coordinator', + tasks { key: 0 value: '10.128.1.2:%s'} + } + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + + def testPodResolutionNoCoordinator(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu='test-tpu-1', + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) - self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) def testGetMasterNoEntries(self): tpu_map = {} + with self.assertRaises(ValueError): + TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu=[], + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + # TODO(saeta): Convert to parameterized test when included in OSS TF. + def verifyShouldResolve(self, tpu, should_resolve): tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=[], + tpu=tpu, + coordinator_name=None, credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - with self.assertRaises(ValueError): - tpu_cluster_resolver.get_master() + service=self.mock_service_client(tpu_map={})) + self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(), + "TPU: '%s'" % tpu) + + def testShouldResolveNoName(self): + self.verifyShouldResolve('', False) + + def testShouldResolveLocal(self): + self.verifyShouldResolve('local', False) + + def testShouldResolveGrpc(self): + self.verifyShouldResolve('grpc://10.1.2.3:8470', False) + + def testShouldResolveBns(self): + self.verifyShouldResolve('/bns/foo/bar', False) + + def testShouldResolveName(self): + self.verifyShouldResolve('mytpu', True) + + def testShouldResolveList(self): + self.verifyShouldResolve(['myothertpu'], True) + + def testShouldResolveGrpcPrefix(self): + self.verifyShouldResolve('grpctpu', True) + + def testNoCallComputeMetadata(self): + tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') + self.assertEqual( + compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) + self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) + + def testGkeEnvironmentForDonut(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testGkeEnvironmentForPod(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470') + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + tasks { key: 1 value: '10.120.27.6:8470' } + tasks { key: 2 value: '10.120.27.7:8470' } + tasks { key: 3 value: '10.120.27.8:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testDiscoveryUrl(self): + os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' + self.assertEqual('https://{api}.internal/{apiVersion}', + TPUClusterResolver._discoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 23b31ae1dcc83d8a7152354ac147de9ada320429..e524e9e7437b19e0d117fe7b85042e8154773a02 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) + +if(WIN32) +# BoringSSL is disabled for windows as it currently doesn't build with +# MSBuild. (Ninja is required.) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) +else() +# BoringSSL is enabled for gRPC. +option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON) +endif() + option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) @@ -31,10 +40,14 @@ option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF) option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF) option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) -option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) +# SIMD, MKL and MKLDNN options +option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions" OFF) +option(tensorflow_ENABLE_MKL_SUPPORT "Enable Intel MKL support" OFF) +option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires MKL enabled" OFF) + # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") @@ -80,7 +93,7 @@ if (NOT WIN32) option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) if (systemlib_ALL) - set (systmelib_ZLIB ON) + set (systemlib_ZLIB ON) endif (systemlib_ALL) endif() @@ -124,8 +137,16 @@ endif() add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) - add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) - add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + # 64 bits + add_definitions(-DWIN64) + elseif(CMAKE_SIZEOF_VOID_P EQUAL 4) + # 32 bits + # temporary fix for #18241 + add_definitions(-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=std::int64_t) + endif() + add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11) + add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH) add_definitions(-DTF_COMPILE_LIBRARY) add_definitions(/bigobj /nologo /EHsc /GF /MP /Gm-) @@ -160,14 +181,24 @@ if (tensorflow_OPTIMIZE_FOR_NATIVE_ARCH) endif() endif() +include(CheckCXXCompilerFlag) + +# OpenMP Support +CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) +if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +endif() +CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) +if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") +endif() + # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) if (WIN32) - CHECK_CXX_COMPILER_FLAG("${tensorflow_WIN_CPU_SIMD_OPTIONS}" COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) + CHECK_CXX_COMPILER_FLAG(${tensorflow_WIN_CPU_SIMD_OPTIONS} COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) if(COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${tensorflow_WIN_CPU_SIMD_OPTIONS}") - else() - message(FATAL_ERROR "${tensorflow_WIN_CPU_SIMD_OPTIONS} not supported") endif() endif() endif() @@ -193,6 +224,7 @@ include(protobuf) include(re2) include(cub) include(sqlite) +include(double_conversion) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -213,6 +245,7 @@ set(tensorflow_EXTERNAL_LIBRARIES ${protobuf_STATIC_LIBRARIES} ${re2_STATIC_LIBRARIES} ${sqlite_STATIC_LIBRARIES} + ${double_conversion_STATIC_LIBRARIES} ) if (systemlib_ZLIB) @@ -240,6 +273,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES fft2d re2 sqlite_copy_headers_to_destination + double_conversion ) include_directories( @@ -262,6 +296,7 @@ include_directories( ${PROTOBUF_INCLUDE_DIRS} ${re2_INCLUDE_DIR} ${sqlite_INCLUDE_DIR} + ${double_conversion_INCLUDE_DIR} ) if(tensorflow_ENABLE_SSL_SUPPORT) @@ -298,6 +333,49 @@ if(HAIKU) list(APPEND tensorflow_EXTERNAL_LIBRARIES network) endif() +# MKL Support +if (tensorflow_ENABLE_MKL_SUPPORT) + add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) + if (WIN32) + find_path(MKL_HOME_PLATFORM mkl + PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ + PATH_SUFFIXES windows) + set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) + set(MKL_LINK_DIRS + ${MKL_HOME_PLATFORM}/mkl/lib/intel64 + ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt + ${MKL_HOME_PLATFORM}/compiler/lib/intel64 + ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) + set(MKL_REDIST_DLL_DIRS + ${MKL_HOME_PLATFORM}/redist/intel64/mkl + ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt + ${MKL_HOME_PLATFORM}/redist/intel64/compiler) + list(APPEND tensorflow_EXTERNAL_LIBRARIES + mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) + endif() + if (UNIX) + # Fix me: complete the path on linux + find_path(MKL_HOME_PLATFORM mkl + HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ + PATH_SUFFIXES linux) + set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) + set(MKL_LINK_DIRS) # incompleted + set(MKL_REDIST_SO_DIRS) # incompleted + endif() + include_directories(${MKL_INCLUDE_DIRS}) + link_directories(${MKL_LINK_DIRS}) + if (tensorflow_ENABLE_MKLDNN_SUPPORT) + include(mkldnn) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + include_directories(${mkldnn_INCLUDE_DIRS}) + else (tensorflow_ENABLE_MKLDNN_SUPPORT) + add_definitions(-DINTEL_MKL_ML) + endif() +endif (tensorflow_ENABLE_MKL_SUPPORT) + if (tensorflow_ENABLE_GPU) if (NOT WIN32) # Default install paths for cuda libraries in Linux @@ -409,6 +487,10 @@ if (tensorflow_ENABLE_GPU) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CUDA_LIBRARIES}) + if(NOT WIN32) + # add gomp to tensorflow_EXTERNAL_LIBRARIES, needed by libcusolver.so + list(APPEND tensorflow_EXTERNAL_LIBRARIES gomp) + endif() # NOTE(mrry): Update these flags when the version of CUDA or cuDNN used # in the default build is upgraded. diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 8f85a75ee466dbac524a1266dc2522109ca77cd5..0b79f718d4823a987e02804f59a432ee46d0ada3 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -26,7 +26,7 @@ The CMake files in this directory can build the core TensorFlow runtime, an example C++ binary, and a PIP package containing the runtime and Python bindings. -### Pre-requisites +### Prerequisites * CMake version 3.5 or later. @@ -34,14 +34,16 @@ bindings. * [SWIG](http://www.swig.org/download.html) -* Additional pre-requisites for Microsoft Windows: +* Additional prerequisites for Microsoft Windows: - Visual Studio 2015 - Python 3.5 - - NumPy 1.11.0 or later -* Additional pre-requisites for Linux: +* Additional prerequisites for Linux: - Python 2.7 or later - [Docker](https://www.docker.com/) (for automated testing) + +* Python dependencies: + - wheel - NumPy 1.11.0 or later ### Known-good configurations @@ -102,7 +104,7 @@ ops or APIs. Step-by-step Windows build ========================== -1. Install the pre-requisites detailed above, and set up your environment. +1. Install the prerequisites detailed above, and set up your environment. * The following commands assume that you are using the Windows Command Prompt (`cmd.exe`). You will need to set up your environment to use the @@ -126,6 +128,18 @@ Step-by-step Windows build D:\local\cuda\bin ``` + * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable. + + In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable. + It should contain the directory of the MKL dlls. For example: + + ``` + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt + ``` + + * We assume that `cmake` and `git` are installed and in your `%PATH%`. If for example `cmake` is not in your path and it is installed in `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory @@ -164,7 +178,15 @@ Step-by-step Windows build More? -Dtensorflow_ENABLE_GPU=ON ^ More? -DCUDNN_HOME="D:\...\cudnn" ``` + To build with MKL support add "^" at the end of the last line above following with: + + ``` + More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ + More? -DMKL_HOME="D:\...\compilers_and_libraries" + ``` + To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: + ``` More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX ``` @@ -224,6 +246,7 @@ Step-by-step Windows build ``` ctest -C RelWithDebInfo ``` + * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. After building the python wheel, you need to install the new wheel before running the tests. @@ -232,6 +255,12 @@ Step-by-step Windows build ctest -C RelWithDebInfo ``` + * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). + CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl. + + * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. + + 4. Invoke MSBuild to build TensorFlow. To build the C++ example program, which will be created as a `.exe` @@ -249,6 +278,7 @@ Step-by-step Windows build D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj ``` + Linux Continuous Integration build ================================== diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index 836889895567f679d9960e29ece1600d1a7a58eb..98a8c7e736e5c8c407b90e8eac440cdc7ab21579 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) -set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip) +set(cub_HASH SHA256=6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake new file mode 100644 index 0000000000000000000000000000000000000000..5c5adaf5798289fba1c5d0b3f9e0489dc242043e --- /dev/null +++ b/tensorflow/contrib/cmake/external/double_conversion.cmake @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) +set(double_conversion_URL https://github.com/google/double-conversion.git) +set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8) +set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) +set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) +set(double_conversion_INCLUDES ${double_conversion_BUILD}) + +if(WIN32) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib) +else() + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a) +endif() + +set(double_conversion_HEADERS + "${double_conversion_INCLUDE_DIR}/double-conversion/bignum-dtoa.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/cached-powers.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/double-conversion.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/fixed-dtoa.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/strtod.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/bignum.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/diy-fp.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/fast-dtoa.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/ieee.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/utils.h" +) + +ExternalProject_Add(double_conversion + PREFIX double_conversion + GIT_REPOSITORY ${double_conversion_URL} + GIT_TAG ${double_conversion_TAG} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON +) diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index a235442dc5c0a07e249653381436eeae81575883..cdaa6b73b93666d272faacb869e8272561a2c74c 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) -set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip) +set(gemmlowp_HASH SHA256=b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index a9f43a3ecba4830533efcc13f8c4c1c61fe1ef78..b1e64aa55c80ad59cfdc0f4767c0282b4f73367f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,9 +17,13 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) +set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) + # We use unsecure gRPC because boringssl does not build on windows + set(grpc_TARGET grpc++_unsecure) + set(grpc_DEPENDS protobuf zlib) + set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib @@ -32,9 +36,14 @@ if(WIN32) ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) endif() else() + set(grpc_TARGET grpc++) + set(grpc_DEPENDS boringssl protobuf zlib) + set(grpc_SSL_PROVIDER module) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() @@ -42,13 +51,13 @@ add_definitions(-DGRPC_ARES=0) ExternalProject_Add(grpc PREFIX grpc - DEPENDS protobuf zlib + DEPENDS ${grpc_DEPENDS} GIT_REPOSITORY ${GRPC_URL} GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET} COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" CMAKE_CACHE_ARGS @@ -57,7 +66,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=NONE + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a639fdee367f060d4c8a79267803da6ffe3dc503 --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(mkldnn_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/include) +set(mkldnn_URL https://github.com/01org/mkl-dnn.git) +set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src) +set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) + +if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + else() + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + endif() +else() + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) +endif() + +ExternalProject_Add(mkldnn + PREFIX mkldnn + GIT_REPOSITORY ${mkldnn_URL} + GIT_TAG ${mkldnn_TAG} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${mkldnn_STATIC_LIBRARIES} + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DMKLINC:STRING=${MKL_INCLUDE_DIRS} +) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index f3a37ff5088e3f9e54e38c0edb5777c27b26969f..b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) +set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 6cd66a65990e7a2b963b52b310061b551752cd4d..ad2af01bc002555ce48f8b9bfb7d8d724a1a7dc8 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -15,32 +15,33 @@ include (ExternalProject) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) -set(png_URL https://storage.googleapis.com/libpng-public-archive/libpng-1.2.53.tar.gz) -set(png_HASH SHA256=e05c9056d7f323088fd7824d8c6acc03a4a758c4b4916715924edc5dd3223a72) +set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) +set(png_HASH SHA256=e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef) set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(png_STATIC_LIBRARIES - debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib - optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_staticd.lib + optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_static.lib) else() if(CMAKE_BUILD_TYPE EQUAL Debug) set(png_STATIC_LIBRARIES - ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib) + ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_staticd.lib) else() set(png_STATIC_LIBRARIES - ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_static.lib) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) endif() set(png_HEADERS - "${png_INSTALL}/include/libpng12/png.h" - "${png_INSTALL}/include/libpng12/pngconf.h" + "${png_INSTALL}/include/libpng16/png.h" + "${png_INSTALL}/include/libpng16/pngconf.h" + "${png_INSTALL}/include/libpng16/pnglibconf.h" ) ExternalProject_Add(png diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index aba8a5244e17d717293deec6d9b6e8e725ef010e..ab464bc99a43138130bb2758ae28ecef29805c31 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a) +set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 57c4ae76517e4d7247093edd5e5bd95a83258d87..7f835d2d519273a6d52d12f92ed585a4ddbeb973 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) -set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) -set(sqlite_HASH SHA256=208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4) +set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip) +set(sqlite_HASH SHA256=4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc) set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index 116d42309394b92407cef79c9d3a975f494bc3ff..8942f3eecf07fff893884795a104422529357bf8 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -31,7 +31,8 @@ else (systemlib_ZLIB) set(ZLIB_URL https://github.com/madler/zlib) set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) - set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) + # Match zlib version in tensorflow/workspace.bzl + set(ZLIB_TAG v1.2.11) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index aaae18a313dd082b428654091c9411600c981ec9..6f059c7225dd0938b758e8f9c28ec36fcff6db4c 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -42,7 +42,6 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11") add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11") set (NSYNC_OS_CPP_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/c++11/src/per_thread_waiter.cc" "platform/c++11/src/yield.cc" "platform/c++11/src/time_rep_timespec.cc" @@ -52,6 +51,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/win32") add_compile_options ("/TP") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/win32/src/clock_gettime.c" "platform/win32/src/pthread_key_win32.cc" ${NSYNC_OS_CPP_SRC} @@ -68,6 +68,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/posix/src/clock_gettime.c" "platform/posix/src/nsync_semaphore_mutex.c" ) @@ -75,9 +76,11 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") "platform/posix/src/start_thread.c" ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") + include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/linux/src/nsync_semaphore_futex.c" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -87,6 +90,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -96,6 +100,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -105,6 +110,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index bfe53c01b3b5fb9db8a5d8fa280d1d7f98974882..015cb73bbd93bb77f6748a364b263d99eb305c27 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -32,56 +32,19 @@ tensorflow/python/feature_column tensorflow/python/framework tensorflow/python/grappler tensorflow/python/keras -tensorflow/python/keras/activations tensorflow/python/keras/applications -tensorflow/python/keras/applications/densenet -tensorflow/python/keras/applications/inception_resnet_v2 -tensorflow/python/keras/applications/inception_v3 -tensorflow/python/keras/applications/mobilenet -tensorflow/python/keras/applications/nasnet -tensorflow/python/keras/applications/resnet50 -tensorflow/python/keras/applications/vgg16 -tensorflow/python/keras/applications/vgg19 -tensorflow/python/keras/applications/xception -tensorflow/python/keras/backend -tensorflow/python/keras/callbacks -tensorflow/python/keras/constraints tensorflow/python/keras/datasets -tensorflow/python/keras/datasets/boston_housing -tensorflow/python/keras/datasets/cifar10 -tensorflow/python/keras/datasets/cifar100 -tensorflow/python/keras/datasets/fashion_mnist -tensorflow/python/keras/datasets/imdb -tensorflow/python/keras/datasets/mnist -tensorflow/python/keras/datasets/reuters -tensorflow/python/keras/estimator -tensorflow/python/keras/initializers +tensorflow/python/keras/engine tensorflow/python/keras/layers -tensorflow/python/keras/losses -tensorflow/python/keras/metrics -tensorflow/python/keras/models -tensorflow/python/keras/optimizers tensorflow/python/keras/preprocessing -tensorflow/python/keras/preprocessing/image -tensorflow/python/keras/preprocessing/sequence -tensorflow/python/keras/preprocessing/text -tensorflow/python/keras/regularizers tensorflow/python/keras/utils tensorflow/python/keras/wrappers -tensorflow/python/keras/wrappers/scikit_learn -tensorflow/python/keras/_impl -tensorflow/python/keras/_impl/keras -tensorflow/python/keras/_impl/keras/applications -tensorflow/python/keras/_impl/keras/datasets -tensorflow/python/keras/_impl/keras/engine -tensorflow/python/keras/_impl/keras/layers -tensorflow/python/keras/_impl/keras/preprocessing -tensorflow/python/keras/_impl/keras/utils -tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions tensorflow/python/kernel_tests/linalg tensorflow/python/kernel_tests/random +tensorflow/python/kernel_tests/testdata tensorflow/python/layers tensorflow/python/lib tensorflow/python/lib/core @@ -98,10 +61,13 @@ tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/training +tensorflow/python/training/checkpointable tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf tensorflow/tools +tensorflow/tools/api +tensorflow/tools/api/generator tensorflow/tools/graph_transforms tensorflow/contrib tensorflow/contrib/all_reduce @@ -125,7 +91,13 @@ tensorflow/contrib/boosted_trees/kernels tensorflow/contrib/boosted_trees/ops tensorflow/contrib/boosted_trees/proto tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/kernel_tests tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/boosted_trees/python/training +tensorflow/contrib/boosted_trees/python/training/functions +tensorflow/contrib/boosted_trees/python/utils +tensorflow/contrib/checkpoint +tensorflow/contrib/checkpoint/python tensorflow/contrib/cloud tensorflow/contrib/cloud/kernels tensorflow/contrib/cloud/ops @@ -138,8 +110,13 @@ tensorflow/contrib/coder tensorflow/contrib/coder/kernels tensorflow/contrib/coder/ops tensorflow/contrib/coder/python +tensorflow/contrib/coder/python/layers tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler +tensorflow/contrib/constrained_optimization +tensorflow/contrib/constrained_optimization/python +tensorflow/contrib/control_flow +tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util @@ -147,8 +124,6 @@ tensorflow/contrib/crf tensorflow/contrib/crf/python tensorflow/contrib/crf/python/ops tensorflow/contrib/cudnn_rnn -tensorflow/contrib/cudnn_rnn/kernels -tensorflow/contrib/cudnn_rnn/ops tensorflow/contrib/cudnn_rnn/python tensorflow/contrib/cudnn_rnn/python/layers tensorflow/contrib/cudnn_rnn/python/ops @@ -160,6 +135,9 @@ tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto tensorflow/contrib/deprecated +tensorflow/contrib/distribute +tensorflow/contrib/distribute/python +tensorflow/contrib/distribute/python/examples tensorflow/contrib/distributions tensorflow/contrib/distributions/python tensorflow/contrib/distributions/python/ops @@ -319,6 +297,8 @@ tensorflow/contrib/metrics tensorflow/contrib/metrics/python tensorflow/contrib/metrics/python/metrics tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mixed_precision +tensorflow/contrib/mixed_precision/python tensorflow/contrib/mpi_collectives/python tensorflow/contrib/mpi_collectives/python/ops tensorflow/contrib/model_pruning @@ -331,6 +311,7 @@ tensorflow/contrib/nccl/kernels tensorflow/contrib/nccl/ops tensorflow/contrib/nccl/python tensorflow/contrib/nccl/python/ops +tensorflow/contrib/nearest_neighbor tensorflow/contrib/nearest_neighbor/kernels tensorflow/contrib/nearest_neighbor/ops tensorflow/contrib/nearest_neighbor/python @@ -341,6 +322,7 @@ tensorflow/contrib/nn/python/ops tensorflow/contrib/opt tensorflow/contrib/opt/python tensorflow/contrib/opt/python/training +tensorflow/contrib/optimizer_v2 tensorflow/contrib/pi_examples tensorflow/contrib/pi_examples/camera tensorflow/contrib/pi_examples/label_image @@ -349,6 +331,9 @@ tensorflow/contrib/periodic_resample tensorflow/contrib/periodic_resample/python tensorflow/contrib/periodic_resample/python/ops tensorflow/contrib/predictor +tensorflow/contrib/proto +tensorflow/contrib/proto/python +tensorflow/contrib/proto/python/ops tensorflow/contrib/quantization tensorflow/contrib/quantization/python tensorflow/contrib/quantize @@ -357,6 +342,10 @@ tensorflow/contrib/receptive_field tensorflow/contrib/receptive_field/python tensorflow/contrib/receptive_field/python/util tensorflow/contrib/receptive_field/python/util/examples +tensorflow/contrib/recurrent +tensorflow/contrib/recurrent/python +tensorflow/contrib/recurrent/python/ops +tensorflow/contrib/recurrent/python/kernel_tests tensorflow/contrib/reduce_slice_ops tensorflow/contrib/reduce_slice_ops/kernels tensorflow/contrib/reduce_slice_ops/ops @@ -377,6 +366,9 @@ tensorflow/contrib/rnn/ops tensorflow/contrib/rnn/python tensorflow/contrib/rnn/python/kernel_tests tensorflow/contrib/rnn/python/ops +tensorflow/contrib/rpc +tensorflow/contrib/rpc/python +tensorflow/contrib/rpc/python/ops tensorflow/contrib/saved_model tensorflow/contrib/saved_model/python tensorflow/contrib/saved_model/python/saved_model diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57..cf1ee2ad76f2cc9f58dbe90182a3e17f1edc7ed3 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -1,4 +1,5 @@ tensorflow/core +tensorflow/core/kernels/boosted_trees tensorflow/core/profiler tensorflow/python tensorflow/contrib/boosted_trees/proto @@ -10,7 +11,6 @@ tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto -tensorflow/contrib/tensorboard/graph_explorer/proto tensorflow/contrib/tensorboard/plugins/projector tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c6a15f2ca075c8de96786a580c7ddb89541df5bc..2e0a2fcef4cbdc50f0521296c4a25a864dbd8b77 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,9 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" @@ -38,13 +37,15 @@ add_dependencies( tf_core_lib tf_protos_cc) -add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" -) -add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) +if(tensorflow_BUILD_PYTHON_BINDINGS) + add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" + ) + add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc) +endif() diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f73da0b8ab18af1eca4c2bd577604595f8b8ec6d..6c90cf398c69c8c1b22ea75e0c407f258e2535f9 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -155,7 +155,7 @@ if (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") endif() else (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) add_custom_target(tf_extension_ops) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 96ac60d095dbc84470ff1be92f4bf52bb420fc52..a54cbff33b66d63d7229fa2f50b8a4ca962111ed 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,6 +63,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc" ) +file(GLOB_RECURSE tf_core_cpu_whitelisted_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc" +) +list(REMOVE_ITEM tf_core_cpu_exclude_srcs ${tf_core_cpu_whitelisted_srcs}) list(REMOVE_ITEM tf_core_cpu_srcs ${tf_core_cpu_exclude_srcs}) if (tensorflow_ENABLE_GPU) @@ -79,6 +85,7 @@ if (tensorflow_ENABLE_GPU) "${tensorflow_source_dir}/tensorflow/core/*test*.cc" ) list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_gpu_exclude_srcs}) + list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_cpu_whitelisted_srcs}) list(APPEND tf_core_cpu_srcs ${tf_core_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index a1c320347fe60f87806736befc677541a93e7e93..dac84ccb0dbf4848329e35a6e9bcf6213d8c0e55 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -213,10 +213,6 @@ else() list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() -file(GLOB tf_core_platform_exclude_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc") -list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs}) - list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs}) if(UNIX) @@ -276,7 +272,7 @@ add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) add_custom_command(OUTPUT ${VERSION_INFO_CC} COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py - --raw_generate ${VERSION_INFO_CC} + ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} DEPENDS __force_rebuild) set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) @@ -286,8 +282,6 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" @@ -341,9 +335,3 @@ add_dependencies(tf_core_framework tf_core_lib proto_text ) - -if(WIN32) - # Cmake > 3.6 will quote this as -D"__VERSION__=\"MSVC\"" which nvcc fails on. - # Instead of defining this global, limit it to tf_core_framework where its used. - target_compile_definitions(tf_core_framework PRIVATE __VERSION__="MSVC") -endif() diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f219d5eb577afa9edaadca09aef9869c81d2bd87..2d76bf530a2100b2afa80a16a5d64b6ec51ffc68 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -63,14 +63,17 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" @@ -176,6 +179,16 @@ if(WIN32) "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_windows_exclude_srcs}) +else(WIN32) + if(tensorflow_ENABLE_GPU) + file(GLOB_RECURSE tf_core_kernels_gpu_exclude_srcs + # temporarily disable nccl as it needs to be ported with gpu + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" + ) + list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_gpu_exclude_srcs}) + endif(tensorflow_ENABLE_GPU) endif(WIN32) file(GLOB_RECURSE tf_core_gpu_kernels_srcs diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 59e094812aaf4da2549d96314fc550e5635f9de8..bc753333dba4f67eee0114c4022743dd59a05982 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -15,19 +15,23 @@ set(tf_op_lib_names "audio_ops" "array_ops" - "batch_ops" + "batch_ops" "bitwise_ops" + "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" "control_flow_ops" "ctc_ops" + "cudnn_rnn_ops" "data_flow_ops" "dataset_ops" + "decode_proto_ops" + "encode_proto_ops" "functional_ops" "image_ops" "io_ops" "linalg_ops" - "list_ops" + "list_ops" "lookup_ops" "logging_ops" "manip_ops" @@ -38,6 +42,7 @@ set(tf_op_lib_names "random_ops" "remote_fused_graph_ops" "resource_variable_ops" + "rpc_ops" "script_ops" "sdca_ops" "set_ops" @@ -47,7 +52,7 @@ set(tf_op_lib_names "state_ops" "stateless_random_ops" "string_ops" - "summary_ops" + "summary_ops" "training_ops" ) @@ -84,7 +89,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") @@ -109,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b730ebd3baacafe8ae401e8987104f3062372954..92446044892127284ecb8753a250b77cb2a5743a 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -244,13 +244,11 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD # tf_python_op_gen_main library ######################################################## set(tf_python_op_gen_main_srcs - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) @@ -319,6 +317,7 @@ GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") +GENERATE_PYTHON_OP_LIB("boosted_trees_ops") GENERATE_PYTHON_OP_LIB("math_ops") GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") @@ -326,8 +325,13 @@ GENERATE_PYTHON_OP_LIB("checkpoint_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") +GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") +GENERATE_PYTHON_OP_LIB("decode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) +GENERATE_PYTHON_OP_LIB("encode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") @@ -341,6 +345,8 @@ GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py) GENERATE_PYTHON_OP_LIB("resource_variable_ops") +GENERATE_PYTHON_OP_LIB("rpc_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") GENERATE_PYTHON_OP_LIB("set_ops") @@ -348,6 +354,7 @@ GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") GENERATE_PYTHON_OP_LIB("string_ops") +GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) @@ -366,8 +373,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" @@ -415,12 +420,12 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) -GENERATE_PYTHON_OP_LIB("summary_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) @@ -459,12 +464,12 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" @@ -475,6 +480,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" @@ -547,12 +554,13 @@ if(WIN32) set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow.def") endif() set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE) - + math(EXPR tensorflow_target_bitness "${CMAKE_SIZEOF_VOID_P}*8") add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py --input "${pywrap_tensorflow_internal_static_dependencies}" --output "${pywrap_tensorflow_deffile}" --target _pywrap_tensorflow_internal.pyd + --bitness "${tensorflow_target_bitness}" BYPRODUCTS ${pywrap_tensorflow_deffile} # Required for Ninja ) endif(WIN32) @@ -582,6 +590,12 @@ add_library(pywrap_tensorflow_internal SHARED ${pywrap_tensorflow_deffile} ) +# There is a bug in GCC 5 resulting in undefined reference to a __cpu_model function when +# linking to the tensorflow library. Adding the following libraries fixes it. +if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + target_link_libraries(pywrap_tensorflow_internal PRIVATE gcc_s gcc) +endif() + if(WIN32) add_dependencies(pywrap_tensorflow_internal pywrap_tensorflow_internal_static) endif(WIN32) @@ -685,6 +699,117 @@ AddUserOps(TARGET _beam_search_ops DEPENDS pywrap_tensorflow_internal tf_python_ops DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/) +if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + else() + add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + endif() +else() + add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) +endif() + + +######################################################## +# Generate API __init__.py files. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}") + endif() +endforeach(api_init_file) +set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") +file(WRITE "${api_init_list_file}" "${api_init_files}") + +# Run create_python_api.py to generate __init__.py files. +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" +) + +add_custom_target(tf_python_api SOURCES ${api_init_files}) +add_dependencies(tf_python_api tf_python_ops) + +# TODO(mikecase): This can be removed once tf.estimator is moved +# out of TensorFlow. +######################################################## +# Generate API __init__.py files for tf.estimator. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}") + endif() +endforeach(api_init_file) +set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt") +file(WRITE "${estimator_api_init_list_file}" "${api_init_files}") + +# Run create_python_api.py to generate __init__.py files. +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" + "--package=tensorflow.python.estimator" + "--apiname=estimator" + "${estimator_api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" +) + +add_custom_target(estimator_python_api SOURCES ${api_init_files}) +add_dependencies(estimator_python_api tf_python_ops) ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -694,6 +819,8 @@ add_dependencies(tf_python_build_pip_package tf_python_copy_scripts_to_destination tf_python_touchup_modules tf_python_ops + tf_python_api + estimator_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. @@ -705,26 +832,6 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) - -if(WIN32) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) - else() - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) - endif() -else() - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) -endif() add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 6d36d5fc5c2854b2d7d2542a3cb12e033e193b88..38f40452b533fdc0dba6ac686a0ff43a2ef13cb8 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -52,12 +52,13 @@ if(WIN32) set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/tensorflow.def") endif() set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) - + math(EXPR tensorflow_target_bitness "${CMAKE_SIZEOF_VOID_P}*8") add_custom_command(TARGET tensorflow_static POST_BUILD COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py --input "${tensorflow_static_dependencies}" --output "${tensorflow_deffile}" --target tensorflow.dll + --bitness "${tensorflow_target_bitness}" ) endif(WIN32) @@ -100,8 +101,7 @@ if(WIN32) endif(WIN32) target_include_directories(tensorflow PUBLIC - $ - $) + $) install(TARGETS tensorflow EXPORT tensorflow_export RUNTIME DESTINATION bin @@ -133,10 +133,6 @@ install(DIRECTORY ${tensorflow_source_dir}/tensorflow/stream_executor/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src/google/ DESTINATION include/google FILES_MATCHING PATTERN "*.h") -# nsync headers -install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/ - DESTINATION include/external/nsync - FILES_MATCHING PATTERN "*.h") # Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/ DESTINATION include/Eigen) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 91ca33f4c4d5f6c822f45b0676e6e46d2e4c2860..9a37b681194d4ef82b27a0160dd969f733ecad67 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -64,7 +64,15 @@ file(GLOB tf_stream_executor_srcs if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" ) + if (NOT tensorflow_BUILD_CC_TESTS) + file(GLOB tf_stream_executor_gpu_tests + "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*_test.cc" + ) + list(REMOVE_ITEM tf_stream_executor_gpu_srcs ${tf_stream_executor_gpu_tests}) + endif() list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 1c4ebd7f0c1113bcd0857fb0858df2248499f920..eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -195,9 +195,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" # Loading resources in contrib doesn't seem to work on Windows @@ -208,6 +210,13 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py" # Test is flaky on Windows GPU builds (b/38283730). "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" + # Disable following manual tag in BUILD. + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + # These tests depend on a .so file + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py + ) if (WIN32) set(tf_test_src_py_exclude @@ -222,6 +231,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/grpc_large_data_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different @@ -261,6 +271,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" + # Flaky on Windows cpu with py36 (b/73556968) + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/sparse_reshape_op_test.py" # Windows file management related issues. "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # training tests @@ -278,6 +290,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py" # Segfaults on Windows. # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. @@ -475,6 +488,10 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*_test.cc" ) + list(REMOVE_ITEM tf_test_src_simple + ${tf_core_profiler_test_srcs} + ) + set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index 53c2285699a6ca94e1e6b147080338b507f4d768..4f957f1e0b46fde5daacbc59657af994e13c42d5 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -44,7 +44,8 @@ UNDNAME = "undname.exe" DUMPBIN = "dumpbin.exe" # Exclude if matched -EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") +EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::|Internal|" + r"python_op_gen_internal|grappler") # Include if matched before exclude INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" @@ -56,6 +57,10 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" r"tensorflow::ops::internal::Enter|" r"tensorflow::strings::internal::AppendPieces|" r"tensorflow::strings::internal::CatPieces|" + r"tensorflow::errors::Internal|" + r"tensorflow::Tensor::CopyFromInternal|" + r"tensorflow::kernel_factory::" + r"OpKernelRegistrar::InitInternal|" r"tensorflow::io::internal::JoinPathImpl") # Include if matched after exclude @@ -63,8 +68,8 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" r"^(TFE_\w*)$|" r"tensorflow::|" r"functor::|" - r"nsync_|" - r"perftools::gputools") + r"\?nsync_|" + r"stream_executor::") # We want to identify data members explicitly in the DEF file, so that no one # can implicitly link against the DLL if they use one of the variables exported @@ -87,6 +92,7 @@ def get_args(): required=True) parser.add_argument("--output", help="output deffile", required=True) parser.add_argument("--target", help="name of the target", required=True) + parser.add_argument("--bitness", help="build target bitness", required=True) args = parser.parse_args() return args @@ -125,7 +131,10 @@ def main(): # Header for the def file. def_fp.write("LIBRARY " + args.target + "\n") def_fp.write("EXPORTS\n") - def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n") + if args.bitness == "64": + def_fp.write("\t??1OpDef@tensorflow@@UEAA@XZ\n") + else: + def_fp.write("\t??1OpDef@tensorflow@@UAE@XZ\n") # Each symbols returned by undname matches the same position in candidates. # We compare on undname but use the decorated name from candidates. diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index ec3d550b70d2aaa23b989c44f3d86fa87cffb335..a2c6e413039ee3b5af3cb53d1af3325037536d36 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains entropy coding related modules. +# Contains tools related to data compression. package(default_visibility = [ "//learning/brain:__subpackages__", @@ -54,19 +54,27 @@ tf_gen_op_libs( ], ) +cc_library( + name = "range_coder_ops_util", + srcs = ["kernels/range_coder_ops_util.cc"], + hdrs = ["kernels/range_coder_ops_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + tf_kernel_library( name = "range_coder_ops", srcs = [ "kernels/range_coder_ops.cc", - "kernels/range_coder_ops_util.cc", - ], - hdrs = [ - "kernels/range_coder_ops_util.h", ], visibility = ["//visibility:public"], deps = [ ":coder_ops_op_lib", ":range_coder", + ":range_coder_ops_util", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -92,6 +100,34 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "pmf_to_cdf_op", + srcs = ["kernels/pmf_to_cdf_op.cc"], + visibility = ["//visibility:public"], + deps = [ + ":coder_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "pmf_to_cdf_op_test", + size = "small", + srcs = ["kernels/pmf_to_cdf_op_test.cc"], + deps = [ + ":pmf_to_cdf_op", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + cc_library( name = "all_ops", deps = [":coder_ops_op_lib"], @@ -99,12 +135,16 @@ cc_library( cc_library( name = "all_kernels", - deps = [":range_coder_ops"], + deps = [ + ":pmf_to_cdf_op", + ":range_coder_ops", + ], ) tf_custom_op_library( name = "python/ops/_coder_ops.so", srcs = [ + "kernels/pmf_to_cdf_op.cc", "kernels/range_coder.cc", "kernels/range_coder.h", "kernels/range_coder_ops.cc", @@ -120,10 +160,21 @@ tf_gen_op_wrapper_py( deps = [":coder_ops_op_lib"], ) +py_library( + name = "coder_py", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":coder_ops_py", + ":entropybottleneck_py", + ], +) + tf_custom_op_py_library( name = "coder_ops_py", srcs = [ - "__init__.py", "python/ops/coder_ops.py", ], dso = [ @@ -155,13 +206,43 @@ tf_py_test( main = "python/ops/coder_ops_test.py", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), +py_library( + name = "entropybottleneck_py", + srcs = [ + "python/layers/entropybottleneck.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":coder_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:functional_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras:engine", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "entropybottleneck_py_test", + srcs = [ + "python/layers/entropybottleneck_test.py", + ], + additional_deps = [ + ":entropybottleneck_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:variables", + "//tensorflow/python:training", + ], + main = "python/layers/entropybottleneck_test.py", ) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py index b7e663e6f1359f399cdaa80e037635a8f7546b37..99b8ac7595ec632b2918e6b7ca22c06dd7f0a8b3 100644 --- a/tensorflow/contrib/coder/__init__.py +++ b/tensorflow/contrib/coder/__init__.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Entropy code operations.""" +"""Data compression tools.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.coder.python.layers.entropybottleneck import * from tensorflow.contrib.coder.python.ops.coder_ops import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd5272ee6f20ac3537a2e378225ede5ee90782c5 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc @@ -0,0 +1,196 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { +using errors::InvalidArgument; + +class PmfToCdfOp : public OpKernel { + public: + explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); + OP_REQUIRES( + context, 0 < precision_ && precision_ <= 16, + InvalidArgument("`precision` must be in [1, 16]: ", precision_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& pmf_tensor = context->input(0); + + TensorShape shape = pmf_tensor.shape(); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape), + InvalidArgument("`pmf` should be at least 1-D.")); + OP_REQUIRES( + context, shape.dim_size(shape.dims() - 1) > 1, + InvalidArgument("`pmf` size should be at least 2 in the last axis.")); + shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1); + + Tensor* cdf_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor)); + + auto pmf = pmf_tensor.flat_inner_dims(); + auto cdf = cdf_tensor->flat_inner_dims(); + CHECK_EQ(pmf.dimension(0), cdf.dimension(0)); + CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); + + const double n = pmf.dimension(1); + const int64 cost_per_unit = static_cast(50.0 * n * std::log2(n)); + thread::ThreadPool* thread_pool = + context->device()->tensorflow_cpu_worker_threads()->workers; + thread_pool->ParallelFor( + pmf.dimension(0), cost_per_unit, + [this, pmf, &cdf](int64 start, int64 limit) { + const gtl::ArraySlice::size_type pmf_size = pmf.dimension(1); + for (int64 i = start; i < limit; ++i) { + cdf(i, 0) = 0; + PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size}); + } + }); + } + + private: + struct PenaltyItem { + PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) { + penalty = ComputeNextPenalty(); + } + + void Decrease() { + CHECK_GT(*pointer, 1); + --*pointer; + penalty = ComputeNextPenalty(); + } + + friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) { + return lhs.penalty < rhs.penalty; + } + + double ComputeNextPenalty() { + if (*pointer <= 1) { + return std::numeric_limits::infinity(); + } + return mass * (std::log2(*pointer) - std::log2(*pointer - 1)); + } + + int32* pointer; + double mass; + double penalty; + }; + + struct GainItem { + GainItem(int32* p, double mass) : pointer(p), mass(mass) { + gain = ComputeNextGain(); + } + + void Increase() { + CHECK_GT(*pointer, 0); + ++*pointer; + gain = ComputeNextGain(); + } + + friend bool operator>(const GainItem& lhs, const GainItem& rhs) { + return lhs.gain > rhs.gain; + } + + double ComputeNextGain() { + // Never increment zero value to non-zero value. + if (*pointer < 1) { + return -std::numeric_limits::infinity(); + } + return mass * (std::log2(*pointer + 1) - std::log2(*pointer)); + } + + int32* pointer; + double mass; + double gain; + }; + + void PerShard(gtl::ArraySlice pmf, + gtl::MutableArraySlice cdf) const { + CHECK_EQ(pmf.size(), cdf.size()); + + const int32 normalizer = 1 << precision_; + std::transform(pmf.begin(), pmf.end(), cdf.begin(), + [normalizer](float mass) { + int32 value = std::rint(mass * normalizer); + // NOTE: Consider checking if mass > 0. + value = std::max(value, 1); + return value; + }); + + int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0); + if (sum > normalizer) { + std::vector queue; + queue.reserve(cdf.size()); + for (int i = 0; i < cdf.size(); ++i) { + queue.emplace_back(&cdf[i], pmf[i]); + } + + std::sort(queue.begin(), queue.end()); + while (sum-- > normalizer) { + queue[0].Decrease(); + // Performs a linear search because this find_if is likely to return + // iterator very close to the begin. + auto iter = std::find_if( + std::next(queue.begin()), queue.end(), + [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; }); + std::rotate(queue.begin(), std::next(queue.begin()), iter); + } + } else if (sum < normalizer) { + std::vector queue; + queue.reserve(cdf.size()); + for (int i = 0; i < cdf.size(); ++i) { + queue.emplace_back(&cdf[i], pmf[i]); + } + + std::sort(queue.begin(), queue.end(), std::greater()); + while (sum++ < normalizer) { + queue[0].Increase(); + // Performs a linear search because this find_if is likely to return + // iterator very close to the begin. + auto iter = std::find_if( + std::next(queue.begin()), queue.end(), + [&queue](const GainItem& rhs) { return queue[0] > rhs; }); + std::rotate(queue.begin(), std::next(queue.begin()), iter); + } + } + std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); + } + + int precision_; +}; + +REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU), + PmfToCdfOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3408f6b519a33fbb8f23d19c16bc7138fc34c121 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +class PmfToQuantizedCdfOpTest : public OpsTestBase { + protected: + void SetupOp(int precision, Tensor* input) { + TF_ASSERT_OK(NodeDefBuilder("pmf_to_cdf", "PmfToQuantizedCdf") + .Input(FakeInput(DT_FLOAT)) + .Attr("precision", precision) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + inputs_.clear(); + inputs_.emplace_back(input); + } + + void GenerateData(random::SimplePhilox* rand, + gtl::MutableArraySlice slice) { + constexpr float minimum = std::numeric_limits::epsilon(); + float sum = 0; + for (float& value : slice) { + value = std::max(rand->RandFloat(), minimum); + sum += value; + } + for (float& value : slice) { + value /= sum; + } + } + + void Verify(int precision, const Tensor& pmf_tensor, + const Tensor& cdf_tensor) { + ASSERT_EQ(pmf_tensor.dims(), cdf_tensor.dims()); + const int n = pmf_tensor.dims(); + + for (int i = 0; i < n - 1; ++i) { + EXPECT_EQ(pmf_tensor.dim_size(i), cdf_tensor.dim_size(i)); + } + + auto pmf = pmf_tensor.flat_inner_dims(); + auto cdf = cdf_tensor.flat_inner_dims(); + EXPECT_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); + + const int normalizer = 1 << precision; + for (int i = 0; i < pmf.dimension(0); ++i) { + EXPECT_EQ(0, cdf(i, 0)); + + TTypes::UnalignedConstVec cdf_slice(&cdf(i, 0), cdf.dimension(1)); + + for (int j = 1; j < cdf_slice.size(); ++j) { + const int32 diff = cdf_slice(j) - cdf_slice(j - 1); + EXPECT_GT(diff, 0); + } + + EXPECT_EQ(cdf_slice(cdf_slice.size() - 1), normalizer); + } + } +}; + +TEST_F(PmfToQuantizedCdfOpTest, UnderSum) { + Tensor pmf(DT_FLOAT, {1, 10, 1, 32}); + auto matrix = pmf.flat_inner_dims(); + const std::size_t n = matrix.dimension(1); + + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + for (int64 i = 0; i < matrix.dimension(0); ++i) { + GenerateData(&rand, {&matrix(i, 0), n}); + } + + pmf.flat() = pmf.flat() * 0.85f; + + constexpr int kPrecision = 10; + SetupOp(kPrecision, &pmf); + TF_ASSERT_OK(RunOpKernel()); + + Verify(kPrecision, pmf, *GetOutput(0)); +} + +TEST_F(PmfToQuantizedCdfOpTest, OverSum) { + Tensor pmf(DT_FLOAT, {10, 1, 1, 100}); + auto matrix = pmf.flat_inner_dims(); + + // Half of each PMF is filled with zeros. The op will round up zeros to ones, + // post quantization. These round ups are likely to make the sum over + // normalizer value. + matrix.setZero(); + const std::size_t n = matrix.dimension(1) / 2; + + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + for (int64 i = 0; i < matrix.dimension(0); ++i) { + GenerateData(&rand, {&matrix(i, 0), n}); + } + + constexpr int kPrecision = 7; + SetupOp(kPrecision, &pmf); + TF_ASSERT_OK(RunOpKernel()); + + Verify(kPrecision, pmf, *GetOutput(0)); +} + +TEST_F(PmfToQuantizedCdfOpTest, ShapeFn) { + ShapeInferenceTestOp op("PmfToQuantizedCdf"); + + INFER_OK(op, "?", "?"); + INFER_OK(op, "[3]", "[4]"); + INFER_OK(op, "[3,4]", "[d0_0,5]"); + INFER_OK(op, "[3,4,5]", "[d0_0,d0_1,6]"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc index ae4d9d2836a0f89a9765004a85bc3c292b0e484f..81b36ca902b82220d9c5282a1ec72324a6d95922 100644 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc index 9056d1a6963d7be92f499db31385fb6afe2dc515..a185e07913f84a813d76a8c63741bd22a832c8b9 100644 --- a/tensorflow/contrib/coder/ops/coder_ops.cc +++ b/tensorflow/contrib/coder/ops/coder_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; @@ -76,7 +77,7 @@ are incorrect. For this reason, the range coder uses integer arithmetics and avoids using any floating point operations internally, and `cdf` should contain integers representing quantized probability mass rather than floating points. -data: An int32 tensor. +data: An int16 tensor. cdf: An int32 tensor representing the CDF's of `data`. Each integer is divided by `2^precision` to represent a fraction. encoded: A range-coded scalar string. @@ -111,9 +112,38 @@ potential performance issues, the decoder does not return error status. encoded: A scalar string tensor from RangeEncode. shape: An int32 1-D tensor representing the shape of the data encoded by RangeEncode. -decoded: An int32 tensor with shape equal to `shape`. +decoded: An int16 tensor with shape equal to `shape`. precision: The number of bits for probability quantization. Must be <= 16, and must match the precision used by RangeEncode that produced `encoded`. )doc"); + +REGISTER_OP("PmfToQuantizedCdf") + .Input("pmf: float") + .Output("cdf: int32") + .Attr("precision: int >= 1") + .SetShapeFn([] (InferenceContext* c) { + ShapeHandle in; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in)); + DimensionHandle last; + TF_RETURN_IF_ERROR(c->Add(c->Dim(in, -1), 1, &last)); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->ReplaceDim(in, -1, last, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Converts PMF to quantized CDF. This op uses floating-point operations +internally. Therefore the quantized output may not be consistent across multiple +platforms. For entropy encoders and decoders to have the same quantized CDF on +different platforms, the quantized CDF should be produced once and saved, then +the saved quantized CDF should be used everywhere. + +After quantization, if PMF does not sum to 2^precision, then some values of PMF +are increased or decreased to adjust the sum to equal to 2^precision. + +Note that the input PMF is pre-quantization. The input PMF is not normalized +by this op prior to quantization. Therefore the user is responsible for +normalizing PMF if necessary. +)doc"); // clang-format on } // namespace tensorflow diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbe3081af0b4de7f116918b3f49efe91a2d83bd --- /dev/null +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -0,0 +1,697 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Entropy bottleneck layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.coder.python.ops import coder_ops + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import engine +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary + + +class EntropyBottleneck(engine.Layer): + """Entropy bottleneck layer. + + This layer can be used to model the entropy (the amount of information + conveyed) of the tensor passing through it. During training, this can be used + to impose a (soft) entropy constraint on its activations, limiting the amount + of information flowing through the layer. Note that this is distinct from + other types of bottlenecks, which reduce the dimensionality of the space, for + example. Dimensionality reduction does not limit the amount of information, + and does not enable efficient data compression per se. + + After training, this layer can be used to compress any input tensor to a + string, which may be written to a file, and to decompress a file which it + previously generated back to a reconstructed tensor (possibly on a different + machine having access to the same model checkpoint). The entropies estimated + during training or evaluation are approximately equal to the average length of + the strings in bits. + + The layer implements a flexible probability density model to estimate entropy, + which is described in the appendix of the paper (please cite the paper if you + use this code for scientific work): + + "Variational image compression with a scale hyperprior" + + Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston + + https://arxiv.org/abs/1802.01436 + + The layer assumes that the input tensor is at least 2D, with a batch dimension + at the beginning and a channel dimension as specified by `data_format`. The + layer trains an independent probability density model for each channel, but + assumes that across all other dimensions, the inputs are i.i.d. (independent + and identically distributed). Because the entropy (and hence, average + codelength) is a function of the densities, this assumption may have a direct + effect on the compression performance. + + Because data compression always involves discretization, the outputs of the + layer are generally only approximations of its inputs. During training, + discretization is modeled using additive uniform noise to ensure + differentiability. The entropies computed during training are differential + entropies. During evaluation, the data is actually quantized, and the + entropies are discrete (Shannon entropies). To make sure the approximated + tensor values are good enough for practical purposes, the training phase must + be used to balance the quality of the approximation with the entropy, by + adding an entropy term to the training loss, as in the following example. + + Here, we use the entropy bottleneck to compress the latent representation of + an autoencoder. The data vectors `x` in this case are 4D tensors in + `'channels_last'` format (for example, 16x16 pixel grayscale images). + + The layer always produces exactly one auxiliary loss and one update op which + are only significant for compression and decompression. To use the compression + feature, the auxiliary loss must be minimized during or after training. After + that, the update op must be executed at least once. Here, we simply attach + them to the main training step. + + Training: + ``` + # Build autoencoder. + x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) + y = forward_transform(x) + entropy_bottleneck = EntropyBottleneck() + y_, likelihoods = entropy_bottleneck(y, training=True) + x_ = backward_transform(y_) + + # Information content (= predicted codelength) in bits of each batch element + # (note that taking the natural logarithm and dividing by `log(2)` is + # equivalent to taking base-2 logarithms): + bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) + + # Squared difference of each batch element: + squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) + + # The loss is a weighted sum of mean squared error and entropy (average + # information content), where the weight controls the trade-off between + # approximation error and entropy. + main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) + + # Minimize loss and auxiliary loss, and execute update op. + main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) + main_step = optimizer.minimize(main_loss) + # 1e-2 is a good starting point for the learning rate of the auxiliary loss, + # assuming Adam is used. + aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) + aux_step = optimizer.minimize(entropy_bottleneck.losses[0]) + step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) + ``` + + Evaluation: + ``` + # Build autoencoder. + x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) + y = forward_transform(x) + y_, likelihoods = EntropyBottleneck()(y, training=False) + x_ = backward_transform(y_) + + # Information content (= predicted codelength) in bits of each batch element: + bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) + + # Squared difference of each batch element: + squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) + + # The loss is a weighted sum of mean squared error and entropy (average + # information content), where the weight controls the trade-off between + # approximation error and entropy. + loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) + ``` + + To be able to compress the bottleneck tensor and decompress it in a different + session, or on a different machine, you need three items: + - The compressed representations stored as strings. + - The shape of the bottleneck for these string representations as a `Tensor`, + as well as the number of channels of the bottleneck at graph construction + time. + - The checkpoint of the trained model that was used for compression. Note: + It is crucial that the auxiliary loss produced by this layer is minimized + during or after training, and that the update op is run after training and + minimization of the auxiliary loss, but *before* the checkpoint is saved. + + Compression: + ``` + x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) + y = forward_transform(x) + strings = EntropyBottleneck().compress(y) + shape = tf.shape(y)[1:] + ``` + + Decompression: + ``` + strings = tf.placeholder(tf.string, shape=[None]) + shape = tf.placeholder(tf.int32, shape=[3]) + entropy_bottleneck = EntropyBottleneck(dtype=tf.float32) + y_ = entropy_bottleneck.decompress(strings, shape, channels=5) + x_ = backward_transform(y_) + ``` + Here, we assumed that the tensor produced by the forward transform has 5 + channels. + + The above four use cases can also be implemented within the same session (i.e. + on the same `EntropyBottleneck` instance), for testing purposes, etc., by + calling the object more than once. + + Arguments: + init_scale: Float. A scaling factor determining the initial width of the + probability densities. This should be chosen big enough so that the + range of values of the layer inputs roughly falls within the interval + [`-init_scale`, `init_scale`] at the beginning of training. + filters: An iterable of ints, giving the number of filters at each layer of + the density model. Generally, the more filters and layers, the more + expressive is the density model in terms of modeling more complicated + distributions of the layer inputs. For details, refer to the paper + referenced above. The default is `[3, 3, 3]`, which should be sufficient + for most practical purposes. + tail_mass: Float, between 0 and 1. The bottleneck layer automatically + determines the range of input values that should be represented based on + their frequency of occurrence. Values occurring in the tails of the + distributions will be clipped to that range during compression. + `tail_mass` determines the amount of probability mass in the tails which + is cut off in the worst case. For example, the default value of `1e-9` + means that at most 1 in a billion input samples will be clipped to the + range. + optimize_integer_offset: Boolean. Typically, the input values of this layer + are floats, which means that quantization during evaluation can be + performed with an arbitrary offset. By default, the layer determines that + offset automatically. In special situations, such as when it is known that + the layer will receive only full integer values during evaluation, it can + be desirable to set this argument to `False` instead, in order to always + quantize to full integer values. + likelihood_bound: Float. If positive, the returned likelihood values are + ensured to be greater than or equal to this value. This prevents very + large gradients with a typical entropy loss (defaults to 1e-9). + range_coder_precision: Integer, between 1 and 16. The precision of the range + coder used for compression and decompression. This trades off computation + speed with compression efficiency, where 16 is the slowest but most + efficient setting. Choosing lower values may increase the average + codelength slightly compared to the estimated entropies. + data_format: Either `'channels_first'` or `'channels_last'` (default). + trainable: Boolean. Whether the layer should be trained. + name: String. The name of the layer. + dtype: Default dtype of the layer's parameters (default of `None` means use + the type of the first input). + + Read-only properties: + init_scale: See above. + filters: See above. + tail_mass: See above. + optimize_integer_offset: See above. + likelihood_bound: See above. + range_coder_precision: See above. + data_format: See above. + name: String. See above. + dtype: See above. + trainable_variables: List of trainable variables. + non_trainable_variables: List of non-trainable variables. + variables: List of all variables of this layer, trainable and non-trainable. + updates: List of update ops of this layer. Always contains exactly one + update op, which must be run once after the last training step, before + `compress` or `decompress` is used. + losses: List of losses added by this layer. Always contains exactly one + auxiliary loss, which must be added to the training loss. + + Mutable properties: + trainable: Boolean. Whether the layer should be trained. + input_spec: Optional `InputSpec` object specifying the constraints on inputs + that can be accepted by the layer. + """ + + def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, + optimize_integer_offset=True, likelihood_bound=1e-9, + range_coder_precision=16, data_format="channels_last", **kwargs): + super(EntropyBottleneck, self).__init__(**kwargs) + self._init_scale = float(init_scale) + self._filters = tuple(int(f) for f in filters) + self._tail_mass = float(tail_mass) + if not 0 < self.tail_mass < 1: + raise ValueError( + "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) + self._optimize_integer_offset = bool(optimize_integer_offset) + self._likelihood_bound = float(likelihood_bound) + self._range_coder_precision = int(range_coder_precision) + self._data_format = data_format + self._channel_axis(2) # trigger ValueError early + self.input_spec = engine.InputSpec(min_ndim=2) + + @property + def init_scale(self): + return self._init_scale + + @property + def filters(self): + return self._filters + + @property + def tail_mass(self): + return self._tail_mass + + @property + def optimize_integer_offset(self): + return self._optimize_integer_offset + + @property + def likelihood_bound(self): + return self._likelihood_bound + + @property + def range_coder_precision(self): + return self._range_coder_precision + + @property + def data_format(self): + return self._data_format + + def _channel_axis(self, ndim): + try: + return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] + except KeyError: + raise ValueError("Unsupported `data_format` for {} layer: {}.".format( + self.__class__.__name__, self.data_format)) + + def _logits_cumulative(self, inputs, stop_gradient): + """Evaluate logits of the cumulative densities. + + Args: + inputs: The values at which to evaluate the cumulative densities, expected + to be a `Tensor` of shape `(channels, 1, batch)`. + stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so + that the gradient of the output with respect to the density model + parameters is disconnected (the gradient with respect to `inputs` is + left untouched). + + Returns: + A `Tensor` of the same shape as `inputs`, containing the logits of the + cumulative densities evaluated at the given inputs. + """ + logits = inputs + + for i in range(len(self.filters) + 1): + matrix = self._matrices[i] + if stop_gradient: + matrix = array_ops.stop_gradient(matrix) + logits = math_ops.matmul(matrix, logits) + + bias = self._biases[i] + if stop_gradient: + bias = array_ops.stop_gradient(bias) + logits += bias + + if i < len(self._factors): + factor = self._factors[i] + if stop_gradient: + factor = array_ops.stop_gradient(factor) + logits += factor * math_ops.tanh(logits) + + return logits + + def build(self, input_shape): + """Builds the layer. + + Creates the variables for the network modeling the densities, creates the + auxiliary loss estimating the median and tail quantiles of the densities, + and then uses that to create the probability mass functions and the update + op that produces the discrete cumulative density functions used by the range + coder. + + Args: + input_shape: Shape of the input tensor, used to get the number of + channels. + + Raises: + ValueError: if `input_shape` doesn't specify the length of the channel + dimension. + """ + input_shape = tensor_shape.TensorShape(input_shape) + channel_axis = self._channel_axis(input_shape.ndims) + channels = input_shape[channel_axis].value + if channels is None: + raise ValueError("The channel dimension of the inputs must be defined.") + self.input_spec = engine.InputSpec( + ndim=input_shape.ndims, axes={channel_axis: channels}) + filters = (1,) + self.filters + (1,) + scale = self.init_scale ** (1 / (len(self.filters) + 1)) + + # Create variables. + self._matrices = [] + self._biases = [] + self._factors = [] + for i in range(len(self.filters) + 1): + init = np.log(np.expm1(1 / scale / filters[i + 1])) + matrix = self.add_variable( + "matrix_{}".format(i), dtype=self.dtype, + shape=(channels, filters[i + 1], filters[i]), + initializer=init_ops.Constant(init)) + matrix = nn.softplus(matrix) + self._matrices.append(matrix) + + bias = self.add_variable( + "bias_{}".format(i), dtype=self.dtype, + shape=(channels, filters[i + 1], 1), + initializer=init_ops.RandomUniform(-.5, .5)) + self._biases.append(bias) + + if i < len(self.filters): + factor = self.add_variable( + "factor_{}".format(i), dtype=self.dtype, + shape=(channels, filters[i + 1], 1), + initializer=init_ops.Zeros()) + factor = math_ops.tanh(factor) + self._factors.append(factor) + + # To figure out what range of the densities to sample, we need to compute + # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we + # can't take inverses of the cumulative directly, we make it an optimization + # problem: + # `quantiles = argmin(|logit(cumulative) - target|)` + # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. + # Taking the logit (inverse of sigmoid) of the cumulative makes the + # representation of the right target more numerically stable. + + # Numerically stable way of computing logits of `tail_mass / 2` + # and `1 - tail_mass / 2`. + target = np.log(2 / self.tail_mass - 1) + # Compute lower and upper tail quantile as well as median. + target = constant_op.constant([-target, 0, target], dtype=self.dtype) + + def quantiles_initializer(shape, dtype=None, partition_info=None): + del partition_info # unused + assert tuple(shape[1:]) == (1, 3) + init = constant_op.constant( + [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) + return array_ops.tile(init, (shape[0], 1, 1)) + + quantiles = self.add_variable( + "quantiles", shape=(channels, 1, 3), dtype=self.dtype, + initializer=quantiles_initializer) + logits = self._logits_cumulative(quantiles, stop_gradient=True) + loss = math_ops.reduce_sum(abs(logits - target)) + self.add_loss(loss, inputs=None) + + # Save medians for `call`, `compress`, and `decompress`. + self._medians = quantiles[:, :, 1:2] + if not self.optimize_integer_offset: + self._medians = math_ops.round(self._medians) + + # Largest distance observed between lower tail quantile and median, + # or between median and upper tail quantile. + minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) + maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) + minmax = math_ops.maximum(minima, maxima) + minmax = math_ops.ceil(minmax) + minmax = math_ops.maximum(minmax, 1) + + # Sample the density up to `minmax` around the median. + samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) + samples += self._medians + + half = constant_op.constant(.5, dtype=self.dtype) + # We strip the sigmoid from the end here, so we can use the special rule + # below to only compute differences in the left tail of the sigmoid. + # This increases numerical stability (see explanation in `call`). + lower = self._logits_cumulative(samples - half, stop_gradient=True) + upper = self._logits_cumulative(samples + half, stop_gradient=True) + # Flip signs if we can move more towards the left tail of the sigmoid. + sign = -math_ops.sign(math_ops.add_n([lower, upper])) + pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) + # Add tail masses to first and last bin of pmf, as we clip values for + # compression, meaning that out-of-range values get mapped to these bins. + pmf = array_ops.concat([ + math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), + pmf[:, 0, 1:-1], + math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), + ], axis=-1) + self._pmf = pmf + + cdf = coder_ops.pmf_to_quantized_cdf( + pmf, precision=self.range_coder_precision) + def cdf_getter(*args, **kwargs): + del args, kwargs # ignored + return variable_scope.get_variable( + "quantized_cdf", dtype=dtypes.int32, initializer=cdf, + trainable=False, validate_shape=False, collections=()) + # Need to provide a fake shape here since add_variable insists on it. + self._quantized_cdf = self.add_variable( + "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32, + getter=cdf_getter, trainable=False) + + update_op = state_ops.assign( + self._quantized_cdf, cdf, validate_shape=False) + self.add_update(update_op, inputs=None) + + super(EntropyBottleneck, self).build(input_shape) + + def call(self, inputs, training): + """Pass a tensor through the bottleneck. + + Args: + inputs: The tensor to be passed through the bottleneck. + training: Boolean. If `True`, returns a differentiable approximation of + the inputs, and their likelihoods under the modeled probability + densities. If `False`, returns the quantized inputs and their + likelihoods under the corresponding probability mass function. These + quantities can't be used for training, as they are not differentiable, + but represent actual compression more closely. + + Returns: + values: `Tensor` with the same shape as `inputs` containing the perturbed + or quantized input values. + likelihood: `Tensor` with the same shape as `inputs` containing the + likelihood of `values` under the modeled probability distributions. + + Raises: + ValueError: if `inputs` has different `dtype` or number of channels than + a previous set of inputs the model was invoked with earlier. + """ + inputs = ops.convert_to_tensor(inputs) + ndim = self.input_spec.ndim + channel_axis = self._channel_axis(ndim) + half = constant_op.constant(.5, dtype=self.dtype) + + # Convert to (channels, 1, batch) format by commuting channels to front + # and then collapsing. + order = list(range(ndim)) + order.pop(channel_axis) + order.insert(0, channel_axis) + values = array_ops.transpose(inputs, order) + shape = array_ops.shape(values) + values = array_ops.reshape(values, (shape[0], 1, -1)) + + # Add noise or quantize. + if training: + noise = random_ops.random_uniform(array_ops.shape(values), -half, half) + values = math_ops.add_n([values, noise]) + elif self.optimize_integer_offset: + values = math_ops.round(values - self._medians) + self._medians + else: + values = math_ops.round(values) + + # Evaluate densities. + # We can use the special rule below to only compute differences in the left + # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 + # for large x, 0 for small x. Subtracting two numbers close to 0 can be done + # with much higher precision than subtracting two numbers close to 1. + lower = self._logits_cumulative(values - half, stop_gradient=False) + upper = self._logits_cumulative(values + half, stop_gradient=False) + # Flip signs if we can move more towards the left tail of the sigmoid. + sign = -math_ops.sign(math_ops.add_n([lower, upper])) + sign = array_ops.stop_gradient(sign) + likelihood = abs( + math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) + if self.likelihood_bound > 0: + likelihood_bound = constant_op.constant( + self.likelihood_bound, dtype=self.dtype) + # TODO(jballe): Override gradients. + likelihood = math_ops.maximum(likelihood, likelihood_bound) + + # Convert back to input tensor shape. + order = list(range(1, ndim)) + order.insert(channel_axis, 0) + values = array_ops.reshape(values, shape) + values = array_ops.transpose(values, order) + likelihood = array_ops.reshape(likelihood, shape) + likelihood = array_ops.transpose(likelihood, order) + + if not context.executing_eagerly(): + values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) + values.set_shape(values_shape) + likelihood.set_shape(likelihood_shape) + + return values, likelihood + + def compress(self, inputs): + """Compress inputs and store their binary representations into strings. + + Args: + inputs: `Tensor` with values to be compressed. + + Returns: + String `Tensor` vector containing the compressed representation of each + batch element of `inputs`. + """ + with ops.name_scope(self._name_scope()): + inputs = ops.convert_to_tensor(inputs) + if not self.built: + # Check input assumptions set before layer building, e.g. input rank. + self._assert_input_compatibility(inputs) + if self.dtype is None: + self._dtype = inputs.dtype.base_dtype.name + self.build(inputs.shape) + + # Check input assumptions set after layer building, e.g. input shape. + if not context.executing_eagerly(): + self._assert_input_compatibility(inputs) + + ndim = self.input_spec.ndim + channel_axis = self._channel_axis(ndim) + # Tuple of slices for expanding dimensions of tensors below. + slices = ndim * [None] + [slice(None)] + slices[channel_axis] = slice(None) + slices = tuple(slices) + + # Expand dimensions of CDF to input dimensions, keeping the channels along + # the right dimension. + cdf = self._quantized_cdf[slices[1:]] + num_levels = array_ops.shape(cdf)[-1] - 1 + + # Bring inputs to the right range by centering the range on the medians. + half = constant_op.constant(.5, dtype=self.dtype) + medians = array_ops.squeeze(self._medians, [1, 2]) + offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians + # Expand offsets to input dimensions and add to inputs. + values = inputs + offsets[slices[:-1]] + + # Clip to range and cast to integers. Because we have added .5 above, and + # all values are positive, the cast effectively implements rounding. + values = math_ops.maximum(values, half) + values = math_ops.minimum( + values, math_ops.cast(num_levels, self.dtype) - half) + values = math_ops.cast(values, dtypes.int16) + + def loop_body(tensor): + return coder_ops.range_encode( + tensor, cdf, precision=self.range_coder_precision) + strings = functional_ops.map_fn( + loop_body, values, dtype=dtypes.string, back_prop=False) + + if not context.executing_eagerly(): + strings.set_shape(inputs.shape[:1]) + + return strings + + def decompress(self, strings, shape, channels=None): + """Decompress values from their compressed string representations. + + Args: + strings: A string `Tensor` vector containing the compressed data. + shape: A `Tensor` vector of int32 type. Contains the shape of the tensor + to be decompressed, excluding the batch dimension. + channels: Integer. Specifies the number of channels statically. Needs only + be set if the layer hasn't been built yet (i.e., this is the first input + it receives). + + Returns: + The decompressed `Tensor`. Its shape will be equal to `shape` prepended + with the batch dimension from `strings`. + + Raises: + ValueError: If the length of `shape` isn't available at graph construction + time. + """ + with ops.name_scope(self._name_scope()): + strings = ops.convert_to_tensor(strings) + shape = ops.convert_to_tensor(shape) + if self.built: + ndim = self.input_spec.ndim + channel_axis = self._channel_axis(ndim) + if channels is None: + channels = self.input_spec.axes[channel_axis] + else: + if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): + raise ValueError("`shape` must be a vector with known length.") + ndim = shape.shape[0].value + 1 + channel_axis = self._channel_axis(ndim) + input_shape = ndim * [None] + input_shape[channel_axis] = channels + self.build(input_shape) + + # Tuple of slices for expanding dimensions of tensors below. + slices = ndim * [None] + [slice(None)] + slices[channel_axis] = slice(None) + slices = tuple(slices) + + # Expand dimensions of CDF to input dimensions, keeping the channels along + # the right dimension. + cdf = self._quantized_cdf[slices[1:]] + num_levels = array_ops.shape(cdf)[-1] - 1 + + def loop_body(string): + return coder_ops.range_decode( + string, shape, cdf, precision=self.range_coder_precision) + outputs = functional_ops.map_fn( + loop_body, strings, dtype=dtypes.int16, back_prop=False) + outputs = math_ops.cast(outputs, self.dtype) + + medians = array_ops.squeeze(self._medians, [1, 2]) + offsets = math_ops.cast(num_levels // 2, self.dtype) - medians + outputs -= offsets[slices[:-1]] + + if not context.executing_eagerly(): + outputs_shape = ndim * [None] + outputs_shape[0] = strings.shape[0] + outputs_shape[channel_axis] = channels + outputs.set_shape(outputs_shape) + + return outputs + + def visualize(self): + """Multi-channel visualization of densities as images. + + Creates and returns an image summary visualizing the current probabilty + density estimates. The image contains one row for each channel. Within each + row, the pixel intensities are proportional to probability values, and each + row is centered on the median of the corresponding distribution. + + Returns: + The created image summary. + """ + with ops.name_scope(self._name_scope()): + image = self._pmf + image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) + image = math_ops.cast(image + .5, dtypes.uint8) + image = image[None, :, :, None] + return summary.image("pmf", image, max_outputs=1) + + def compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + return input_shape, input_shape diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py new file mode 100644 index 0000000000000000000000000000000000000000..798b0234ebcce7df108a0da65d1305502ce0253a --- /dev/null +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests of EntropyBottleneck class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.coder.python.layers import entropybottleneck + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class EntropyBottleneckTest(test.TestCase): + + def test_noise(self): + # Tests that the noise added is uniform noise between -0.5 and 0.5. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck() + noisy, _ = layer(inputs, training=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + values = np.linspace(-50, 50, 100)[:, None] + noisy, = sess.run([noisy], {inputs: values}) + self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) + self.assertAllClose(values, noisy, rtol=0, atol=.5) + + def test_quantization(self): + # Tests that inputs are quantized to full integer values, even after + # quantiles have been updated. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False) + quantized, _ = layer(inputs, training=False) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + values = np.linspace(-50, 50, 100)[:, None] + quantized, = sess.run([quantized], {inputs: values}) + self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) + + def test_quantization_optimized_offset(self): + # Tests that inputs are not quantized to full integer values after quantiles + # have been updated. However, the difference between input and output should + # be between -0.5 and 0.5, and the offset must be consistent. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True) + quantized, _ = layer(inputs, training=False) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + values = np.linspace(-50, 50, 100)[:, None] + quantized, = sess.run([quantized], {inputs: values}) + self.assertAllClose(values, quantized, rtol=0, atol=.5) + diff = np.ravel(np.around(values) - quantized) % 1 + self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) + self.assertNotEqual(diff[0], 0) + + def test_codec(self): + # Tests that inputs are compressed and decompressed correctly, and quantized + # to full integer values, even after quantiles have been updated. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=60, + optimize_integer_offset=False) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = np.linspace(-50, 50, 100)[None, :, None] + decoded, = sess.run([decoded], {inputs: values}) + self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) + + def test_codec_optimized_offset(self): + # Tests that inputs are compressed and decompressed correctly, and not + # quantized to full integer values after quantiles have been updated. + # However, the difference between input and output should be between -0.5 + # and 0.5, and the offset must be consistent. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=60, + optimize_integer_offset=True) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = np.linspace(-50, 50, 100)[None, :, None] + decoded, = sess.run([decoded], {inputs: values}) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + diff = np.ravel(np.around(values) - decoded) % 1 + self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) + self.assertNotEqual(diff[0], 0) + + def test_codec_clipping(self): + # Tests that inputs are compressed and decompressed correctly, and clipped + # to the expected range. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=40) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = np.linspace(-50, 50, 100)[None, :, None] + decoded, = sess.run([decoded], {inputs: values}) + expected = np.clip(np.around(values), -40, 40) + self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) + + def test_channels_last(self): + # Test the layer with more than one channel and multiple input dimensions, + # with the channels in the last dimension. + inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=50) + noisy, _ = layer(inputs, training=True) + quantized, _ = layer(inputs, training=False) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = 5 * np.random.normal(size=(7, 5, 3, 2)) + noisy, quantized, decoded = sess.run( + [noisy, quantized, decoded], {inputs: values}) + self.assertAllClose(values, noisy, rtol=0, atol=.5) + self.assertAllClose(values, quantized, rtol=0, atol=.5) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + + def test_channels_first(self): + # Test the layer with more than one channel and multiple input dimensions, + # with the channel dimension right after the batch dimension. + inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_first", init_scale=50) + noisy, _ = layer(inputs, training=True) + quantized, _ = layer(inputs, training=False) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = 5 * np.random.normal(size=(2, 3, 5, 7)) + noisy, quantized, decoded = sess.run( + [noisy, quantized, decoded], {inputs: values}) + self.assertAllClose(values, noisy, rtol=0, atol=.5) + self.assertAllClose(values, quantized, rtol=0, atol=.5) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + + def test_compress(self): + # Test compression and decompression, and produce test data for + # `test_decompress`. If you set the constant at the end to `True`, this test + # will fail and the log will contain the new test data. + inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_first", filters=(), init_scale=2) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 + bitstrings, quantized_cdf, decoded = sess.run( + [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + # Set this constant to `True` to log new test data for `test_decompress`. + if False: # pylint:disable=using-constant-test + assert False, (bitstrings, quantized_cdf, decoded) + + # Data generated by `test_compress`. + # pylint:disable=g-inconsistent-quotes,bad-whitespace + bitstrings = np.array([ + b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', + b'\x8dF\xf0%\x1cv\xccllW' + ], dtype=object) + + quantized_cdf = np.array([ + [ 0, 15636, 22324, 30145, 38278, 65536], + [ 0, 19482, 26927, 35052, 42904, 65535], + [ 0, 21093, 28769, 36919, 44578, 65536] + ], dtype=np.int32) + + expected = np.array([ + [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], + [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], + [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], + [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], + [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], + [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] + ], dtype=np.float32) + # pylint:enable=g-inconsistent-quotes,bad-whitespace + + def test_decompress(self): + # Test that decompression of values compressed with a previous version + # works, i.e. that the file format doesn't change across revisions. + bitstrings = array_ops.placeholder(dtypes.string) + input_shape = array_ops.placeholder(dtypes.int32) + quantized_cdf = array_ops.placeholder(dtypes.int32) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_first", filters=(), dtype=dtypes.float32) + layer.build(self.expected.shape) + layer._quantized_cdf = quantized_cdf + decoded = layer.decompress(bitstrings, input_shape[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + decoded, = sess.run([decoded], { + bitstrings: self.bitstrings, input_shape: self.expected.shape, + quantized_cdf: self.quantized_cdf}) + self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) + + def test_build_decompress(self): + # Test that layer can be built when `decompress` is the first call to it. + bitstrings = array_ops.placeholder(dtypes.string) + input_shape = array_ops.placeholder(dtypes.int32, shape=[3]) + layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) + layer.decompress(bitstrings, input_shape[1:], channels=5) + self.assertTrue(layer.built) + + def test_pmf_normalization(self): + # Test that probability mass functions are normalized correctly. + layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) + layer.build((None, 10)) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + pmf, = sess.run([layer._pmf]) + self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) + + def test_visualize(self): + # Test that summary op can be constructed. + layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) + layer.build((None, 10)) + summary = layer.visualize() + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run([summary]) + + def test_normalization(self): + # Test that densities are normalized correctly. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck(filters=(2,)) + _, likelihood = layer(inputs, training=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + x = np.repeat(np.arange(-200, 201), 1000)[:, None] + likelihood, = sess.run([likelihood], {inputs: x}) + self.assertEqual(x.shape, likelihood.shape) + integral = np.sum(likelihood) * .001 + self.assertAllClose(1, integral, rtol=0, atol=1e-4) + + def test_entropy_estimates(self): + # Test that entropy estimates match actual range coding. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + filters=(2, 3), data_format="channels_last") + _, likelihood = layer(inputs, training=True) + diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) + _, likelihood = layer(inputs, training=False) + disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) + bitstrings = layer.compress(inputs) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + diff_entropy, disc_entropy, bitstrings = sess.run( + [diff_entropy, disc_entropy, bitstrings], + {inputs: np.random.normal(size=(1, 10000, 1))}) + codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) + self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) + self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) + self.assertGreater(codelength, disc_entropy) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 388d8e6ed6d9cb9400b0bfbe8e3f50b80149ea1a..bcee0b04c8430588c2dcbc199504bede0436f8f1 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -46,15 +46,3 @@ cuda_py_test( ], xla_enabled = True, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 29a593f6bcfa05dcafcdb2f94087380ad720dba1..a56a01b16356e12b83344474c7fbe427530f0c74 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -170,12 +169,11 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) -@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant([[3]]) + x = constant_op.constant([[3.]]) y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): y_c = math_ops.matmul(y_nc, y_nc, name="compiled") @@ -200,11 +198,11 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) @@ -222,11 +220,11 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..619153df67c90cea5a5082a411972948bac5fe90 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -0,0 +1,91 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "constrained_optimization_pip", + deps = [ + ":constrained_optimization", + ":test_util", + ], +) + +py_library( + name = "constrained_optimization", + srcs = [ + "__init__.py", + "python/candidates.py", + "python/constrained_minimization_problem.py", + "python/constrained_optimizer.py", + "python/external_regret_optimizer.py", + "python/swap_regret_optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:standard_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "candidates_test", + srcs = ["python/candidates_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +# NOTE: This library can't be "testonly" since it needs to be included in the +# pip package. +py_library( + name = "test_util", + srcs = ["python/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + "//tensorflow/python:dtypes", + "//tensorflow/python:standard_ops", + ], +) + +py_test( + name = "external_regret_optimizer_test", + srcs = ["python/external_regret_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + ":test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:standard_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + +py_test( + name = "swap_regret_optimizer_test", + srcs = ["python/swap_regret_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + ":test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:standard_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c65a150464efc1e77419040f66f36fc6756325aa --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -0,0 +1,345 @@ + + +# ConstrainedOptimization (TFCO) + +TFCO is a library for optimizing inequality-constrained problems in TensorFlow. +Both the objective function and the constraints are represented as Tensors, +giving users the maximum amount of flexibility in specifying their optimization +problems. + +This flexibility makes optimization considerably more difficult: on a non-convex +problem, if one uses the "standard" approach of introducing a Lagrange +multiplier for each constraint, and then jointly maximizing over the Lagrange +multipliers and minimizing over the model parameters, then a stable stationary +point might not even *exist*. Hence, in some cases, oscillation, instead of +convergence, is inevitable. + +Thankfully, it turns out that even if, over the course of optimization, no +*particular* iterate does a good job of minimizing the objective while +satisfying the constraints, the *sequence* of iterates, on average, usually +will. This observation suggests the following approach: at training time, we'll +periodically snapshot the model state during optimization; then, at evaluation +time, each time we're given a new example to evaluate, we'll sample one of the +saved snapshots uniformly at random, and apply it to the example. This +*stochastic model* will generally perform well, both with respect to the +objective function, and the constraints. + +In fact, we can do better: it's possible to post-process the set of snapshots to +find a distribution over at most $$m+1$$ snapshots, where $$m$$ is the number of +constraints, that will be at least as good (and will usually be much better) +than the (much larger) uniform distribution described above. If you're unable or +unwilling to use a stochastic model at all, then you can instead use a heuristic +to choose the single best snapshot. + +For full details, motivation, and theoretical results on the approach taken by +this library, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +which will be referred to as [CoJiSr18] throughout the remainder of this +document. + +### Proxy Constraints + +Imagine that we want to constrain the recall of a binary classifier to be at +least 90%. Since the recall is proportional to the number of true positive +classifications, which itself is a sum of indicator functions, this constraint +is non-differentible, and therefore cannot be used in a problem that will be +optimized using a (stochastic) gradient-based algorithm. + +For this and similar problems, TFCO supports so-called *proxy constraints*, +which are (at least semi-differentiable) approximations of the original +constraints. For example, one could create a proxy recall function by replacing +the indicator functions with sigmoids. During optimization, each proxy +constraint function will be penalized, with the magnitude of the penalty being +chosen to satisfy the corresponding *original* (non-proxy) constraint. + +On a problem including proxy constraints—even a convex problem—the +Lagrangian approach discussed above isn't guaranteed to work. However, a +different algorithm, based on minimizing *swap regret*, does work. Aside from +this difference, the recommended procedure for optimizing a proxy-constrained +problem remains the same: periodically snapshot the model during optimization, +and then either find the best $$m+1$$-sized distribution, or heuristically +choose the single best snapshot. + +## Components + +* [constrained_minimization_problem](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py): + contains the `ConstrainedMinimizationProblem` interface. Your own + constrained optimization problems should be represented using + implementations of this interface. + +* [constrained_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py): + contains the `ConstrainedOptimizer` interface, which is similar to (but + different from) `tf.train.Optimizer`, with the main difference being that + `ConstrainedOptimizer`s are given `ConstrainedMinimizationProblem`s to + optimize, and perform constrained optimization. + + * [external_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py): + contains the `AdditiveExternalRegretOptimizer` implementation, which is + a `ConstrainedOptimizer` implementing the Lagrangian approach discussed + above (with additive updates to the Lagrange multipliers). You should + use this optimizer for problems *without* proxy constraints. It may also + work for problems with proxy constraints, but we recommend using a swap + regret optimizer, instead. + + This optimizer is most similar to Algorithm 3 in Appendix C.3 of + [CoJiSr18], and is discussed in Section 3. The two differences are that + it uses proxy constraints (if they're provided) in the update of the + model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates. + + * [swap_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py): + contains the `AdditiveSwapRegretOptimizer` and + `MultiplicativeSwapRegretOptimizer` implementations, which are + `ConstrainedOptimizer`s implementing the swap-regret minimization + approach mentioned above (with additive or multiplicative updates, + respectively, to the parameters associated with the + constraints—these parameters are not Lagrange multipliers, but + play a similar role). You should use one of these optimizers (we suggest + `MultiplicativeSwapRegretOptimizer`) for problems *with* proxy + constraints. + + The `MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 + in Section 4 of [CoJiSr18], with the difference being that it uses + `tf.train.Optimizer`s, instead of SGD, for the "inner" updates. The + `AdditiveSwapRegretOptimizer` differs further in that it performs + additive (instead of multiplicative) updates of the stochastic matrix. + +* [candidates](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/candidates.py): + contains two functions, `find_best_candidate_distribution` and + `find_best_candidate_index`. Both of these functions are given a set of + candidate solutions to a constrained optimization problem, from which the + former finds the best distribution over at most $$m+1$$ candidates, and the + latter heuristically finds the single best candidate. As discussed above, + the set of candidates will typically be model snapshots saved periodically + during optimization. Both of these functions require that scipy be + installed. + + The `find_best_candidate_distribution` function implements the approach + described in Lemma 3 of [CoJiSr18], while `find_best_candidate_index` + implements the heuristic used for hyperparameter search in the experiments + of Section 5.2. + +## Convex Example with Proxy Constraints + +This is a simple example of recall-constrained optimization on simulated data: +we will try to find a classifier that minimizes the average hinge loss while +constraining recall to be at least 90%. + +We'll start with the required imports—notice the definition of `tfco`: + +```python +import math +import numpy as np +import tensorflow as tf + +tfco = tf.contrib.constrained_optimization +``` + +We'll now create an implementation of the `ConstrainedMinimizationProblem` class +for this problem. The constructor takes three parameters: a Tensor containing +the classification labels (0 or 1) for every training example, another Tensor +containing the model's predictions on every training example (sometimes called +the "logits"), and the lower bound on recall that will be enforced using a +constraint. + +This implementation will contain both constraints *and* proxy constraints: the +former represents the constraint that the true recall (defined in terms of the +*number* of true positives) be at least `recall_lower_bound`, while the latter +represents the same constraint, but on a hinge approximation of the recall. + +```python +class ExampleProblem(tfco.ConstrainedMinimizationProblem): + + def __init__(self, labels, predictions, recall_lower_bound): + self._labels = labels + self._predictions = predictions + self._recall_lower_bound = recall_lower_bound + # The number of positively-labeled examples. + self._positive_count = tf.reduce_sum(self._labels) + + @property + def objective(self): + return tf.losses.hinge_loss(labels=self._labels, logits=self._predictions) + + @property + def constraints(self): + true_positives = self._labels * tf.to_float(self._predictions > 0) + true_positive_count = tf.reduce_sum(true_positives) + recall = true_positive_count / self._positive_count + # The constraint is (recall >= self._recall_lower_bound), which we convert + # to (self._recall_lower_bound - recall <= 0) because + # ConstrainedMinimizationProblems must always provide their constraints in + # the form (tensor <= 0). + # + # The result of this function should be a tensor, with each element being + # a quantity that is constrained to be nonpositive. We only have one + # constraint, so we return a one-element tensor. + return self._recall_lower_bound - recall + + @property + def proxy_constraints(self): + # Use 1 - hinge since we're SUBTRACTING recall in the constraint function, + # and we want the proxy constraint function to be convex. + true_positives = self._labels * tf.minimum(1.0, self._predictions) + true_positive_count = tf.reduce_sum(true_positives) + recall = true_positive_count / self._positive_count + # Please see the corresponding comment in the constraints property. + return self._recall_lower_bound - recall +``` + +We'll now create a simple simulated dataset by sampling 1000 random +10-dimensional feature vectors from a Gaussian, finding their labels using a +random "ground truth" linear model, and then adding noise by randomly flipping +200 labels. + +```python +# Create a simulated 10-dimensional training dataset consisting of 1000 labeled +# examples, of which 800 are labeled correctly and 200 are mislabeled. +num_examples = 1000 +num_mislabeled_examples = 200 +dimension = 10 +# We will constrain the recall to be at least 90%. +recall_lower_bound = 0.9 + +# Create random "ground truth" parameters to a linear model. +ground_truth_weights = np.random.normal(size=dimension) / math.sqrt(dimension) +ground_truth_threshold = 0 + +# Generate a random set of features for each example. +features = np.random.normal(size=(num_examples, dimension)).astype( + np.float32) / math.sqrt(dimension) +# Compute the labels from these features given the ground truth linear model. +labels = (np.matmul(features, ground_truth_weights) > + ground_truth_threshold).astype(np.float32) +# Add noise by randomly flipping num_mislabeled_examples labels. +mislabeled_indices = np.random.choice( + num_examples, num_mislabeled_examples, replace=False) +labels[mislabeled_indices] = 1 - labels[mislabeled_indices] +``` + +We're now ready to construct our model, and the corresponding optimization +problem. We'll use a linear model of the form $$f(x) = w^T x - t$$, where $$w$$ +is the `weights`, and $$t$$ is the `threshold`. The `problem` variable will hold +an instance of the `ExampleProblem` class we created earlier. + +```python +# Create variables containing the model parameters. +weights = tf.Variable(tf.zeros(dimension), dtype=tf.float32, name="weights") +threshold = tf.Variable(0.0, dtype=tf.float32, name="threshold") + +# Create the optimization problem. +constant_labels = tf.constant(labels, dtype=tf.float32) +constant_features = tf.constant(features, dtype=tf.float32) +predictions = tf.tensordot(constant_features, weights, axes=(1, 0)) - threshold +problem = ExampleProblem( + labels=constant_labels, + predictions=predictions, + recall_lower_bound=recall_lower_bound, +) +``` + +We're almost ready to train our model, but first we'll create a couple of +functions to measure its performance. We're interested in two quantities: the +average hinge loss (which we seek to minimize), and the recall (which we +constrain). + +```python +def average_hinge_loss(labels, predictions): + num_examples, = np.shape(labels) + signed_labels = (labels * 2) - 1 + total_hinge_loss = np.sum(np.maximum(0.0, 1.0 - signed_labels * predictions)) + return total_hinge_loss / num_examples + +def recall(labels, predictions): + positive_count = np.sum(labels) + true_positives = labels * (predictions > 0) + true_positive_count = np.sum(true_positives) + return true_positive_count / positive_count +``` + +As was mentioned earlier, external regret optimizers suffice for problems +without proxy constraints, but swap regret optimizers are recommended for +problems *with* proxy constraints. Since this problem contains proxy +constraints, we use the `MultiplicativeSwapRegretOptimizer`. + +For this problem, the constraint is fairly easy to satisfy, so we can use the +same "inner" optimizer (an `AdagradOptimizer` with a learning rate of 1) for +optimization of both the model parameters (`weights` and `threshold`), and the +internal parameters associated with the constraints (these are the analogues of +the Lagrange multipliers used by the `MultiplicativeSwapRegretOptimizer`). For +more difficult problems, it will often be necessary to use different optimizers, +with different learning rates (presumably found via a hyperparameter search): to +accomplish this, pass *both* the `optimizer` and `constraint_optimizer` +parameters to `MultiplicativeSwapRegretOptimizer`'s constructor. + +Since this is a convex problem (both the objective and proxy constraint +functions are convex), we can just take the last iterate. Periodic snapshotting, +and the use of the `find_best_candidate_distribution` or +`find_best_candidate_index` functions, is generally only necessary for +non-convex problems (and even then, it isn't *always* necessary). + +```python +with tf.Session() as session: + optimizer = tfco.MultiplicativeSwapRegretOptimizer( + optimizer=tf.train.AdagradOptimizer(learning_rate=1.0)) + train_op = optimizer.minimize(problem) + + session.run(tf.global_variables_initializer()) + for ii in xrange(1000): + session.run(train_op) + + trained_weights, trained_threshold = session.run((weights, threshold)) + +trained_predictions = np.matmul(features, trained_weights) - trained_threshold +print("Constrained average hinge loss = %f" % average_hinge_loss( + labels, trained_predictions)) +print("Constrained recall = %f" % recall(labels, trained_predictions)) +``` + +Running the above code gives the following output (due to the randomness of the +dataset, you'll get a different result when you run it): + +```none +Constrained average hinge loss = 0.710019 +Constrained recall = 0.899811 +``` + +As we hoped, the recall is extremely close to 90%—and, thanks to the use +of proxy constraints, this is the *true* recall, not a hinge approximation. + +For comparison, let's try optimizing the same problem *without* the recall +constraint: + +```python +with tf.Session() as session: + optimizer = tf.train.AdagradOptimizer(learning_rate=1.0) + # For optimizing the unconstrained problem, we just minimize the "objective" + # portion of the minimization problem. + train_op = optimizer.minimize(problem.objective) + + session.run(tf.global_variables_initializer()) + for ii in xrange(1000): + session.run(train_op) + + trained_weights, trained_threshold = session.run((weights, threshold)) + +trained_predictions = np.matmul(features, trained_weights) - trained_threshold +print("Unconstrained average hinge loss = %f" % average_hinge_loss( + labels, trained_predictions)) +print("Unconstrained recall = %f" % recall(labels, trained_predictions)) +``` + +This code gives the following output (again, you'll get a different answer, +since the dataset is random): + +```none +Unconstrained average hinge loss = 0.627271 +Unconstrained recall = 0.793951 +``` + +Because there is no constraint, the unconstrained problem does a better job of +minimizing the average hinge loss, but naturally doesn't approach 90% recall. diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/constrained_optimization/__init__.py similarity index 53% rename from tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py rename to tensorflow/contrib/constrained_optimization/__init__.py index 9f7a95f138f7fd3e726f095dc16f41abb6182e17..1e49ba9f179ea98aaa9c35f79787605b53a1ec53 100644 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py +++ b/tensorflow/contrib/constrained_optimization/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,40 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Csiszar f-Divergence and helpers. - -See ${python/contrib.bayesflow.csiszar_divergence}. -""" +"""A library for performing constrained optimization in TensorFlow.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.csiszar_divergence_impl import * +from tensorflow.contrib.constrained_optimization.python.candidates import * +from tensorflow.contrib.constrained_optimization.python.constrained_minimization_problem import * +from tensorflow.contrib.constrained_optimization.python.constrained_optimizer import * +from tensorflow.contrib.constrained_optimization.python.external_regret_optimizer import * +from tensorflow.contrib.constrained_optimization.python.swap_regret_optimizer import * # pylint: enable=wildcard-import + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'amari_alpha', - 'arithmetic_geometric', - 'chi_square', - 'csiszar_vimco', - 'dual_csiszar_function', - 'jeffreys', - 'jensen_shannon', - 'kl_forward', - 'kl_reverse', - 'log1p_abs', - 'modified_gan', - 'monte_carlo_csiszar_f_divergence', - 'pearson', - 'squared_hellinger', - 'symmetrized_csiszar_function', - 'total_variation', - 't_power', - 'triangular', + "AdditiveExternalRegretOptimizer", + "AdditiveSwapRegretOptimizer", + "ConstrainedMinimizationProblem", + "ConstrainedOptimizer", + "find_best_candidate_distribution", + "find_best_candidate_index", + "MultiplicativeSwapRegretOptimizer", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py new file mode 100644 index 0000000000000000000000000000000000000000..ac86a6741be1f244476f917d0e151166db65524b --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/candidates.py @@ -0,0 +1,319 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Code for optimizing over a set of candidate solutions. + +The functions in this file deal with the constrained problem: + +> minimize f(w) +> s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + +Here, f(w) is the "objective function", and g_i(w) is the ith (of m) "constraint +function". Given the values of the objective and constraint functions for a set +of n "candidate solutions" {w_0,w_1,...,w_{n-1}} (for a total of n objective +function values, and n*m constraint function values), the +`find_best_candidate_distribution` function finds the best DISTRIBUTION over +these candidates, while `find_best_candidate_index' heuristically finds the +single best candidate. + +Both of these functions have dependencies on `scipy`, so if you want to call +them, then you must make sure that `scipy` is available. The imports are +performed inside the functions themselves, so if they're not actually called, +then `scipy` is not needed. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The `find_best_candidate_distribution` function implements the approach +described in Lemma 3, while `find_best_candidate_index` implements the heuristic +used for hyperparameter search in the experiments of Section 5.2. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + + +def _find_best_candidate_distribution_helper(objective_vector, + constraints_matrix, + maximum_violation=0.0): + """Finds a distribution minimizing an objective subject to constraints. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n + candidates that, in expectation, minimizes the objective while violating + the constraints by no more than `maximum_violation`. If no such distribution + exists, it returns an error (using Go-style error reporting). + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + This function will return a distribution for which at most m+1 probabilities, + and often fewer, are nonzero. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + maximum_violation: nonnegative float, the maximum amount by which any + constraint may be violated, in expectation. + + Returns: + A pair (`result`, `message`), exactly one of which is None. If `message` is + None, then the `result` contains the optimal distribution as a numpy array + of shape (n,). If `result` is None, then `message` contains an error + message. + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes, or if `maximum_violation` is negative. + ImportError: If we're unable to import `scipy.optimize`. + """ + if maximum_violation < 0.0: + raise ValueError("maximum_violation must be nonnegative") + + mm, nn = np.shape(constraints_matrix) + if (nn,) != np.shape(objective_vector): + raise ValueError( + "objective_vector must have shape (n,), and constraints_matrix (m, n)," + " where n is the number of candidates, and m is the number of " + "constraints") + + # We import scipy inline, instead of at the top of the file, so that a scipy + # dependency is only introduced if either find_best_candidate_distribution() + # or find_best_candidate_index() are actually called. + import scipy.optimize # pylint: disable=g-import-not-at-top + + # Feasibility (within maximum_violation) constraints. + a_ub = constraints_matrix + b_ub = np.full((mm, 1), maximum_violation) + # Sum-to-one constraint. + a_eq = np.ones((1, nn)) + b_eq = np.ones((1, 1)) + # Nonnegativity constraints. + bounds = (0, None) + + result = scipy.optimize.linprog( + objective_vector, + A_ub=a_ub, + b_ub=b_ub, + A_eq=a_eq, + b_eq=b_eq, + bounds=bounds) + # Go-style error reporting. We don't raise on error, since + # find_best_candidate_distribution() needs to handle the failure case, and we + # shouldn't use exceptions as flow-control. + if not result.success: + return (None, result.message) + else: + return (result.x, None) + + +def find_best_candidate_distribution(objective_vector, + constraints_matrix, + epsilon=0.0): + """Finds a distribution minimizing an objective subject to constraints. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n + candidates that, in expectation, minimizes the objective while violating + the constraints by the smallest possible amount (with the amount being found + via bisection search). + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + This function will return a distribution for which at most m+1 probabilities, + and often fewer, are nonzero. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + This function implements the approach described in Lemma 3. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + epsilon: nonnegative float, the threshold at which to terminate the binary + search while searching for the minimal expected constraint violation + magnitude. + + Returns: + The optimal distribution, as a numpy array of shape (n,). + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes, or if `epsilon` is negative. + ImportError: If we're unable to import `scipy.optimize`. + """ + if epsilon < 0.0: + raise ValueError("epsilon must be nonnegative") + + # If there is a feasible solution (i.e. with maximum_violation=0), then that's + # what we'll return. + pp, _ = _find_best_candidate_distribution_helper(objective_vector, + constraints_matrix) + if pp is not None: + return pp + + # The bound is the minimum over all candidates, of the maximum per-candidate + # constraint violation. + lower = 0.0 + upper = np.min(np.amax(constraints_matrix, axis=0)) + best_pp, _ = _find_best_candidate_distribution_helper( + objective_vector, constraints_matrix, maximum_violation=upper) + assert best_pp is not None + + # Throughout this loop, a maximum_violation of "lower" is not achievable, + # but a maximum_violation of "upper" is achiveable. + while True: + middle = 0.5 * (lower + upper) + if (middle - lower <= epsilon) or (upper - middle <= epsilon): + break + else: + pp, _ = _find_best_candidate_distribution_helper( + objective_vector, constraints_matrix, maximum_violation=middle) + if pp is None: + lower = middle + else: + best_pp = pp + upper = middle + + return best_pp + + +def find_best_candidate_index(objective_vector, + constraints_matrix, + rank_objectives=False): + """Heuristically finds the best candidate solution to a constrained problem. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds the "best" solution according + to the following heuristic: + + 1. Across all models, the ith constraint violations (i.e. max{0, g_i(0)}) + are ranked, as are the objectives (if rank_objectives=True). + 2. Each model is then associated its MAXIMUM rank across all m constraints + (and the objective, if rank_objectives=True). + 3. The model with the minimal maximum rank is then identified. Ties are + broken using the objective function value. + 4. The index of this "best" model is returned. + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + This function implements the heuristic used for hyperparameter search in the + experiments of Section 5.2. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + rank_objectives: bool, whether the objective function values should be + included in the initial ranking step. If True, both the objective and + constraints will be ranked. If False, only the constraints will be ranked. + In either case, the objective function values will be used for + tiebreaking. + + Returns: + The index (in {0,1,...,n-1}) of the "best" model according to the above + heuristic. + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes. + ImportError: If we're unable to import `scipy.stats`. + """ + mm, nn = np.shape(constraints_matrix) + if (nn,) != np.shape(objective_vector): + raise ValueError( + "objective_vector must have shape (n,), and constraints_matrix (m, n)," + " where n is the number of candidates, and m is the number of " + "constraints") + + # We import scipy inline, instead of at the top of the file, so that a scipy + # dependency is only introduced if either find_best_candidate_distribution() + # or find_best_candidate_index() are actually called. + import scipy.stats # pylint: disable=g-import-not-at-top + + if rank_objectives: + maximum_ranks = scipy.stats.rankdata(objective_vector, method="min") + else: + maximum_ranks = np.zeros(nn, dtype=np.int64) + for ii in xrange(mm): + # Take the maximum of the constraint functions with zero, since we want to + # rank the magnitude of constraint *violations*. If the constraint is + # satisfied, then we don't care how much it's satisfied by (as a result, we + # we expect all models satisfying a constraint to be tied at rank 1). + ranks = scipy.stats.rankdata( + np.maximum(0.0, constraints_matrix[ii, :]), method="min") + maximum_ranks = np.maximum(maximum_ranks, ranks) + + best_index = None + best_rank = float("Inf") + best_objective = float("Inf") + for ii in xrange(nn): + if maximum_ranks[ii] < best_rank: + best_index = ii + best_rank = maximum_ranks[ii] + best_objective = objective_vector[ii] + elif (maximum_ranks[ii] == best_rank) and (objective_vector[ii] <= + best_objective): + best_index = ii + best_objective = objective_vector[ii] + + return best_index diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c49d48bc5c763489215261a909573af0f19055 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.candidates.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import candidates +from tensorflow.python.platform import test + + +class CandidatesTest(test.TestCase): + + def test_inconsistent_shapes_for_best_distribution(self): + """An error is raised when parameters have inconsistent shapes.""" + objective_vector = np.array([1, 2, 3]) + constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + with self.assertRaises(ValueError): + _ = candidates.find_best_candidate_distribution(objective_vector, + constraints_matrix) + + def test_inconsistent_shapes_for_best_index(self): + """An error is raised when parameters have inconsistent shapes.""" + objective_vector = np.array([1, 2, 3]) + constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + with self.assertRaises(ValueError): + _ = candidates.find_best_candidate_index(objective_vector, + constraints_matrix) + + def test_best_distribution(self): + """Distribution should match known solution.""" + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + distribution = candidates.find_best_candidate_distribution( + objective_vector, constraints_matrix) + # Verify that the solution is a probability distribution. + self.assertTrue(np.all(distribution >= 0)) + self.assertAlmostEqual(np.sum(distribution), 1.0) + # Verify that the solution satisfies the constraints. + maximum_constraint_violation = np.amax( + np.dot(constraints_matrix, distribution)) + self.assertLessEqual(maximum_constraint_violation, 0) + # Verify that the solution matches that which we expect. + expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) + self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) + + def test_best_index_rank_objectives_true(self): + """Index should match known solution.""" + # Objective ranks = [2, 1, 4, 3]. + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]]. + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + # Maximum ranks = [4, 3, 4, 3]. + index = candidates.find_best_candidate_index( + objective_vector, constraints_matrix, rank_objectives=True) + self.assertEqual(1, index) + + def test_best_index_rank_objectives_false(self): + """Index should match known solution.""" + # Objective ranks = [2, 1, 4, 3]. + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]]. + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + # Maximum ranks = [4, 3, 4, 1]. + index = candidates.find_best_candidate_index( + objective_vector, constraints_matrix, rank_objectives=False) + self.assertEqual(3, index) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py new file mode 100644 index 0000000000000000000000000000000000000000..70813fb217956b167b80a7e1d555c8ba79088fdb --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines abstract class for `ConstrainedMinimizationProblem`s. + +A ConstrainedMinimizationProblem consists of an objective function to minimize, +and a set of constraint functions that are constrained to be nonpositive. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class ConstrainedMinimizationProblem(object): + """Abstract class representing a `ConstrainedMinimizationProblem`. + + A ConstrainedMinimizationProblem consists of an objective function to + minimize, and a set of constraint functions that are constrained to be + nonpositive. + + In addition to the constraint functions, there may (optionally) be proxy + constraint functions: a ConstrainedOptimizer will attempt to penalize these + proxy constraint functions so as to satisfy the (non-proxy) constraints. Proxy + constraints could be used if the constraints functions are difficult or + impossible to optimize (e.g. if they're piecewise constant), in which case the + proxy constraints should be some approximation of the original constraints + that is well-enough behaved to permit successful optimization. + """ + + @abc.abstractproperty + def objective(self): + """Returns the objective function. + + Returns: + A 0d tensor that should be minimized. + """ + pass + + @property + def num_constraints(self): + """Returns the number of constraints. + + Returns: + An int containing the number of constraints. + + Raises: + ValueError: If the constraints (or proxy_constraints, if present) do not + have fully-known shapes, OR if proxy_constraints are present, and the + shapes of constraints and proxy_constraints are fully-known, but they're + different. + """ + constraints_shape = self.constraints.get_shape() + if self.proxy_constraints is None: + proxy_constraints_shape = constraints_shape + else: + proxy_constraints_shape = self.proxy_constraints.get_shape() + + if (constraints_shape is None or proxy_constraints_shape is None or + any([ii is None for ii in constraints_shape.as_list()]) or + any([ii is None for ii in proxy_constraints_shape.as_list()])): + raise ValueError( + "constraints and proxy_constraints must have fully-known shapes") + if constraints_shape != proxy_constraints_shape: + raise ValueError( + "constraints and proxy_constraints must have the same shape") + + size = 1 + for ii in constraints_shape.as_list(): + size *= ii + return int(size) + + @abc.abstractproperty + def constraints(self): + """Returns the vector of constraint functions. + + Letting g_i be the ith element of the constraints vector, the ith constraint + will be g_i <= 0. + + Returns: + A tensor of constraint functions. + """ + pass + + # This is a property, instead of an abstract property, since it doesn't need + # to be overridden: if proxy_constraints returns None, then there are no + # proxy constraints. + @property + def proxy_constraints(self): + """Returns the optional vector of proxy constraint functions. + + The difference between `constraints` and `proxy_constraints` is that, when + proxy constraints are present, the `constraints` are merely EVALUATED during + optimization, whereas the `proxy_constraints` are DIFFERENTIATED. If there + are no proxy constraints, then the `constraints` are both evaluated and + differentiated. + + For example, if we want to impose constraints on step functions, then we + could use these functions for `constraints`. However, because a step + function has zero gradient almost everywhere, we can't differentiate these + functions, so we would take `proxy_constraints` to be some differentiable + approximation of `constraints`. + + Returns: + A tensor of proxy constraint functions. + """ + return None diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..805554536610a5e2cc650ff0b47185f4fbd6fac5 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py @@ -0,0 +1,208 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines base class for `ConstrainedOptimizer`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.training import optimizer as train_optimizer + + +@six.add_metaclass(abc.ABCMeta) +class ConstrainedOptimizer(object): + """Base class representing a constrained optimizer. + + A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and + applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer, + which takes a tensor to minimize as a parameter to its minimize() method, a + constrained optimizer instead takes a ConstrainedMinimizationProblem. + """ + + def __init__(self, optimizer): + """Constructs a new `ConstrainedOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the + ConstraintedMinimizationProblem. + + Returns: + A new `ConstrainedOptimizer`. + """ + self._optimizer = optimizer + + @property + def optimizer(self): + """Returns the `tf.train.Optimizer` used for optimization.""" + return self._optimizer + + def minimize_unconstrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the unconstrained problem. + + Unlike `minimize_constrained`, this function ignores the `constraints` (and + `proxy_constraints`) portion of the minimization problem entirely, and only + minimizes `objective`. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + return self.optimizer.minimize( + minimization_problem.objective, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + @abc.abstractmethod + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + Unlike `minimize_unconstrained`, this function attempts to find a solution + that minimizes the `objective` portion of the minimization problem while + satisfying the `constraints` portion. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + pass + + def minimize(self, + minimization_problem, + unconstrained_steps=None, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + This method combines the functionality of `minimize_unconstrained` and + `minimize_constrained`. If global_step < unconstrained_steps, it will + perform an unconstrained update, and if global_step >= unconstrained_steps, + it will perform a constrained update. + + The reason for this functionality is that it may be best to initialize the + constrained optimizer with an approximate optimum of the unconstrained + problem. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + unconstrained_steps: int, number of steps for which we should perform + unconstrained updates, before transitioning to constrained updates. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + + Raises: + ValueError: If unconstrained_steps is provided, but global_step is not. + """ + + def unconstrained_fn(): + """Returns an `Op` for minimizing the unconstrained problem.""" + return self.minimize_unconstrained( + minimization_problem=minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + def constrained_fn(): + """Returns an `Op` for minimizing the constrained problem.""" + return self.minimize_constrained( + minimization_problem=minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + if unconstrained_steps is not None: + if global_step is None: + raise ValueError( + "global_step cannot be None if unconstrained_steps is provided") + unconstrained_steps_tensor = ops.convert_to_tensor(unconstrained_steps) + dtype = unconstrained_steps_tensor.dtype + return control_flow_ops.cond( + standard_ops.cast(global_step, dtype) < unconstrained_steps_tensor, + true_fn=unconstrained_fn, + false_fn=constrained_fn) + else: + return constrained_fn() diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..01c6e4f08afb93e37aa124f31ca7faa10b07d4d6 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -0,0 +1,375 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines `AdditiveExternalRegretOptimizer`. + +This optimizer minimizes a `ConstrainedMinimizationProblem` by introducing +Lagrange multipliers, and using `tf.train.Optimizer`s to jointly optimize over +the model parameters and Lagrange multipliers. + +For the purposes of constrained optimization, at least in theory, +external-regret minimization suffices if the `ConstrainedMinimizationProblem` +we're optimizing doesn't have any `proxy_constraints`, while swap-regret +minimization should be used if `proxy_constraints` are present. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The formulation used by the AdditiveExternalRegretOptimizer--which is simply the +usual Lagrangian formulation--can be found in Definition 1, and is discussed in +Section 3. This optimizer is most similar to Algorithm 3 in Appendix C.3, with +the two differences being that it uses proxy constraints (if they're provided) +in the update of the model parameters, and uses `tf.train.Optimizer`s, instead +of SGD, for the "inner" updates. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.constrained_optimization.python import constrained_optimizer + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer as train_optimizer + + +def _project_multipliers_wrt_euclidean_norm(multipliers, radius): + """Projects its argument onto the feasible region. + + The feasible region is the set of all vectors with nonnegative elements that + sum to at most `radius`. + + Args: + multipliers: 1d tensor, the Lagrange multipliers to project. + radius: float, the radius of the feasible region. + + Returns: + The 1d tensor that results from projecting `multipliers` onto the feasible + region w.r.t. the Euclidean norm. + + Raises: + ValueError: if the `multipliers` tensor does not have a fully-known shape, + or is not one-dimensional. + """ + multipliers_shape = multipliers.get_shape() + if multipliers_shape is None: + raise ValueError("multipliers must have known shape") + if multipliers_shape.ndims != 1: + raise ValueError( + "multipliers must be one dimensional (instead is %d-dimensional)" % + multipliers_shape.ndims) + dimension = multipliers_shape[0].value + if dimension is None: + raise ValueError("multipliers must have fully-known shape") + + def while_loop_condition(iteration, multipliers, inactive, old_inactive): + """Returns false if the while loop should terminate.""" + del multipliers # Needed by the body, but not the condition. + not_done = (iteration < dimension) + not_converged = standard_ops.reduce_any( + standard_ops.not_equal(inactive, old_inactive)) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, multipliers, inactive, old_inactive): + """Performs one iteration of the projection.""" + del old_inactive # Needed by the condition, but not the body. + iteration += 1 + scale = standard_ops.minimum( + 0.0, + (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive))) + multipliers += scale * inactive + new_inactive = standard_ops.to_float(multipliers > 0) + multipliers *= new_inactive + return (iteration, multipliers, new_inactive, inactive) + + iteration = standard_ops.constant(0) + inactive = standard_ops.ones_like(multipliers) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, multipliers, inactive, old_inactive = while_loop_body( + iteration, multipliers, inactive, inactive) + iteration, multipliers, inactive, old_inactive = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, multipliers, inactive, old_inactive), + name="euclidean_projection") + + return multipliers + + +@six.add_metaclass(abc.ABCMeta) +class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): + """Base class representing an `_ExternalRegretOptimizer`. + + This class contains most of the logic for performing constrained + optimization, minimizing external regret for the constraints player. What it + *doesn't* do is keep track of the internal state (the Lagrange multipliers). + Instead, the state is accessed via the _initial_state(), + _lagrange_multipliers(), _constraint_grad_and_var() and _projection_op() + methods. + + The reason for this is that we want to make it easy to implement different + representations of the internal state. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by `_ExternalRegretOptimizer`s--which is simply the usual + Lagrangian formulation--can be found in Definition 1, and is discussed in + Section 3. Such optimizers are most similar to Algorithm 3 in Appendix C.3. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `_ExternalRegretOptimizer`. + + The difference between `optimizer` and `constraint_optimizer` (if the latter + is provided) is that the former is used for learning the model parameters, + while the latter us used for the Lagrange multipliers. If no + `constraint_optimizer` is provided, then `optimizer` is used for both. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of the ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multipliers. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multipliers. + + Returns: + A new `_ExternalRegretOptimizer`. + """ + super(_ExternalRegretOptimizer, self).__init__(optimizer=optimizer) + self._constraint_optimizer = constraint_optimizer + + @property + def constraint_optimizer(self): + """Returns the `tf.train.Optimizer` used for the Lagrange multipliers.""" + return self._constraint_optimizer + + @abc.abstractmethod + def _initial_state(self, num_constraints): + pass + + @abc.abstractmethod + def _lagrange_multipliers(self, state): + pass + + @abc.abstractmethod + def _constraint_grad_and_var(self, state, gradient): + pass + + @abc.abstractmethod + def _projection_op(self, state, name=None): + pass + + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + The `optimizer` constructor parameter will be used to update the model + parameters, while the Lagrange multipliers will be updated using + `constrained_optimizer` (if provided) or `optimizer` (if not). + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + objective = minimization_problem.objective + + constraints = minimization_problem.constraints + proxy_constraints = minimization_problem.proxy_constraints + if proxy_constraints is None: + proxy_constraints = constraints + # Flatten both constraints tensors to 1d. + num_constraints = minimization_problem.num_constraints + constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) + proxy_constraints = standard_ops.reshape( + proxy_constraints, shape=(num_constraints,)) + + # We use a lambda to initialize the state so that, if this function call is + # inside the scope of a tf.control_dependencies() block, the dependencies + # will not be applied to the initializer. + state = standard_ops.Variable( + lambda: self._initial_state(num_constraints), + trainable=False, + name="external_regret_optimizer_state") + + multipliers = self._lagrange_multipliers(state) + loss = ( + objective + standard_ops.tensordot(multipliers, proxy_constraints, 1)) + multipliers_gradient = constraints + + update_ops = [] + if self.constraint_optimizer is None: + # If we don't have a separate constraint_optimizer, then we use + # self._optimizer for both the update of the model parameters, and that of + # the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + grads_and_vars.append( + self._constraint_grad_and_var(state, multipliers_gradient)) + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + else: + # If we have a separate constraint_optimizer, then we use self._optimizer + # for the update of the model parameters, and self._constraint_optimizer + # for that of the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + multiplier_grads_and_vars = [ + self._constraint_grad_and_var(state, multipliers_gradient) + ] + + gradients = [ + gradient for gradient, _ in grads_and_vars + multiplier_grads_and_vars + if gradient is not None + ] + with ops.control_dependencies(gradients): + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + update_ops.append( + self.constraint_optimizer.apply_gradients( + multiplier_grads_and_vars, name="optimizer_state_update")) + + with ops.control_dependencies(update_ops): + if global_step is None: + # If we don't have a global step, just project, and we're done. + return self._projection_op(state, name=name) + else: + # If we have a global step, then we need to increment it in addition to + # projecting. + projection_op = self._projection_op(state, name="project") + with ops.colocate_with(global_step): + global_step_op = state_ops.assign_add( + global_step, 1, name="global_step_increment") + return control_flow_ops.group(projection_op, global_step_op, name=name) + + +class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): + """A `ConstrainedOptimizer` based on external-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over Lagrange multipliers, + with the latter maximization using additive updates and an algorithm that + minimizes external regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer--which is simply the usual Lagrangian + formulation--can be found in Definition 1, and is discussed in Section 3. It + is most similar to Algorithm 3 in Appendix C.3, with the two differences being + that it uses proxy constraints (if they're provided) in the update of the + model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for the + "inner" updates. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + maximum_multiplier_radius=None): + """Constructs a new `AdditiveExternalRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multipliers. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multipliers. + maximum_multiplier_radius: float, an optional upper bound to impose on the + sum of the Lagrange multipliers. + + Returns: + A new `AdditiveExternalRegretOptimizer`. + + Raises: + ValueError: If the maximum_multiplier_radius parameter is nonpositive. + """ + super(AdditiveExternalRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + if maximum_multiplier_radius and (maximum_multiplier_radius <= 0.0): + raise ValueError("maximum_multiplier_radius must be strictly positive") + + self._maximum_multiplier_radius = maximum_multiplier_radius + + def _initial_state(self, num_constraints): + # For an AdditiveExternalRegretOptimizer, the internal state is simply a + # tensor of Lagrange multipliers with shape (m,), where m is the number of + # constraints. + return standard_ops.zeros((num_constraints,), dtype=dtypes.float32) + + def _lagrange_multipliers(self, state): + return state + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + if self._maximum_multiplier_radius: + projected_multipliers = _project_multipliers_wrt_euclidean_norm( + state, self._maximum_multiplier_radius) + else: + projected_multipliers = standard_ops.maximum(state, 0.0) + return state_ops.assign(state, projected_multipliers, name=name) diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4bf6271009161c4c449cd9c3cdab9fba90aa59 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py @@ -0,0 +1,136 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.external_regret_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import external_regret_optimizer +from tensorflow.contrib.constrained_optimization.python import test_util + +from tensorflow.python.ops import standard_ops +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class AdditiveExternalRegretOptimizerWrapper( + external_regret_optimizer.AdditiveExternalRegretOptimizer): + """Testing wrapper class around AdditiveExternalRegretOptimizer. + + This class is identical to AdditiveExternalRegretOptimizer, except that it + caches the internal optimization state when _lagrange_multipliers() is called, + so that we can test that the Lagrange multipliers take on their expected + values. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + maximum_multiplier_radius=None): + """Same as AdditiveExternalRegretOptimizer.__init__.""" + super(AdditiveExternalRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, + constraint_optimizer=constraint_optimizer, + maximum_multiplier_radius=maximum_multiplier_radius) + self._cached_lagrange_multipliers = None + + @property + def lagrange_multipliers(self): + """Returns the cached Lagrange multipliers.""" + return self._cached_lagrange_multipliers + + def _lagrange_multipliers(self, state): + """Caches the internal state for testing.""" + self._cached_lagrange_multipliers = super( + AdditiveExternalRegretOptimizerWrapper, + self)._lagrange_multipliers(state) + return self._cached_lagrange_multipliers + + +class ExternalRegretOptimizerTest(test.TestCase): + + def test_project_multipliers_wrt_euclidean_norm(self): + """Tests Euclidean projection routine on some known values.""" + multipliers1 = standard_ops.constant([-0.1, -0.6, -0.3]) + expected_projected_multipliers1 = np.array([0.0, 0.0, 0.0]) + + multipliers2 = standard_ops.constant([-0.1, 0.6, 0.3]) + expected_projected_multipliers2 = np.array([0.0, 0.6, 0.3]) + + multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1]) + expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0]) + + with self.test_session() as session: + projected_multipliers1 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers1, 1.0)) + projected_multipliers2 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers2, 1.0)) + projected_multipliers3 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers3, 1.0)) + + self.assertAllClose( + expected_projected_multipliers1, + projected_multipliers1, + rtol=0, + atol=1e-6) + self.assertAllClose( + expected_projected_multipliers2, + projected_multipliers2, + rtol=0, + atol=1e-6) + self.assertAllClose( + expected_projected_multipliers3, + projected_multipliers3, + rtol=0, + atol=1e-6) + + def test_additive_external_regret_optimizer(self): + """Tests that the Lagrange multipliers update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = AdditiveExternalRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0), + maximum_multiplier_radius=1.0) + train_op = optimizer.minimize_constrained(minimization_problem) + + expected_multipliers = [ + np.array([0.0, 0.0, 0.0]), + np.array([0.6, 0.0, 0.4]), + np.array([0.7, 0.0, 0.3]), + np.array([0.8, 0.0, 0.2]), + np.array([0.9, 0.0, 0.1]), + np.array([1.0, 0.0, 0.0]), + np.array([1.0, 0.0, 0.0]), + ] + + multipliers = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(multipliers) < len(expected_multipliers): + multipliers.append(session.run(optimizer.lagrange_multipliers)) + session.run(train_op) + + for expected, actual in zip(expected_multipliers, multipliers): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..04014ab4aebd6d9cd70653c53f9361320e803329 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -0,0 +1,595 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines `{Additive,Multiplicative}SwapRegretOptimizer`s. + +These optimizers minimize a `ConstrainedMinimizationProblem` by using a +swap-regret minimizing algorithm (either SGD or multiplicative weights) to learn +what weights should be associated with the objective function and constraints. +These algorithms do *not* use Lagrange multipliers, but the idea is similar. +The main differences between the formulation used here, and the standard +Lagrangian formulation, are that (i) the objective function is weighted, in +addition to the constraints, and (ii) we learn a matrix of weights, instead of a +vector. + +For the purposes of constrained optimization, at least in theory, +external-regret minimization suffices if the `ConstrainedMinimizationProblem` +we're optimizing doesn't have any `proxy_constraints`, while swap-regret +minimization should be used if `proxy_constraints` are present. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The formulation used by both of the SwapRegretOptimizers can be found in +Definition 2, and is discussed in Section 4. The +`MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4, +with the difference being that it uses `tf.train.Optimizer`s, instead of SGD, +for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in +that it performs additive (instead of multiplicative) updates of the stochastic +matrix. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import math + +import six + +from tensorflow.contrib.constrained_optimization.python import constrained_optimizer + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer as train_optimizer + + +def _maximal_eigenvector_power_method(matrix, + epsilon=1e-6, + maximum_iterations=100): + """Returns the maximal right-eigenvector of `matrix` using the power method. + + Args: + matrix: 2D Tensor, the matrix of which we will find the maximal + right-eigenvector. + epsilon: nonnegative float, if two iterations of the power method differ (in + L2 norm) by no more than epsilon, we will terminate. + maximum_iterations: nonnegative int, if we perform this many iterations, we + will terminate. + + Result: + The maximal right-eigenvector of `matrix`. + + Raises: + ValueError: If the epsilon or maximum_iterations parameters violate their + bounds. + """ + if epsilon <= 0.0: + raise ValueError("epsilon must be strictly positive") + if maximum_iterations <= 0: + raise ValueError("maximum_iterations must be strictly positive") + + def while_loop_condition(iteration, eigenvector, old_eigenvector): + """Returns false if the while loop should terminate.""" + not_done = (iteration < maximum_iterations) + not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, eigenvector, old_eigenvector): + """Performs one iteration of the power method.""" + del old_eigenvector # Needed by the condition, but not the body. + iteration += 1 + # We need to use tf.matmul() and tf.expand_dims(), instead of + # tf.tensordot(), since the former will infer the shape of the result, while + # the latter will not (tf.while_loop() needs the shapes). + new_eigenvector = standard_ops.matmul( + matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0] + new_eigenvector /= standard_ops.norm(new_eigenvector) + return (iteration, new_eigenvector, eigenvector) + + iteration = standard_ops.constant(0) + eigenvector = standard_ops.ones_like(matrix[:, 0]) + eigenvector /= standard_ops.norm(eigenvector) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, eigenvector, old_eigenvector = while_loop_body( + iteration, eigenvector, eigenvector) + iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, eigenvector, old_eigenvector), + name="power_method") + + return eigenvector + + +def _project_stochastic_matrix_wrt_euclidean_norm(matrix): + """Projects its argument onto the set of left-stochastic matrices. + + This algorithm is O(n^3) at worst, where `matrix` is n*n. It can be done in + O(n^2 * log(n)) time by sorting each column (and maybe better with a different + algorithm), but the algorithm implemented here is easier to implement in + TensorFlow. + + Args: + matrix: 2d square tensor, the matrix to project. + + Returns: + The 2d square tensor that results from projecting `matrix` onto the set of + left-stochastic matrices w.r.t. the Euclidean norm applied column-wise + (i.e. the Frobenius norm). + + Raises: + ValueError: if the `matrix` tensor does not have a fully-known shape, or is + not two-dimensional and square. + """ + matrix_shape = matrix.get_shape() + if matrix_shape is None: + raise ValueError("matrix must have known shape") + if matrix_shape.ndims != 2: + raise ValueError( + "matrix must be two dimensional (instead is %d-dimensional)" % + matrix_shape.ndims) + if matrix_shape[0] != matrix_shape[1]: + raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + (matrix_shape[0], matrix_shape[1])) + dimension = matrix_shape[0].value + if dimension is None: + raise ValueError("matrix must have fully-known shape") + + def while_loop_condition(iteration, matrix, inactive, old_inactive): + """Returns false if the while loop should terminate.""" + del matrix # Needed by the body, but not the condition. + not_done = (iteration < dimension) + not_converged = standard_ops.reduce_any( + standard_ops.not_equal(inactive, old_inactive)) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, matrix, inactive, old_inactive): + """Performs one iteration of the projection.""" + del old_inactive # Needed by the condition, but not the body. + iteration += 1 + scale = (1.0 - standard_ops.reduce_sum( + matrix, axis=0, keep_dims=True)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) + matrix += scale * inactive + new_inactive = standard_ops.to_float(matrix > 0) + matrix *= new_inactive + return (iteration, matrix, new_inactive, inactive) + + iteration = standard_ops.constant(0) + inactive = standard_ops.ones_like(matrix) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, matrix, inactive, old_inactive = while_loop_body( + iteration, matrix, inactive, inactive) + iteration, matrix, inactive, old_inactive = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, matrix, inactive, old_inactive), + name="euclidean_projection") + + return matrix + + +def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): + """Projects its argument onto the set of log-left-stochastic matrices. + + Args: + log_matrix: 2d square tensor, the element-wise logarithm of the matrix to + project. + + Returns: + The 2d square tensor that results from projecting exp(`matrix`) onto the set + of left-stochastic matrices w.r.t. the KL-divergence applied column-wise. + """ + + # For numerical reasons, make sure that the largest matrix element is zero + # before exponentiating. + log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) + log_matrix -= standard_ops.log( + standard_ops.reduce_sum( + standard_ops.exp(log_matrix), axis=0, keep_dims=True)) + return log_matrix + + +@six.add_metaclass(abc.ABCMeta) +class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): + """Base class representing a `_SwapRegretOptimizer`. + + This class contains most of the logic for performing constrained optimization, + minimizing external regret for the constraints player. What it *doesn't* do is + keep track of the internal state (the stochastic matrix). Instead, the state + is accessed via the _initial_state(), _stochastic_matrix(), + _constraint_grad_and_var() and _projection_op() methods. + + The reason for this is that we want to make it easy to implement different + representations of the internal state. For example, for additive updates, it's + most natural to store the stochastic matrix directly, whereas for + multiplicative updates, it's most natural to store its element-wise logarithm. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by `_SwapRegretOptimizer`s can be found in Definition 2, + and is discussed in Section 4. Such optimizers are most similar to Algorithm + 2 in Section 4. Most notably, the internal state is a left-stochastic matrix + of shape (m+1,m+1), where m is the number of constraints. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `_SwapRegretOptimizer`. + + The difference between `optimizer` and `constraint_optimizer` (if the latter + is provided) is that the former is used for learning the model parameters, + while the latter us used for the update to the constraint/objective weight + matrix (the analogue of Lagrange multipliers). If no `constraint_optimizer` + is provided, then `optimizer` is used for both. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + + Returns: + A new `_SwapRegretOptimizer`. + """ + super(_SwapRegretOptimizer, self).__init__(optimizer=optimizer) + self._constraint_optimizer = constraint_optimizer + + @property + def constraint_optimizer(self): + """Returns the `tf.train.Optimizer` used for the matrix.""" + return self._constraint_optimizer + + @abc.abstractmethod + def _initial_state(self, num_constraints): + pass + + @abc.abstractmethod + def _stochastic_matrix(self, state): + pass + + def _distribution(self, state): + distribution = _maximal_eigenvector_power_method( + self._stochastic_matrix(state)) + distribution = standard_ops.abs(distribution) + distribution /= standard_ops.reduce_sum(distribution) + return distribution + + @abc.abstractmethod + def _constraint_grad_and_var(self, state, gradient): + pass + + @abc.abstractmethod + def _projection_op(self, state, name=None): + pass + + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + The `optimizer` constructor parameter will be used to update the model + parameters, while the constraint/objective weight matrix (the analogue of + Lagrange multipliers) will be updated using `constrained_optimizer` (if + provided) or `optimizer` (if not). Whether the matrix updates are additive + or multiplicative depends on the derived class. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + objective = minimization_problem.objective + + constraints = minimization_problem.constraints + proxy_constraints = minimization_problem.proxy_constraints + if proxy_constraints is None: + proxy_constraints = constraints + # Flatten both constraints tensors to 1d. + num_constraints = minimization_problem.num_constraints + constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) + proxy_constraints = standard_ops.reshape( + proxy_constraints, shape=(num_constraints,)) + + # We use a lambda to initialize the state so that, if this function call is + # inside the scope of a tf.control_dependencies() block, the dependencies + # will not be applied to the initializer. + state = standard_ops.Variable( + lambda: self._initial_state(num_constraints), + trainable=False, + name="swap_regret_optimizer_state") + + zero_and_constraints = standard_ops.concat( + (standard_ops.zeros((1,)), constraints), axis=0) + objective_and_proxy_constraints = standard_ops.concat( + (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0) + + distribution = self._distribution(state) + loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints, + 1) + matrix_gradient = standard_ops.matmul( + standard_ops.expand_dims(zero_and_constraints, 1), + standard_ops.expand_dims(distribution, 0)) + + update_ops = [] + if self.constraint_optimizer is None: + # If we don't have a separate constraint_optimizer, then we use + # self._optimizer for both the update of the model parameters, and that of + # the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + grads_and_vars.append( + self._constraint_grad_and_var(state, matrix_gradient)) + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + else: + # If we have a separate constraint_optimizer, then we use self._optimizer + # for the update of the model parameters, and self._constraint_optimizer + # for that of the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + matrix_grads_and_vars = [ + self._constraint_grad_and_var(state, matrix_gradient) + ] + + gradients = [ + gradient for gradient, _ in grads_and_vars + matrix_grads_and_vars + if gradient is not None + ] + with ops.control_dependencies(gradients): + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + update_ops.append( + self.constraint_optimizer.apply_gradients( + matrix_grads_and_vars, name="optimizer_state_update")) + + with ops.control_dependencies(update_ops): + if global_step is None: + # If we don't have a global step, just project, and we're done. + return self._projection_op(state, name=name) + else: + # If we have a global step, then we need to increment it in addition to + # projecting. + projection_op = self._projection_op(state, name="project") + with ops.colocate_with(global_step): + global_step_op = state_ops.assign_add( + global_step, 1, name="global_step_increment") + return control_flow_ops.group(projection_op, global_step_op, name=name) + + +class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): + """A `ConstrainedOptimizer` based on swap-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over constraint/objective + weight matrix (the analogue of Lagrange multipliers), with the latter + maximization using additive updates and an algorithm that minimizes swap + regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer can be found in Definition 2, and is + discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with + the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates, and performs additive (instead of multiplicative) updates + of the stochastic matrix. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `AdditiveSwapRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + + Returns: + A new `AdditiveSwapRegretOptimizer`. + """ + # TODO(acotter): add a parameter determining the initial values of the + # matrix elements (like initial_multiplier_radius in + # MultiplicativeSwapRegretOptimizer). + super(AdditiveSwapRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + def _initial_state(self, num_constraints): + # For an AdditiveSwapRegretOptimizer, the internal state is a tensor of + # shape (m+1,m+1), where m is the number of constraints, representing a + # left-stochastic matrix. + dimension = num_constraints + 1 + # Initialize by putting all weight on the objective, and none on the + # constraints. + return standard_ops.concat( + (standard_ops.ones( + (1, dimension)), standard_ops.zeros((dimension - 1, dimension))), + axis=0) + + def _stochastic_matrix(self, state): + return state + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + return state_ops.assign( + state, + _project_stochastic_matrix_wrt_euclidean_norm(state), + name=name) + + +class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): + """A `ConstrainedOptimizer` based on swap-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over constraint/objective + weight matrix (the analogue of Lagrange multipliers), with the latter + maximization using multiplicative updates and an algorithm that minimizes swap + regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer can be found in Definition 2, and is + discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with + the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + minimum_multiplier_radius=1e-3, + initial_multiplier_radius=None): + """Constructs a new `MultiplicativeSwapRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + minimum_multiplier_radius: float, each element of the matrix will be lower + bounded by `minimum_multiplier_radius` divided by one plus the number of + constraints. + initial_multiplier_radius: float, the initial value of each element of the + matrix associated with a constraint (i.e. excluding those elements + associated with the objective) will be `initial_multiplier_radius` + divided by one plus the number of constraints. Defaults to the value of + `minimum_multiplier_radius`. + + Returns: + A new `MultiplicativeSwapRegretOptimizer`. + + Raises: + ValueError: If the two radius parameters are inconsistent. + """ + super(MultiplicativeSwapRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + if (minimum_multiplier_radius <= 0.0) or (minimum_multiplier_radius >= 1.0): + raise ValueError("minimum_multiplier_radius must be in the range (0,1)") + if initial_multiplier_radius is None: + initial_multiplier_radius = minimum_multiplier_radius + elif (initial_multiplier_radius < + minimum_multiplier_radius) or (minimum_multiplier_radius > 1.0): + raise ValueError("initial_multiplier_radius must be in the range " + "[minimum_multiplier_radius,1]") + + self._minimum_multiplier_radius = minimum_multiplier_radius + self._initial_multiplier_radius = initial_multiplier_radius + + def _initial_state(self, num_constraints): + # For a MultiplicativeSwapRegretOptimizer, the internal state is a tensor of + # shape (m+1,m+1), where m is the number of constraints, representing the + # element-wise logarithm of a left-stochastic matrix. + dimension = num_constraints + 1 + # Initialize by putting as much weight as possible on the objective, and as + # little as possible on the constraints. + log_initial_one = math.log(1.0 - (self._initial_multiplier_radius * + (dimension - 1) / (dimension))) + log_initial_zero = math.log(self._initial_multiplier_radius / dimension) + return standard_ops.concat( + (standard_ops.constant( + log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), + standard_ops.constant( + log_initial_zero, + dtype=dtypes.float32, + shape=(dimension - 1, dimension))), + axis=0) + + def _stochastic_matrix(self, state): + return standard_ops.exp(state) + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + # Gets the dimension of the state (num_constraints + 1)--all of these + # assertions are of things that should be impossible, since the state + # passed into this method will have the same shape as that returned by + # _initial_state(). + state_shape = state.get_shape() + assert state_shape is not None + assert state_shape.ndims == 2 + assert state_shape[0] == state_shape[1] + dimension = state_shape[0].value + assert dimension is not None + + minimum_log_multiplier = standard_ops.log( + self._minimum_multiplier_radius / standard_ops.to_float(dimension)) + + return state_ops.assign( + state, + standard_ops.maximum( + _project_log_stochastic_matrix_wrt_kl_divergence(state), + minimum_log_multiplier), + name=name) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34c4543dca97e12c8335e4c90b849820edaefa81 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py @@ -0,0 +1,212 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.swap_regret_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import swap_regret_optimizer +from tensorflow.contrib.constrained_optimization.python import test_util + +from tensorflow.python.ops import standard_ops +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class AdditiveSwapRegretOptimizerWrapper( + swap_regret_optimizer.AdditiveSwapRegretOptimizer): + """Testing wrapper class around AdditiveSwapRegretOptimizer. + + This class is identical to AdditiveSwapRegretOptimizer, except that it caches + the internal optimization state when _stochastic_matrix() is called, so that + we can test that the stochastic matrices take on their expected values. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Same as AdditiveSwapRegretOptimizer.__init__().""" + super(AdditiveSwapRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + self._cached_stochastic_matrix = None + + @property + def stochastic_matrix(self): + """Returns the cached stochastic matrix.""" + return self._cached_stochastic_matrix + + def _stochastic_matrix(self, state): + """Caches the internal state for testing.""" + self._cached_stochastic_matrix = super(AdditiveSwapRegretOptimizerWrapper, + self)._stochastic_matrix(state) + return self._cached_stochastic_matrix + + +class MultiplicativeSwapRegretOptimizerWrapper( + swap_regret_optimizer.MultiplicativeSwapRegretOptimizer): + """Testing wrapper class around MultiplicativeSwapRegretOptimizer. + + This class is identical to MultiplicativeSwapRegretOptimizer, except that it + caches the internal optimization state when _stochastic_matrix() is called, so + that we can test that the stochastic matrices take on their expected values. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + minimum_multiplier_radius=None, + initial_multiplier_radius=None): + """Same as MultiplicativeSwapRegretOptimizer.__init__().""" + super(MultiplicativeSwapRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, + constraint_optimizer=constraint_optimizer, + minimum_multiplier_radius=1e-3, + initial_multiplier_radius=initial_multiplier_radius) + self._cached_stochastic_matrix = None + + @property + def stochastic_matrix(self): + """Returns the cached stochastic matrix.""" + return self._cached_stochastic_matrix + + def _stochastic_matrix(self, state): + """Caches the internal state for testing.""" + self._cached_stochastic_matrix = super( + MultiplicativeSwapRegretOptimizerWrapper, + self)._stochastic_matrix(state) + return self._cached_stochastic_matrix + + +class SwapRegretOptimizerTest(test.TestCase): + + def test_maximum_eigenvector_power_method(self): + """Tests power method routine on some known left-stochastic matrices.""" + matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) + matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) + + with self.test_session() as session: + eigenvector1 = session.run( + swap_regret_optimizer._maximal_eigenvector_power_method( + standard_ops.constant(matrix1))) + eigenvector2 = session.run( + swap_regret_optimizer._maximal_eigenvector_power_method( + standard_ops.constant(matrix2))) + + # Check that eigenvector1 and eigenvector2 are eigenvectors of matrix1 and + # matrix2 (respectively) with associated eigenvalue 1. + matrix_eigenvector1 = np.tensordot(matrix1, eigenvector1, axes=1) + matrix_eigenvector2 = np.tensordot(matrix2, eigenvector2, axes=1) + self.assertAllClose(eigenvector1, matrix_eigenvector1, rtol=0, atol=1e-6) + self.assertAllClose(eigenvector2, matrix_eigenvector2, rtol=0, atol=1e-6) + + def test_project_stochastic_matrix_wrt_euclidean_norm(self): + """Tests Euclidean projection routine on some known values.""" + matrix = standard_ops.constant([[-0.1, -0.1, 0.4], [-0.8, 0.4, 1.2], + [-0.3, 0.1, 0.2]]) + expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], + [0.4, 0.3, 0.0]]) + + with self.test_session() as session: + projected_matrix = session.run( + swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm( + matrix)) + + self.assertAllClose( + expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6) + + def test_project_log_stochastic_matrix_wrt_kl_divergence(self): + """Tests KL-divergence projection routine on some known values.""" + matrix = standard_ops.constant([[0.2, 0.8, 0.6], [0.1, 0.2, 1.5], + [0.2, 1.0, 0.9]]) + expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], + [0.4, 0.5, 0.3]]) + + with self.test_session() as session: + projected_matrix = session.run( + standard_ops.exp( + swap_regret_optimizer. + _project_log_stochastic_matrix_wrt_kl_divergence( + standard_ops.log(matrix)))) + + self.assertAllClose( + expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6) + + def test_additive_swap_regret_optimizer(self): + """Tests that the stochastic matrices update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = AdditiveSwapRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0)) + train_op = optimizer.minimize_constrained(minimization_problem) + + # Calculated using a numpy+python implementation of the algorithm. + expected_matrices = [ + np.array([[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), + np.array([[0.66666667, 1.0, 1.0, 1.0], [0.26666667, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.06666667, 0.0, 0.0, 0.0]]), + np.array([[0.41666667, 0.93333333, 1.0, + 0.98333333], [0.46666667, 0.05333333, 0.0, + 0.01333333], [0.0, 0.0, 0.0, 0.0], + [0.11666667, 0.01333333, 0.0, 0.00333333]]), + ] + + matrices = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(matrices) < len(expected_matrices): + matrices.append(session.run(optimizer.stochastic_matrix)) + session.run(train_op) + + for expected, actual in zip(expected_matrices, matrices): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + def test_multiplicative_swap_regret_optimizer(self): + """Tests that the stochastic matrices update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = MultiplicativeSwapRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0), + initial_multiplier_radius=0.8) + train_op = optimizer.minimize_constrained(minimization_problem) + + # Calculated using a numpy+python implementation of the algorithm. + expected_matrices = [ + np.array([[0.4, 0.4, 0.4, 0.4], [0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]]), + np.array([[0.36999014, 0.38528351, 0.38528351, 0.38528351], [ + 0.23517483, 0.21720297, 0.21720297, 0.21720297 + ], [0.17774131, 0.18882719, 0.18882719, 0.18882719], + [0.21709373, 0.20868632, 0.20868632, 0.20868632]]), + np.array([[0.33972109, 0.36811863, 0.37118462, 0.36906575], [ + 0.27114826, 0.23738228, 0.23376693, 0.23626491 + ], [0.15712313, 0.17641793, 0.17858959, 0.17708679], + [0.23200752, 0.21808115, 0.21645886, 0.21758255]]), + ] + + matrices = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(matrices) < len(expected_matrices): + matrices.append(session.run(optimizer.stochastic_matrix)) + session.run(train_op) + + for expected, actual in zip(expected_matrices, matrices): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/test_util.py b/tensorflow/contrib/constrained_optimization/python/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..704b36ca4c9cf94e7c304f9bed4f6ac7ca275deb --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/test_util.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains helpers used by tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.constrained_optimization.python import constrained_minimization_problem + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import standard_ops + + +class ConstantMinimizationProblem( + constrained_minimization_problem.ConstrainedMinimizationProblem): + """A `ConstrainedMinimizationProblem` with constant constraint violations. + + This minimization problem is intended for use in performing simple tests of + the Lagrange multiplier (or equivalent) update in the optimizers. There is a + one-element "dummy" model parameter, but it should be ignored. + """ + + def __init__(self, constraints): + """Constructs a new `ConstantMinimizationProblem'. + + Args: + constraints: 1d numpy array, the constant constraint violations. + + Returns: + A new `ConstantMinimizationProblem'. + """ + # We make an fake 1-parameter linear objective so that we don't get a "no + # variables to optimize" error. + self._objective = standard_ops.Variable(0.0, dtype=dtypes.float32) + self._constraints = standard_ops.constant(constraints, dtype=dtypes.float32) + + @property + def objective(self): + """Returns the objective function.""" + return self._objective + + @property + def constraints(self): + """Returns the constant constraint violations.""" + return self._constraints diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e8036d63aeeac224b226899c036416a06b4ffe65 --- /dev/null +++ b/tensorflow/contrib/control_flow/BUILD @@ -0,0 +1,53 @@ +# New implementations of control flow ops + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "control_flow", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2", + ], +) + +py_library( + name = "cond_v2", + srcs = ["python/cond_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:c_api_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:function_def_to_graph", + "//tensorflow/python:functional_ops_gen", + "//tensorflow/python:gradients", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:util", + ], +) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["python/cond_v2_test.py"], + additional_deps = [ + ":cond_v2", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:training", + ], + grpc_enabled = True, +) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/control_flow/__init__.py similarity index 70% rename from tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py rename to tensorflow/contrib/control_flow/__init__.py index f3a645eafc249d1c39e0d4a238ae7ec8755c78d8..582af2cf10a3d92dd8611b0f2826625e3acfb099 100644 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py +++ b/tensorflow/contrib/control_flow/__init__.py @@ -12,21 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.""" + +"""New implementations of TF control flow ops. + +@@cond_v2 +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +# pylint: disable=unused-import +from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 +# pylint: enable=unused-import -_allowed_symbols = [ - "effective_sample_size", - "potential_scale_reduction", -] +from tensorflow.python.util.all_util import remove_undocumented -remove_undocumented(__name__, _allowed_symbols) +remove_undocumented(__name__) diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffad9caa92d2d3be8f598758a443b0eceb8d4d8 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -0,0 +1,429 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 and gradient. + +This is a version of cond that emits a single If op, as well as the gradient +function for If ops produced by cond_v2. This will eventually replace the +current tf.cond implementation once it reaches feature and performance parity. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import function +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.util import compat + + +# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify +# that they aren't part of the official public API. These protected members +# often need to be used by implementation code however. Rather than litter the +# code with pylint comments, we ignore protected access violations for +# readability. +# pylint: disable=protected-access + + +def cond_v2(pred, true_fn, false_fn, name="cond"): + """Like tf.cond, except emits a single If op.""" + with ops.name_scope(name) as scope: + true_graph = function.func_graph_from_py_func(true_fn, [], [], + name="%s_true" % scope) + false_graph = function.func_graph_from_py_func(false_fn, [], [], + name="%s_false" % scope) + _check_same_outputs(true_graph, false_graph) + + # Add inputs to true_graph and false_graph to make them match. Note that + # this modifies true_graph and false_graph. + cond_inputs = _make_inputs_match(true_graph, false_graph, + true_graph.extra_inputs, + false_graph.extra_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # the gradient computation. + + true_intermediates = _get_intermediates(true_graph) + false_intermediates = _get_intermediates(false_graph) + + # Save the original number of outputs to return to the caller. + num_cond_outputs = len(true_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_outputs, extra_false_outputs = _pad_params( + true_graph, false_graph, true_intermediates, false_intermediates) + + true_graph.outputs.extend(extra_true_outputs) + false_graph.outputs.extend(extra_false_outputs) + + # Create the If op. + tensors = gen_functional_ops._if( + pred, cond_inputs, [t.dtype for t in true_graph.outputs], + _create_new_tf_function(true_graph), + _create_new_tf_function(false_graph), + name=scope) + + return tensors[:num_cond_outputs] + + +@ops.RegisterGradient("If") +def _IfGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of an If op produced by cond_v2.""" + true_graph, false_graph = _get_func_graphs(op) + + # Create grad functions that compute the gradient of the true/false forward + # graphs. These functions will capture tensors from the forward pass + # functions. + true_grad_graph = _create_grad_func( + true_graph, grads, _get_grad_fn_name(true_graph)) + false_grad_graph = _create_grad_func( + false_graph, grads, _get_grad_fn_name(false_graph)) + + assert ([t.dtype for t in true_grad_graph.outputs] == + [t.dtype for t in false_grad_graph.outputs]) + + # Match up the captured grad function inputs with outputs of 'op' and other + # external tensors. + true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) + false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + + # Make the inputs to true_grad_graph and false_grad_graph match. Note that + # this modifies true_grad_graph and false_grad_graph. + grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, + true_grad_inputs, false_grad_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # higher-order gradient computations. + + true_grad_intermediates = _get_intermediates(true_grad_graph) + false_grad_intermediates = _get_intermediates(false_grad_graph) + + # Save the original number of gradient outputs to return. + num_grad_outputs = len(true_grad_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( + true_grad_graph, false_grad_graph, + true_grad_intermediates, false_grad_intermediates) + + true_grad_graph.outputs.extend(extra_true_grad_outputs) + false_grad_graph.outputs.extend(extra_false_grad_outputs) + + # Create the gradient If op. + tensors = gen_functional_ops._if( + op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + _create_new_tf_function(true_grad_graph), + _create_new_tf_function(false_grad_graph)) + + # The predicate has no gradient. + return [None] + tensors[:num_grad_outputs] + + +def _get_func_graphs(if_op): + """Returns `_FuncGraph`s for the input op branches. + + Args: + if_op: The _If Operation. + + Returns: + A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. + """ + def _get_func_graph_for_branch(branch_name): + extra_inputs = if_op.inputs[1:] # First input is pred. + input_shapes = [t.shape for t in extra_inputs] + func_name = if_op.get_attr(branch_name).name + fdef = if_op.graph._get_function(func_name).definition + func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + func_graph.extra_inputs = extra_inputs + func_graph.extra_args = func_graph.inputs + func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + return func_graph + + return (_get_func_graph_for_branch("then_branch"), + _get_func_graph_for_branch("else_branch")) + + +def _grad_fn(func_graph, grads): + """The gradient function for each conditional branch. + + This function builds the gradient graph of the corresponding forward-pass + conditional branch in `func_graph`. This is done by differentiating + func_graph's outputs w.r.t. its inputs. + + Args: + func_graph: function._FuncGraph. The corresponding forward-pass function. + grads: The list of input gradient Tensors. + + Returns: + The output gradient Tensors. + """ + # Filter out untrainable function outputs. + # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes + # cause _GradientsHelper to raise an exception (e.g. the implementation + # doesn't expect 'ys' to contain boolean tensors). + assert len(func_graph.outputs) == len(grads) + ys = [] + grad_ys = [] + for y, grad_y in zip(func_graph.outputs, grads): + if not gradients_impl._IsTrainable(y): + continue + ys.append(y) + grad_ys.append(grad_y) + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _get_grad_inputs. + result = gradients_impl._GradientsHelper( + ys, func_graph.inputs, grad_ys=grad_ys, + src_graph=func_graph) + + # Functions can't return None; replace Nones with zero tensors. + # TODO(b/80444525): don't return anything here and make _IfGrad return None if + # both branches have zero gradient. + for i in range(len(result)): + if result[i] is None: + result[i] = array_ops.zeros_like(func_graph.inputs[i]) + + return result + + +def _create_grad_func(func_graph, grads, name): + """Returns the _FuncGraph representation of _grad_fn.""" + return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) + + +def _get_grad_inputs(if_op, cond_graph, grad_graph): + """Returns the tensors we should pass to grad_graph. + + This method handles tensors captured from cond_graph in grad_graph. It + converts these to suitable input tensors from the outer graph. + + Args: + if_op: Operation. The forward-pass If op that uses cond_graph. + cond_graph: function._FuncGraph. The forward-pass function. + grad_graph: function._FuncGraph. The gradients function. + + Returns: + A list of inputs tensors to be passed to grad_graph. + """ + inputs = [] + + # Maps placeholders in cond_graph -> input tensor in outer graph. + forward_input_map = {v: k for k, v in cond_graph._captured.items()} + + for t in grad_graph.extra_inputs: + if t.graph == ops.get_default_graph(): + # t is in the outer graph (e.g. one of the input gradients). + inputs.append(t) + elif t in forward_input_map: + # t is an input placeholder in cond_graph. Get the corresponding input + # tensor in the outer graph. + assert t.graph == cond_graph + assert forward_input_map[t].graph == ops.get_default_graph() + inputs.append(forward_input_map[t]) + else: + # t is an intermediate value in cond_graph. Get the corresponding output + # of 'if_op' (note that all intermediate values are outputs). + assert t.graph == cond_graph + output_idx = cond_graph.outputs.index(t) + inputs.append(if_op.outputs[output_idx]) + + return inputs + + +def _create_new_tf_function(func_graph): + """Converts func_graph to a TF_Function and adds it to the current graph. + + Args: + func_graph: function._FuncGraph + + Returns: + The name of the new TF_Function. + """ + c_func = c_api.TF_GraphToFunction_wrapper( + func_graph._c_graph, + compat.as_str(func_graph.name), + False, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in func_graph.inputs], + [t._as_tf_output() for t in func_graph.outputs], + [], + None, # opts + None) # description + _ = c_api_util.ScopedTFFunction(c_func) + + # TODO(b/109833212): this sucks, we're serializing the TF_Function*, + # deserializing it into a Python FunctionDef, then reserializing it to create + # a new TF_Function that we add to the graph. + fdef = function.function_def_from_tf_function(c_func) + defined_func = function._from_definition(fdef) + defined_func.add_to_graph(ops.get_default_graph()) + + return func_graph.name + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that aren't inputs or outputs.""" + intermediates = [] + for op in func_graph.get_operations(): + for t in op.outputs: + if t in func_graph.inputs: continue + if t in func_graph.outputs: continue + intermediates.append(t) + return intermediates + + +def _separate_unique_inputs(true_inputs, false_inputs): + """Separates tensors appearing only in true_inputs or false_inputs, or both. + + Args: + true_inputs: list of Tensors + false_inputs: list of Tensors + + Returns: + Three lists of Tensors: + 1. The tensors that appear in both true_inputs and false_inputs + 2. The tensors that only appear in true_inputs + 3. The tensors that only appear in false_inputs + """ + true_inputs = set(true_inputs) + false_inputs = set(false_inputs) + + shared_inputs = true_inputs.intersection(false_inputs) + true_only_inputs = true_inputs - false_inputs + false_only_inputs = false_inputs - true_inputs + + return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) + + +def _pad_params(true_graph, false_graph, true_params, false_params): + """Returns new param lists that have matching signatures. + + This is done by mirroring each param list in the other using dummy params. + There is no merging of params. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_params: a list of Tensors from true_graph + false_params: a list of Tensors from false_graph + + Returns: + A new list of Tensors in true_graph and a new list of Tensors in + false_graph. The two lists have the same number of Tensors, with matching + types and shapes across the lists. + """ + new_true_params = (true_params + + _create_dummy_params(true_graph, false_params)) + new_false_inputs = (_create_dummy_params(false_graph, true_params) + + false_params) + return new_true_params, new_false_inputs + + +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): + """Modifies true_graph and false_graph so they have the same input signature. + + This method reorders and/or adds parameters to true_graph and false_graph so + they have the same input signature, and updates the 'inputs', 'extra_inputs', + and '_captured' fields of both graphs accordingly. It uses the input tensors + from the outer graph to avoid duplicating shared arguments. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_inputs: a list of Tensors in the outer graph. The inputs for + true_graph. + false_inputs: a list of Tensors in the outer graph. The inputs for + false_graph. + + Returns: + A new list of Tensors from the outer graph that are the new inputs for both + true_graph and false_graph. This is a deduped version of true_inputs + + false_inputs. + """ + shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( + true_inputs, false_inputs) + + new_inputs = shared_inputs + true_only_inputs + false_only_inputs + + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + + true_graph.inputs = ( + [true_input_to_param[t] for t in shared_inputs] + + [true_input_to_param[t] for t in true_only_inputs] + + _create_dummy_params(true_graph, false_only_inputs)) + + false_graph.inputs = ( + [false_input_to_param[t] for t in shared_inputs] + + _create_dummy_params(false_graph, true_only_inputs) + + [false_input_to_param[t] for t in false_only_inputs]) + + # Rewrite the _FuncGraphs' state to reflect the new inputs. + true_graph.extra_inputs = new_inputs + false_graph.extra_inputs = new_inputs + + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) + false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + + return new_inputs + + +def _create_dummy_params(func_graph, template_tensors): + """Creates tensors in func_graph to represent template_tensors. + + Args: + func_graph: function._FuncGraph. + template_tensors: a list of tensors in the outer graph. + + Returns: + A list of tensors in func_graph. + """ + with func_graph.as_default(): + return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) + for t in template_tensors] + + +def _get_grad_fn_name(func_graph): + """Returns a unique name to use for the grad function of `func_graph`.""" + name = "%s_grad" % func_graph.name + + base_name = name + counter = 1 + if ops.get_default_graph()._is_function(name): + name = "%s_%s" % (base_name, counter) + counter += 1 + + return name + + +def _check_same_outputs(true_graph, false_graph): + """Raises an error if true_graph and false_graph have different outputs.""" + true_output_types = [t.dtype for t in true_graph.outputs] + false_output_types = [t.dtype for t in false_graph.outputs] + if (len(true_graph.outputs) != len(false_graph.outputs) or + true_output_types != false_output_types): + raise ValueError( + "true_fn() and false_fn() must return the same number and type of " + "arguments, got:\n" + " true_fn: %s\n" + " false_fn: %s" % (true_output_types, false_output_types)) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..338601aa2c5ee8ffc97aa2e07ff05a2d17531936 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -0,0 +1,171 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.control_flow.python import cond_v2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver + + +class NewCondTest(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + with self.test_session() as sess: + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: True}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: False}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNoInputs(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + def true_fn(): + return constant_op.constant(1.0) + + def false_fn(): + return constant_op.constant(2.0) + + out = cond_v2.cond_v2(pred, true_fn, false_fn) + + with self.test_session() as sess: + self.assertEqual(sess.run(out, {pred: True}), [1.0]) + self.assertEqual(sess.run(out, {pred: False}), [2.0]) + + def testSecondDerivative(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + with self.test_session() as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testGradientOfDeserializedCond(self): + with ops.Graph().as_default(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + ops.add_to_collection("x", x) + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + ops.add_to_collection("pred", pred) + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + for c in cond: + ops.add_to_collection("cond", c) + meta_graph = saver.export_meta_graph() + + with ops.Graph().as_default() as g: + saver.import_meta_graph(meta_graph) + x = ops.get_collection("x")[0] + pred = ops.get_collection("pred")[0] + cond = ops.get_collection("cond") + cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") + cond_grad_grad = gradients_impl.gradients( + cond_grad, [x], name="cond_grad_grad") + with self.test_session(graph=g) as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD index 8ec706df74e2c91345c4bf7a506fdb424a996773..fa44c4d54e1ee871feb425115525b1cf8b732214 100644 --- a/tensorflow/contrib/copy_graph/BUILD +++ b/tensorflow/contrib/copy_graph/BUILD @@ -41,15 +41,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index b806799202bff4f2f6dbf717fbeea74a04b8cd6e..a0dd3881a86c19e47ccb65f84a2477a55626b81c 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -201,7 +201,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it #stores String-based info such as name, device and type of the op. #Unique to every Operation instance. - new_node_def = deepcopy(op._node_def) + new_node_def = deepcopy(op.node_def) #Change the name new_node_def.name = new_name @@ -211,14 +211,13 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #Make a copy of the op_def too. #Its unique to every _type_ of Operation. - op_def = deepcopy(op._op_def) + op_def = deepcopy(op.op_def) #Initialize a new Operation instance new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD index 7aad4abdb908d0284b85137bff842bd0f38d09c6..5c1a17df4f95f3c4d05b286de0e3d7b009a76bd7 100644 --- a/tensorflow/contrib/crf/BUILD +++ b/tensorflow/contrib/crf/BUILD @@ -40,15 +40,3 @@ cuda_py_tests( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 721dc4d0801d1f0e116921888e3851a95e0b72b0..74f2ec22ffaab1654e5cd38169258fb87d307ad4 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -152,6 +152,22 @@ class CrfTest(test.TestCase): self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + def testCrfLogNormZeroSeqLength(self): + """ + Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. + """ + with self.test_session() as sess: + inputs = constant_op.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = constant_op.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = constant_op.constant(np.zeros([2], + dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=np.float32) + log_norm = crf.crf_log_norm(inputs, sequence_lengths, transition_params) + tf_log_norm = sess.run(log_norm) + self.assertAllClose(tf_log_norm, expected_log_norm) + def testCrfLogLikelihood(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) @@ -281,6 +297,21 @@ class CrfTest(test.TestCase): self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), expected_max_sequence[:sequence_lengths]) + def testCrfDecodeZeroSeqLength(self): + """ + Test that crf_decode works when sequence_length contains one or more zeros. + """ + with self.test_session() as sess: + inputs = constant_op.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = constant_op.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = constant_op.constant(np.zeros([2], + dtype=np.int32)) + tags, scores = crf.crf_decode(inputs, transition_params, sequence_lengths) + tf_tags, tf_scores = sess.run([tags, scores]) + self.assertEqual(len(tf_tags.shape), 2) + self.assertEqual(len(tf_scores.shape), 1) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index a30bf06396117034f7cfe461d5e365b8a4a38a3f..2d2cbdc1990ed9d8e58c0032cbc141a52271838f 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -52,6 +52,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -90,9 +91,13 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = array_ops.reshape( math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) - return array_ops.gather_nd( + sequence_scores = array_ops.gather_nd( array_ops.squeeze(inputs, [1]), array_ops.concat([example_inds, tag_indices], axis=1)) + sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(sequence_scores), + sequence_scores) + return sequence_scores def _multi_seq_fn(): # Compute the scores of the given tag sequence. @@ -128,7 +133,12 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over # the "initial state" (the unary potentials). def _single_seq_fn(): - return math_ops.reduce_logsumexp(first_input, [1]) + log_norm = math_ops.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(log_norm), + log_norm) + return log_norm def _multi_seq_fn(): """Forward computation of alpha values.""" @@ -137,13 +147,21 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) + # Sequence length is not allowed to be less than zero. + sequence_lengths_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_lengths.dtype), + sequence_lengths - 1) _, alphas = rnn.dynamic_rnn( cell=forward_cell, inputs=rest_of_input, - sequence_length=sequence_lengths - 1, + sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=dtypes.float32) log_norm = math_ops.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(log_norm), + log_norm) return log_norm max_seq_len = array_ops.shape(inputs)[1] @@ -479,15 +497,17 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + # Sequence length is not allowed to be less than zero. + sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, - sequence_length=sequence_length - 1, + sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] - backpointers, sequence_length - 1, seq_dim=1) + backpointers, sequence_length_less_one, seq_dim=1) # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) @@ -497,7 +517,7 @@ def crf_decode(potentials, transition_params, sequence_length): decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] crf_bwd_cell, inputs=backpointers, - sequence_length=sequence_length - 1, + sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) @@ -511,7 +531,7 @@ def crf_decode(potentials, transition_params, sequence_length): return decode_tags, best_score return utils.smart_cond( - pred=math_ops.equal( - potentials.shape[1].value or array_ops.shape(potentials)[1], 1), + pred=math_ops.equal(potentials.shape[1].value or + array_ops.shape(potentials)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fec358c4e1067dc8dc8173d1b9d05dc90b90ca05..aeefa3cee62281c74388765ea5e2cbc7f16ff927 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -9,52 +9,10 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -tf_custom_op_library( - name = "python/ops/_cudnn_rnn_ops.so", - srcs = [ - "kernels/cudnn_rnn_ops.cc", - "ops/cudnn_rnn_ops.cc", - ], - deps = [ - "//tensorflow/core/kernels:bounds_check_lib", - "@farmhash_archive//:farmhash", - ], -) - -tf_kernel_library( - name = "cudnn_rnn_kernels", - srcs = ["kernels/cudnn_rnn_ops.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", - "//third_party/eigen3", - "@farmhash_archive//:farmhash", - ], -) - -tf_gen_op_libs( - op_lib_names = ["cudnn_rnn_ops"], - deps = [ - "//tensorflow/core:lib", - ], -) - -tf_gen_op_wrapper_py( - name = "cudnn_rnn_ops", - deps = [":cudnn_rnn_ops_op_lib"], -) tf_custom_op_py_library( name = "cudnn_rnn_py", @@ -64,20 +22,14 @@ tf_custom_op_py_library( "python/layers/cudnn_rnn.py", "python/ops/cudnn_rnn_ops.py", ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":cudnn_rnn_ops", + "//tensorflow/contrib/checkpoint/python:split_dependency", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:cudnn_rnn_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", @@ -172,32 +124,3 @@ cuda_py_test( "requires_cudnn5", ], ) - -tf_cc_test( - name = "cudnn_rnn_ops_test_cc", - size = "small", - srcs = [ - "ops/cudnn_rnn_ops_test.cc", - ], - deps = [ - ":cudnn_rnn_ops_op_lib", - "//tensorflow/core", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc deleted file mode 100644 index ba9686e94ee7072cc485c955decb2287bd4a56f3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc +++ /dev/null @@ -1,1145 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#define EIGEN_USE_THREADS - -#include -#include -#include -#include -#include -#include -#include - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/env_var.h" - -#if GOOGLE_CUDA -#include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/util/stream_executor_util.h" -#endif // GOOGLE_CUDA - -/* - * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model - * using the underlying Cudnn library. - * - * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and - * format. And it is very likely that if saved, they cannot be used across - * different GPUs. So users need to first query the size of the opaque - * parameter buffer, and convert it to and from its canonical forms. But each - * actual training step is carried out with the parameter buffer. - * - * Similar to many other ops, the forward op has two flavors: training and - * inference. When training is specified, additional data in reserve_space will - * be produced for the backward pass. So there is a performance penalty. - * - * In addition to the actual data and reserve_space, Cudnn also needs more - * memory as temporary workspace. The memory management to and from - * stream-executor is done through ScratchAllocator. In general, - * stream-executor is responsible for creating the memory of proper size. And - * TensorFlow is responsible for making sure the memory is alive long enough - * and recycles afterwards. - * - */ -namespace tensorflow { - -using CPUDevice = Eigen::ThreadPoolDevice; - -#if GOOGLE_CUDA - -using GPUDevice = Eigen::GpuDevice; - -template -class CudnnRNNParamsSizeOp; - -template -class CudnnRNNParamsToCanonical; - -template -class CudnnRNNCanonicalToParams; - -template -class CudnnRNNForwardOp; - -template -class CudnnRNNBackwardOp; - -enum class TFRNNInputMode { - kRNNLinearInput = 0, - kRNNSkipInput = 1, - kAutoSelect = 9999999 -}; - -namespace { -using perftools::gputools::DeviceMemory; -using perftools::gputools::DeviceMemoryBase; -using perftools::gputools::ScratchAllocator; -using perftools::gputools::dnn::RnnDirectionMode; -using perftools::gputools::dnn::RnnInputMode; -using perftools::gputools::dnn::RnnMode; -using perftools::gputools::dnn::ToDataType; -using perftools::gputools::port::StatusOr; - -Status ParseRNNMode(const string& str, RnnMode* rnn_mode) { - if (str == "rnn_relu") { - *rnn_mode = RnnMode::kRnnRelu; - return Status::OK(); - } else if (str == "rnn_tanh") { - *rnn_mode = RnnMode::kRnnTanh; - return Status::OK(); - } else if (str == "lstm") { - *rnn_mode = RnnMode::kRnnLstm; - return Status::OK(); - } else if (str == "gru") { - *rnn_mode = RnnMode::kRnnGru; - return Status::OK(); - } - return errors::InvalidArgument("Invalid RNN mode: ", str); -} - -Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) { - if (str == "linear_input") { - *rnn_input_mode = TFRNNInputMode::kRNNLinearInput; - return Status::OK(); - } else if (str == "skip_input") { - *rnn_input_mode = TFRNNInputMode::kRNNSkipInput; - return Status::OK(); - } else if (str == "auto_select") { - *rnn_input_mode = TFRNNInputMode::kAutoSelect; - return Status::OK(); - } - return errors::InvalidArgument("Invalid RNN input mode: ", str); -} - -Status ParseRNNDirectionMode(const string& str, - RnnDirectionMode* rnn_dir_mode) { - if (str == "unidirectional") { - *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional; - return Status::OK(); - } else if (str == "bidirectional") { - *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional; - return Status::OK(); - } - return errors::InvalidArgument("Invalid RNN direction mode: ", str); -} - -Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units, - int input_size, RnnInputMode* input_mode) { - switch (tf_input_mode) { - case TFRNNInputMode::kRNNLinearInput: - *input_mode = RnnInputMode::kRnnLinearSkip; - break; - case TFRNNInputMode::kRNNSkipInput: - *input_mode = RnnInputMode::kRnnSkipInput; - break; - case TFRNNInputMode::kAutoSelect: - *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput - : RnnInputMode::kRnnLinearSkip; - break; - default: - return errors::InvalidArgument("Invalid TF input mode: ", - static_cast(tf_input_mode)); - } - return Status::OK(); -} - -// TODO(zhengxq): Merge those into stream_executor_util.h. -template -const DeviceMemory AsDeviceMemory(const Tensor* tensor) { - return DeviceMemory::MakeFromByteSize( - const_cast(tensor->template flat().data()), - tensor->template flat().size() * sizeof(T)); -} - -template -DeviceMemory AsDeviceMemory(Tensor* tensor) { - return DeviceMemory::MakeFromByteSize( - tensor->template flat().data(), - tensor->template flat().size() * sizeof(T)); -} - -template -DeviceMemory CastDeviceMemory(Tensor* tensor) { - return DeviceMemory::MakeFromByteSize( - tensor->template flat().data(), - tensor->template flat().size() * sizeof(T)); -} - -DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory, - int64 offset, int64 size) { - const void* base_ptr = device_memory.opaque(); - void* offset_ptr = - const_cast(reinterpret_cast(base_ptr) + offset); - CHECK(offset + size <= device_memory.size()) - << "The slice is not within the region of DeviceMemory."; - return DeviceMemoryBase(offset_ptr, size); -} - -inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) { - return s.ok() ? Status::OK() - : Status(static_cast( - static_cast(s.code())), - s.error_message()); -} - -template -inline Status FromExecutorStatus( - const perftools::gputools::port::StatusOr& s) { - return FromExecutorStatus(s.status()); -} - -inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) { - return s.ok() ? perftools::gputools::port::Status::OK() - : perftools::gputools::port::Status( - static_cast( - static_cast(s.code())), - s.error_message()); -} - -// A helper to allocate temporary scratch memory for Cudnn RNN models. It takes -// the ownership of the underlying memory. The expectation is that the memory -// should be alive for the span of the Cudnn RNN itself. -class CudnnRNNWorkspaceAllocator : public ScratchAllocator { - public: - ~CudnnRNNWorkspaceAllocator() override {} - explicit CudnnRNNWorkspaceAllocator(OpKernelContext* context) - : context_(context) {} - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { - return std::numeric_limits::max(); - } - StatusOr> AllocateBytes( - perftools::gputools::Stream* stream, int64 byte_size) override { - Tensor temporary_memory; - Status allocation_status(context_->allocate_temp( - DT_UINT8, TensorShape({byte_size}), &temporary_memory)); - if (!allocation_status.ok()) { - return ToExecutorStatus(allocation_status); - } - // Hold the reference of the allocated tensors until the end of the - // allocator. - allocated_tensors_.push_back(temporary_memory); - total_byte_size_ += byte_size; - return StatusOr>( - AsDeviceMemory(&temporary_memory)); - } - int64 TotalByteSize() { return total_byte_size_; } - - private: - int64 total_byte_size_ = 0; - OpKernelContext* context_; // not owned - std::vector allocated_tensors_; -}; - -// A helper to allocate reserve-space memory for Cudnn RNN models. The tensors -// are allocated as a kernel output, and will be fed into the backward pass. -// The memory is expected to live long enough after the backward pass is -// finished. -template -class CudnnRNNReserveSpaceAllocator : public ScratchAllocator { - public: - ~CudnnRNNReserveSpaceAllocator() override {} - CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index) - : context_(context), output_index_(output_index) {} - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { - return std::numeric_limits::max(); - } - StatusOr> AllocateBytes( - perftools::gputools::Stream* stream, int64 byte_size) override { - CHECK(total_byte_size_ == 0) - << "Reserve space allocator can only be called once"; - int64 allocate_count = - Eigen::divup(byte_size, static_cast(sizeof(T))); - - Tensor* temporary_memory = nullptr; - Status allocation_status(context_->allocate_output( - output_index_, TensorShape({allocate_count}), &temporary_memory)); - if (!allocation_status.ok()) { - return ToExecutorStatus(allocation_status); - } - total_byte_size_ += byte_size; - auto memory_uint8 = DeviceMemory::MakeFromByteSize( - temporary_memory->template flat().data(), - temporary_memory->template flat().size() * sizeof(T)); - return StatusOr>(memory_uint8); - } - int64 TotalByteSize() { return total_byte_size_; } - - private: - int64 total_byte_size_ = 0; - OpKernelContext* context_; // not owned - int output_index_; -}; - -// A helper to allocate persistent memory for Cudnn RNN models, which is -// expected to live between kernel invocations. -// This class is not thread-safe. -class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator { - public: - explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context) - : context_(context) {} - - ~CudnnRNNPersistentSpaceAllocator() override {} - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { - return std::numeric_limits::max(); - } - - StatusOr> AllocateBytes( - perftools::gputools::Stream* stream, int64 byte_size) override { - if (total_byte_size_ != 0) { - return Status(error::FAILED_PRECONDITION, - "Persistent space allocator can only be called once"); - } - - Status allocation_status = context_->allocate_persistent( - DT_UINT8, TensorShape({byte_size}), &handle_, nullptr); - if (!allocation_status.ok()) { - return ToExecutorStatus(allocation_status); - } - total_byte_size_ += byte_size; - return AsDeviceMemory(handle_.AccessTensor(context_)); - } - int64 TotalByteSize() { return total_byte_size_; } - - private: - int64 total_byte_size_ = 0; - PersistentTensor handle_; - OpKernelContext* context_; // not owned -}; - -struct CudnnModelTypes { - RnnMode rnn_mode; - TFRNNInputMode rnn_input_mode; - RnnDirectionMode rnn_direction_mode; - bool HasInputC() const { - // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h. - return rnn_mode == RnnMode::kRnnLstm; - } -}; - -// A helper class that collects the shapes to describe a RNN model. -struct CudnnModelShapes { - int num_layers; - int input_size; - int num_units; - int seq_length; - int batch_size; - int dir_count; - TensorShape input_shape; - TensorShape output_shape; - TensorShape hidden_state_shape; - // At present only fields related to cached RnnDescriptor are concerned. - bool IsCompatibleWith(const CudnnModelShapes& rhs) const { - return num_layers == rhs.num_layers && input_size == rhs.input_size && - num_units == rhs.num_units && dir_count == rhs.dir_count; - } - string RnnDescDebugString() { - return strings::Printf( - "[num_layers, input_size, num_units, dir_count]: [%d, %d, %d, %d]", - num_layers, input_size, num_units, dir_count); - } -}; - -// Utility class for using CudnnModelShapes as a hash table key. -struct CudnnModelShapesHasher { - uint64 operator()(const CudnnModelShapes& to_hash) const { - uint64 hash = static_cast(to_hash.num_layers); - hash = tensorflow::FingerprintCat64( - hash, static_cast(to_hash.input_size)); - hash = tensorflow::FingerprintCat64(hash, - static_cast(to_hash.num_units)); - return tensorflow::FingerprintCat64(hash, - static_cast(to_hash.dir_count)); - } -}; - -// Utility class for using CudnnModelShapes as a hash table key. -struct CudnnModelShapesComparator { - bool operator()(const CudnnModelShapes& first, - const CudnnModelShapes& second) const { - return first.IsCompatibleWith(second); - } -}; - -// Extract and checks the forward input tensors, parameters, and shapes from the -// OpKernelContext. -Status ExtractForwardInput(OpKernelContext* context, - const CudnnModelTypes& model_types, - const Tensor** input, const Tensor** input_h, - const Tensor** input_c, const Tensor** params, - CudnnModelShapes* model_shapes) { - TF_RETURN_IF_ERROR(context->input("input", input)); - TF_RETURN_IF_ERROR(context->input("input_h", input_h)); - if (model_types.HasInputC()) { - TF_RETURN_IF_ERROR(context->input("input_c", input_c)); - } - TF_RETURN_IF_ERROR(context->input("params", params)); - - if ((*input)->dims() != 3) { - return errors::InvalidArgument("RNN input must be a 3-D vector."); - } - model_shapes->seq_length = (*input)->dim_size(0); - model_shapes->batch_size = (*input)->dim_size(1); - model_shapes->input_size = (*input)->dim_size(2); - model_shapes->input_shape = (*input)->shape(); - model_shapes->dir_count = - (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional) - ? 2 - : 1; - - if ((*input_h)->dims() != 3) { - return errors::InvalidArgument("RNN input must be a 3-D vector."); - } - model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count; - model_shapes->num_units = (*input_h)->dim_size(2); - - model_shapes->hidden_state_shape = - TensorShape({model_shapes->dir_count * model_shapes->num_layers, - model_shapes->batch_size, model_shapes->num_units}); - if ((*input_h)->shape() != model_shapes->hidden_state_shape) { - return errors::InvalidArgument( - "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", - model_shapes->hidden_state_shape.DebugString()); - } - if (model_types.HasInputC()) { - if ((*input_h)->shape() != (*input_c)->shape()) { - return errors::InvalidArgument( - "input_h and input_c must have the same shape: ", - (*input_h)->shape().DebugString(), " ", - (*input_c)->shape().DebugString()); - } - } - model_shapes->output_shape = - TensorShape({model_shapes->seq_length, model_shapes->batch_size, - model_shapes->dir_count * model_shapes->num_units}); - return Status::OK(); -} - -using perftools::gputools::dnn::RnnDescriptor; - -template -void RestoreParams(const OpInputList params_input, - const std::vector& params, - DeviceMemoryBase* data_dst, - perftools::gputools::Stream* stream) { - int num_params = params.size(); - CHECK(params_input.size() == num_params) - << "Number of params mismatch. Expected " << params_input.size() - << ", got " << num_params; - for (int i = 0; i < params.size(); i++) { - int64 size_in_bytes = params[i].size; - int64 size = size_in_bytes / sizeof(T); - CHECK(size == params_input[i].NumElements()) - << "Params size mismatch. Expected " << size << ", got " - << params_input[i].NumElements(); - auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory(params_input[i]); - DeviceMemoryBase data_dst_ptr = - SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes); - stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); - } -} - -} // namespace - -// Note: all following kernels depend on a RnnDescriptor instance, which -// according to Cudnn official doc should be kept around and reused across all -// Cudnn kernels in the same model. -// In Tensorflow, we don't pass the reference across different OpKernels, -// rather, recreate it separately in each OpKernel, which does no cause issue: -// CudnnDropoutDescriptor keeps a reference to a memory for -// random number generator state. During recreation, this state is lost. -// However, only forward-pass Cudnn APIs make use of the state. - -// A common base class for RNN kernels. It extracts common attributes and -// shape validations. -class CudnnRNNKernelCommon : public OpKernel { - protected: - explicit CudnnRNNKernelCommon(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_)); - OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); - OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); - string str; - OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str)); - OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode)); - OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str)); - OP_REQUIRES_OK(context, - ParseTFRNNInputMode(str, &model_types_.rnn_input_mode)); - OP_REQUIRES_OK(context, context->GetAttr("direction", &str)); - OP_REQUIRES_OK( - context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode)); - // Reset CudnnRnnDescriptor and related random number generate states in - // every Compute() call. - OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE", - false, &reset_rnd_gen_state_)); - } - - bool HasInputC() const { return model_types_.HasInputC(); } - RnnMode rnn_mode() const { return model_types_.rnn_mode; } - TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; } - RnnDirectionMode rnn_direction_mode() const { - return model_types_.rnn_direction_mode; - } - CudnnModelTypes model_types() const { return model_types_; } - float dropout() const { return dropout_; } - uint64 seed() { return (static_cast(seed_) << 32) | seed2_; } - bool ResetRndGenState() { return reset_rnd_gen_state_; } - - template - Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, - std::unique_ptr* rnn_desc) { - const Tensor* num_layers_t = nullptr; - TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); - if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) { - return errors::InvalidArgument("num_layers is not a scalar"); - } - int num_layers = num_layers_t->scalar()(); - const Tensor* num_units_t = nullptr; - TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t)); - if (!TensorShapeUtils::IsScalar(num_units_t->shape())) { - return errors::InvalidArgument("num_units is not a scalar"); - } - int num_units = num_units_t->scalar()(); - const Tensor* input_size_t = nullptr; - TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t)); - if (!TensorShapeUtils::IsScalar(input_size_t->shape())) { - return errors::InvalidArgument("input_size is not a scalar"); - } - int input_size = input_size_t->scalar()(); - - RnnInputMode input_mode; - TF_RETURN_IF_ERROR( - ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); - - auto* stream = context->op_device_context()->stream(); - // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require - // random number generator, therefore set state_allocator to nullptr. - auto rnn_desc_s = stream->parent()->createRnnDescriptor( - num_layers, num_units, input_size, input_mode, rnn_direction_mode(), - rnn_mode(), ToDataType::value, dropout(), seed(), - nullptr /* state_allocator */); - if (!rnn_desc_s.ok()) { - return FromExecutorStatus(rnn_desc_s); - } - *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); - return Status::OK(); - } - - private: - int seed_; - int seed2_; - float dropout_; - bool reset_rnd_gen_state_; - - CudnnModelTypes model_types_; -}; - -// A class that returns the size of the opaque parameter buffer. The user should -// use that to create the actual parameter buffer for training. However, it -// should not be used for saving and restoring. -template -class CudnnRNNParamsSizeOp : public CudnnRNNKernelCommon { - public: - typedef GPUDevice Device; - explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) {} - - void Compute(OpKernelContext* context) override { - std::unique_ptr rnn_desc; - OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo(context, &rnn_desc)); - int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); - CHECK(params_size_in_bytes % sizeof(T) == 0) - << "params_size_in_bytes must be multiple of element size"; - int64 params_size = params_size_in_bytes / sizeof(T); - - Tensor* output_t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); - *output_t->template flat().data() = params_size; - } -}; - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \ - .Device(DEVICE_GPU) \ - .HostMemory("num_layers") \ - .HostMemory("num_units") \ - .HostMemory("input_size") \ - .HostMemory("params_size") \ - .TypeConstraint("T") \ - .TypeConstraint("S"), \ - CudnnRNNParamsSizeOp); - -TF_CALL_half(REGISTER_GPU); -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); -#undef REGISTER_GPU - -// Convert weight and bias params from a platform-specific layout to the -// canonical form. -template -class CudnnRNNParamsToCanonical : public CudnnRNNKernelCommon { - public: - typedef GPUDevice Device; - explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) { - OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(3); - auto input_ptr = StreamExecutorUtil::AsDeviceMemory(input); - auto* stream = context->op_device_context()->stream(); - - std::unique_ptr rnn_desc; - OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo(context, &rnn_desc)); - int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); - CHECK(params_size_in_bytes % sizeof(T) == 0) - << "params_size_in_bytes must be multiple of element size"; - - const Tensor* num_units_t = nullptr; - OP_REQUIRES_OK(context, context->input("num_units", &num_units_t)); - CHECK(TensorShapeUtils::IsScalar(num_units_t->shape())) - << "num_units is not a scalar"; - int num_units = num_units_t->scalar()(); - - const Tensor* input_size_t = nullptr; - OP_REQUIRES_OK(context, context->input("input_size", &input_size_t)); - CHECK(TensorShapeUtils::IsScalar(input_size_t->shape())) - << "input_size is not a scalar"; - int input_size = input_size_t->scalar()(); - - const Tensor* num_layers_t = nullptr; - OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t)); - CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape())) - << "num_layers is not a scalar"; - int num_layers = num_layers_t->scalar()(); - int num_dirs = 1; - if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) { - num_dirs = 2; - } - const int num_params_per_layer = num_params_ / num_layers / num_dirs; - // Number of params applied on inputs. The rest are applied on recurrent - // hidden states. - const int num_params_input_state = num_params_per_layer / 2; - CHECK(num_params_ % (num_layers * num_dirs) == 0) - << "Number of params is not a multiple of num_layers * num_dirs."; - CHECK(num_params_per_layer % 2 == 0) - << "Number of params per layer is not a even number."; - - CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size()) - << "Number of params mismatch. Expected " << num_params_ << ", got " - << rnn_desc->ParamsWeightRegions().size(); - for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { - int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size; - int64 size = size_in_bytes / sizeof(T); - const int layer_idx = i / num_params_per_layer; - const int index_within_layer = i % num_params_per_layer; - int width = 0, height = num_units; - // In CuDNN layout, each layer has num_params_per_layer params, with the - // first half a.k.a num_params_input_state params applied on the inputs, - // and the second half on the recurrent hidden states. - bool apply_on_input_state = index_within_layer < num_params_input_state; - if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) { - if (layer_idx == 0 && apply_on_input_state) { - width = input_size; - } else { - width = num_units; - } - } else { - if (apply_on_input_state) { - if (layer_idx <= 1) { - // First fwd or bak layer. - width = input_size; - } else { - // Following layers, cell inputs are concatenated outputs of - // its prior layer. - width = 2 * num_units; - } - } else { - width = num_units; - } - } - CHECK(size == width * height) << "Params size mismatch. Expected " - << width * height << ", got " << size; - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - i, TensorShape({height, width}), &output)); - DeviceMemoryBase data_src_ptr = SliceDeviceMemory( - input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes); - auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory(*output); - stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); - } - - OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(), - errors::InvalidArgument("Number of params mismatch. Expected ", - num_params_, ", got ", - rnn_desc->ParamsBiasRegions().size())); - for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) { - int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size; - int64 size = size_in_bytes / sizeof(T); - OP_REQUIRES(context, size == num_units, - errors::InvalidArgument("Params size mismatch. Expected ", - num_units, ", got ", size)); - - Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(num_params_ + i, - TensorShape({size}), &output)); - DeviceMemoryBase data_src_ptr = SliceDeviceMemory( - input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes); - auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory(*output); - stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); - } - } - - private: - int num_params_; -}; - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \ - .Device(DEVICE_GPU) \ - .HostMemory("num_layers") \ - .HostMemory("num_units") \ - .HostMemory("input_size") \ - .TypeConstraint("T"), \ - CudnnRNNParamsToCanonical); -TF_CALL_half(REGISTER_GPU); -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); -#undef REGISTER_GPU - -// Convert weight and bias params from the canonical form to a -// platform-specific layout. -template -class CudnnRNNCanonicalToParams : public CudnnRNNKernelCommon { - public: - typedef GPUDevice Device; - explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) {} - - void Compute(OpKernelContext* context) override { - std::unique_ptr rnn_desc; - OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo(context, &rnn_desc)); - int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); - CHECK(params_size_in_bytes % sizeof(T) == 0) - << "params_size_in_bytes must be multiple of element size"; - Tensor* output = nullptr; - int params_size = params_size_in_bytes / sizeof(T); - OP_REQUIRES_OK(context, - context->allocate_output(0, {params_size}, &output)); - auto output_ptr = StreamExecutorUtil::AsDeviceMemory(*output); - auto* stream = context->op_device_context()->stream(); - - OpInputList weights; - OP_REQUIRES_OK(context, context->input_list("weights", &weights)); - RestoreParams(weights, rnn_desc->ParamsWeightRegions(), &output_ptr, - stream); - - OpInputList biases; - OP_REQUIRES_OK(context, context->input_list("biases", &biases)); - RestoreParams(biases, rnn_desc->ParamsBiasRegions(), &output_ptr, - stream); - } -}; - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \ - .Device(DEVICE_GPU) \ - .HostMemory("num_layers") \ - .HostMemory("num_units") \ - .HostMemory("input_size") \ - .TypeConstraint("T"), \ - CudnnRNNCanonicalToParams); -TF_CALL_half(REGISTER_GPU); -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); -#undef REGISTER_GPU - -// Pointers to RNN scratch space for a specific set of shape parameters (used as -// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). -struct RnnScratchSpace { - std::unique_ptr rnn_desc; - std::unique_ptr dropout_state_allocator; -}; - -// Run the forward operation of the RNN model. -template -class CudnnRNNForwardOp : public CudnnRNNKernelCommon { - public: - typedef GPUDevice Device; - explicit CudnnRNNForwardOp(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) { - OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor* input = nullptr; - const Tensor* input_h = nullptr; - const Tensor* input_c = nullptr; - const Tensor* params = nullptr; - CudnnModelShapes model_shapes; - OP_REQUIRES_OK(context, - ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes)); - const auto& input_shape = model_shapes.input_shape; - const auto& hidden_state_shape = model_shapes.hidden_state_shape; - const auto& output_shape = model_shapes.output_shape; - - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - Tensor* output_h = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(1, hidden_state_shape, &output_h)); - Tensor* output_c = nullptr; - if (HasInputC()) { - // Only LSTM uses input_c and output_c. So for all other models, we only - // need to create dummy outputs. - OP_REQUIRES_OK( - context, context->allocate_output(2, hidden_state_shape, &output_c)); - } else { - OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c)); - } - - auto* stream = context->op_device_context()->stream(); - auto* executor = stream->parent(); - RnnInputMode input_mode; - OP_REQUIRES_OK(context, - ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, - model_shapes.input_size, &input_mode)); - auto data_type = ToDataType::value; - - auto input_desc_s = executor->createRnnSequenceTensorDescriptor( - input_shape.dim_size(0), input_shape.dim_size(1), - input_shape.dim_size(2), data_type); - OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s)); - auto input_desc = input_desc_s.ConsumeValueOrDie(); - - auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( - hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), - hidden_state_shape.dim_size(2), data_type); - OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s)); - auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); - - auto output_desc_s = executor->createRnnSequenceTensorDescriptor( - output_shape.dim_size(0), output_shape.dim_size(1), - output_shape.dim_size(2), data_type); - OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s)); - auto output_desc = output_desc_s.ConsumeValueOrDie(); - - auto input_data = AsDeviceMemory(input); - auto input_h_data = AsDeviceMemory(input_h); - DeviceMemory input_c_data; - if (HasInputC()) { - input_c_data = AsDeviceMemory(input_c); - } - auto params_data = AsDeviceMemory(params); - auto output_data = AsDeviceMemory(output); - auto output_h_data = AsDeviceMemory(output_h); - DeviceMemory output_c_data; - if (HasInputC()) { - output_c_data = AsDeviceMemory(output_c); - } - - // Creates a memory callback for the reserve_space. The memory lives in the - // output of this kernel. And it will be fed into the backward pass when - // needed. - CudnnRNNReserveSpaceAllocator reserve_space_allocator(context, 3); - if (!is_training_) { - Tensor* dummy_reserve_space = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(3, {}, &dummy_reserve_space)); - } - // Creates a memory callback for the workspace. The memory lives to the end - // of this kernel calls. - CudnnRNNWorkspaceAllocator workspace_allocator(context); - bool launch_status = false; - { - mutex_lock l(mu_); - RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; - if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { - CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = - new CudnnRNNPersistentSpaceAllocator(context); - rnn_state.dropout_state_allocator.reset(dropout_state_allocator); - auto rnn_desc_s = executor->createRnnDescriptor( - model_shapes.num_layers, model_shapes.num_units, - model_shapes.input_size, input_mode, rnn_direction_mode(), - rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); - OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); - rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); - } - launch_status = - stream - ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data, - *hidden_state_desc, input_h_data, - *hidden_state_desc, input_c_data, params_data, - *output_desc, &output_data, *hidden_state_desc, - &output_h_data, *hidden_state_desc, - &output_c_data, is_training_, - &reserve_space_allocator, &workspace_allocator) - .ok(); - } - OP_REQUIRES(context, launch_status, - errors::Internal("Failed to call ThenRnnForward")); - } - - private: - mutex mu_; - bool is_training_; - std::unordered_map - rnn_state_cache_ GUARDED_BY(mu_); -}; - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint("T"), \ - CudnnRNNForwardOp); - -TF_CALL_half(REGISTER_GPU); -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); -#undef REGISTER_GPU - -// Run the backward operation of the RNN model. -template -class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { - public: - typedef GPUDevice Device; - - explicit CudnnRNNBackwardOp(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor* input = nullptr; - const Tensor* input_h = nullptr; - const Tensor* input_c = nullptr; - const Tensor* params = nullptr; - CudnnModelShapes model_shapes; - OP_REQUIRES_OK(context, - ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes)); - - const auto& input_shape = model_shapes.input_shape; - const auto& hidden_state_shape = model_shapes.hidden_state_shape; - const auto& output_shape = model_shapes.output_shape; - - auto data_type = ToDataType::value; - const Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->input("output", &output)); - OP_REQUIRES(context, output_shape == output->shape(), - errors::InvalidArgument( - "input_h and input_c must have the same shape: ", - input_h->shape().DebugString(), " ", - input_c->shape().DebugString())); - const Tensor* output_h = nullptr; - OP_REQUIRES_OK(context, context->input("output_h", &output_h)); - OP_REQUIRES(context, output_h->shape() == hidden_state_shape, - errors::InvalidArgument( - "Invalid output_h shape: ", output_h->shape().DebugString(), - " ", hidden_state_shape.DebugString())); - const Tensor* output_c = nullptr; - if (HasInputC()) { - // Only LSTM uses input_c and output_c. So for all other models, we only - // need to create dummy outputs. - OP_REQUIRES_OK(context, context->input("output_c", &output_c)); - OP_REQUIRES(context, output_c->shape() == hidden_state_shape, - errors::InvalidArgument("Invalid output_c shape: ", - output_c->shape().DebugString(), " ", - hidden_state_shape.DebugString())); - } - - const Tensor* output_backprop = nullptr; - OP_REQUIRES_OK(context, - context->input("output_backprop", &output_backprop)); - OP_REQUIRES(context, output_backprop->shape() == output_shape, - errors::InvalidArgument("Invalid output_backprop shapes: ", - output_backprop->shape().DebugString(), - " ", output_shape.DebugString())); - - const Tensor* output_h_backprop = nullptr; - OP_REQUIRES_OK(context, - context->input("output_h_backprop", &output_h_backprop)); - OP_REQUIRES( - context, output_h_backprop->shape() == hidden_state_shape, - errors::InvalidArgument("Invalid output_h_backprop shapes: ", - output_h_backprop->shape().DebugString(), " ", - hidden_state_shape.DebugString())); - const Tensor* output_c_backprop = nullptr; - if (HasInputC()) { - OP_REQUIRES_OK(context, - context->input("output_c_backprop", &output_c_backprop)); - OP_REQUIRES( - context, output_c_backprop->shape() == hidden_state_shape, - errors::InvalidArgument("Invalid output_c_backprop shapes: ", - output_c_backprop->shape().DebugString(), " ", - hidden_state_shape.DebugString())); - } - const Tensor* reserve_space_const = nullptr; - // This is the same "reserve_space" created by the forward op. - // It can also be modified by this backward operation. - OP_REQUIRES_OK(context, - context->input("reserve_space", &reserve_space_const)); - // Cudnn needs the reserve space to be writeable. This is fine because they - // are opaque. - Tensor* reserve_space = const_cast(reserve_space_const); - - Tensor* input_backprop = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output(0, input->shape(), &input_backprop)); - Tensor* input_h_backprop = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(), - &input_h_backprop)); - Tensor* input_c_backprop = nullptr; - if (HasInputC()) { - OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(), - &input_c_backprop)); - } else { - OP_REQUIRES_OK(context, - context->allocate_output(2, {}, &input_c_backprop)); - } - Tensor* params_backprop = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(), - ¶ms_backprop)); - - auto* stream = context->op_device_context()->stream(); - auto* executor = stream->parent(); - RnnInputMode input_mode; - OP_REQUIRES_OK(context, - ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, - model_shapes.input_size, &input_mode)); - - auto input_desc_s = executor->createRnnSequenceTensorDescriptor( - input_shape.dim_size(0), input_shape.dim_size(1), - input_shape.dim_size(2), data_type); - OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s)); - auto input_desc = input_desc_s.ConsumeValueOrDie(); - - auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( - hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), - hidden_state_shape.dim_size(2), data_type); - OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s)); - auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); - - auto output_desc_s = executor->createRnnSequenceTensorDescriptor( - output_shape.dim_size(0), output_shape.dim_size(1), - output_shape.dim_size(2), data_type); - OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s)); - auto output_desc = output_desc_s.ConsumeValueOrDie(); - - auto input_data = AsDeviceMemory(input); - auto input_h_data = AsDeviceMemory(input_h); - DeviceMemory input_c_data; - if (HasInputC()) { - input_c_data = AsDeviceMemory(input_c); - } - auto params_data = AsDeviceMemory(params); - auto output_data = AsDeviceMemory(output); - auto output_h_data = AsDeviceMemory(output_h); - DeviceMemory output_c_data; - if (HasInputC()) { - output_c_data = AsDeviceMemory(output_c); - } - auto output_backprop_data = AsDeviceMemory(output_backprop); - auto output_h_backprop_data = AsDeviceMemory(output_h_backprop); - DeviceMemory output_c_backprop_data; - if (HasInputC()) { - output_c_backprop_data = AsDeviceMemory(output_c_backprop); - } - auto input_backprop_data = AsDeviceMemory(input_backprop); - auto input_h_backprop_data = AsDeviceMemory(input_h_backprop); - DeviceMemory input_c_backprop_data; - if (HasInputC()) { - input_c_backprop_data = AsDeviceMemory(input_c_backprop); - } - auto params_backprop_data = AsDeviceMemory(params_backprop); - auto reserve_space_uint8 = CastDeviceMemory(reserve_space); - // Creates a memory callback for the workspace. The memory lives to the end - // of this kernel calls. - CudnnRNNWorkspaceAllocator workspace_allocator(context); - bool launch_status = false; - { - mutex_lock l(mu_); - RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; - if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { - CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = - new CudnnRNNPersistentSpaceAllocator(context); - rnn_state.dropout_state_allocator.reset(dropout_state_allocator); - auto rnn_desc_s = executor->createRnnDescriptor( - model_shapes.num_layers, model_shapes.num_units, - model_shapes.input_size, input_mode, rnn_direction_mode(), - rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); - OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); - rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); - } - launch_status = - stream - ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data, - *hidden_state_desc, input_h_data, - *hidden_state_desc, input_c_data, params_data, - *output_desc, output_data, *hidden_state_desc, - output_h_data, *hidden_state_desc, - output_c_data, output_backprop_data, - output_h_backprop_data, output_c_backprop_data, - &input_backprop_data, &input_h_backprop_data, - &input_c_backprop_data, ¶ms_backprop_data, - &reserve_space_uint8, &workspace_allocator) - .ok(); - } - OP_REQUIRES(context, launch_status, - errors::Internal("Failed to call ThenRnnBackward")); - } - - private: - mutex mu_; - std::unordered_map - rnn_state_cache_ GUARDED_BY(mu_); -}; - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint("T"), \ - CudnnRNNBackwardOp); - -TF_CALL_half(REGISTER_GPU); -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); -#undef REGISTER_GPU - -// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to -// its canonical form. - -#endif // GOOGLE_CUDA - -} // namespace tensorflow diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 9897c31a98e0b335c18a84825fc518ed1fc310a2..8285ea04926d3a24e9c22bd6d69eb7a48f5e3a85 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import argparse import collections +import functools import itertools import os import sys @@ -34,7 +35,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_nn_ops @@ -57,6 +58,7 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -265,7 +267,7 @@ def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): return outputs, (output_state_fw, output_state_bw) -class CudnnRNNTestBasic(TensorFlowTestCase): +class CudnnRNNTestBasic(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @@ -467,7 +469,7 @@ class CudnnRNNTestBasic(TensorFlowTestCase): # TODO(jamesqin): Transform to parameterized test after it is included in the # TF open source codebase. -class CudnnRNNTestSaveRestore(TensorFlowTestCase): +class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): def _CompareWeights(self, lhs, rhs): self.assertEqual(len(lhs), len(rhs)) @@ -701,9 +703,146 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): self._TestSaveRestoreHelper(CUDNN_RNN_RELU) +class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): + + def _VerifyCheckpoint( + self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn, + num_layers, input_size, expected_variable_values, num_applications=3): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with ops.device("gpu:0"): + cudnn_layer = cudnn_cell_fn() + cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer) + status = cudnn_checkpoint.restore(checkpoint_path) + inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], + dtype=dtypes.float32) + cudnn_output, _ = cudnn_layer(inputs) + status.run_restore_ops() + second_save_path = cudnn_checkpoint.save(checkpoint_prefix) + restore_layer = compatible_cell_fn() + restore_layer_checkpoint = checkpointable_utils.Checkpoint( + cell=restore_layer) + status = restore_layer_checkpoint.restore(second_save_path) + current_state = restore_layer.zero_state(1, dtypes.float32) + for _ in range(num_applications): + restore_layer_output, current_state = restore_layer( + inputs=3. * array_ops.ones([1, input_size]), + state=current_state) + status.run_restore_ops() + self.assertTrue(restore_layer.variables) + for variable, expected_value in zip( + restore_layer.variables, expected_variable_values): + self.assertAllClose(expected_value, self.evaluate(variable)) + self.assertAllClose(self.evaluate(restore_layer_output), + self.evaluate(cudnn_output)[-1, -1:, ...]) + + def _CheckpointableSingleCellUnidirectionalTestTemplate( + self, single_cell_fn, cudnn_cell_fn): + # Single-layer cuDNN cells with object-based checkpointing should be + # checkpoint compatible with either single CudnnCompatible cells or + # MultiRnnCells with one cell. + input_size = 3 + save_cell_layer = single_cell_fn() + save_cell_layer( + inputs=array_ops.ones([1, input_size]), + state=save_cell_layer.zero_state(1, dtypes.float32)) + self.assertTrue(save_cell_layer.variables) + expected_values = [] + np.random.seed(10) + for variable in save_cell_layer.variables: + value = np.random.normal(size=variable.shape) + expected_values.append(value) + self.evaluate(variable.assign(value)) + save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + first_save_path = save_checkpoint.save(checkpoint_prefix) + self._VerifyCheckpoint( + checkpoint_path=first_save_path, + compatible_cell_fn= + lambda: rnn_cell_impl.MultiRNNCell([single_cell_fn()]), + cudnn_cell_fn=cudnn_cell_fn, + num_layers=1, + expected_variable_values=expected_values, + input_size=input_size) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + @test_util.run_in_graph_and_eager_modes() + def testLSTMCheckpointableSingleLayer(self): + num_units = 2 + direction = CUDNN_RNN_UNIDIRECTION + self._CheckpointableSingleCellUnidirectionalTestTemplate( + single_cell_fn=functools.partial( + cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), + cudnn_cell_fn=functools.partial( + cudnn_rnn.CudnnLSTM, num_layers=1, num_units=num_units, + direction=direction, name="awesome_lstm")) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + @test_util.run_in_graph_and_eager_modes() + def testGRUCheckpointableSingleLayer(self): + num_units = 2 + direction = CUDNN_RNN_UNIDIRECTION + with self.assertRaises(NotImplementedError): + # TODO(allenl): Implement object-based saving for GRUs and other cells. + self._CheckpointableSingleCellUnidirectionalTestTemplate( + single_cell_fn=functools.partial( + cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units), + cudnn_cell_fn=functools.partial( + cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units, + direction=direction, name="awesome_gru")) + + def _CheckpointableMultiLayerTestTemplate( + self, single_cell_fn, cudnn_cell_fn, num_layers): + + def _MultiCellFn(): + return rnn_cell_impl.MultiRNNCell( + [single_cell_fn() for _ in range(num_layers)]) + input_size = 3 + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session(graph=save_graph): + save_layer = _MultiCellFn() + save_layer(inputs=array_ops.ones([1, input_size]), + state=save_layer.zero_state(1, dtypes.float32)) + self.assertTrue(save_layer.variables) + expected_values = [] + np.random.seed(10) + for variable in save_layer.variables: + value = np.random.normal(size=variable.shape) + expected_values.append(value) + self.evaluate(variable.assign(value)) + save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + first_save_path = save_checkpoint.save(checkpoint_prefix) + self._VerifyCheckpoint( + checkpoint_path=first_save_path, + compatible_cell_fn=_MultiCellFn, cudnn_cell_fn=cudnn_cell_fn, + num_layers=num_layers, + expected_variable_values=expected_values, + input_size=input_size) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + @test_util.run_in_graph_and_eager_modes() + def testCudnnCompatibleLSTMCheckpointablMultiLayer(self): + num_units = 2 + num_layers = 3 + direction = CUDNN_RNN_UNIDIRECTION + self._CheckpointableMultiLayerTestTemplate( + single_cell_fn=functools.partial( + cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), + cudnn_cell_fn=functools.partial( + cudnn_rnn.CudnnLSTM, num_layers=num_layers, num_units=num_units, + direction=direction, name="awesome_lstm"), + num_layers=num_layers) + + # TODO(jamesqin): Transform to parameterized test after it is included in the # TF open source codebase. -class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): +class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @@ -884,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): rtol=2e-5) -class CudnnRNNTestParamsSize(TensorFlowTestCase): +class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase): def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, dtype, direction): @@ -931,7 +1070,18 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): dtype, direction) -class CudnnRNNTestTraining(TensorFlowTestCase): +class CudnnRNNTestTraining(test_util.TensorFlowTestCase): + + def setUp(self): + super(CudnnRNNTestTraining, self).setUp() + self._reset_rnd_gen_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", + str(False)) + self._rnn_use_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0") + + def tearDown(self): + super(CudnnRNNTestTraining, self).tearDown() + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = self._reset_rnd_gen_state + os.environ["TF_CUDNN_RNN_USE_V2"] = self._rnn_use_v2 def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): """Compute the numeric gradient of y wrt to x. @@ -1045,11 +1195,10 @@ class CudnnRNNTestTraining(TensorFlowTestCase): def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, batch_size, seq_length, dir_count, dropout, dtype, - delta, tolerance): + use_v2, delta, tolerance): # Gradient checking runs two forward ops with almost the same input. Need to # make sure the drop patterns across the two runs are the same. logging.info("Training test with config: %s", locals()) - old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) np.random.seed(1234) @@ -1057,6 +1206,10 @@ class CudnnRNNTestTraining(TensorFlowTestCase): has_input_c = (rnn_mode == CUDNN_LSTM) direction = (CUDNN_RNN_UNIDIRECTION if dir_count == 1 else CUDNN_RNN_BIDIRECTION) + if use_v2: + os.environ["TF_CUDNN_RNN_USE_V2"] = "1" + else: + os.environ["TF_CUDNN_RNN_USE_V2"] = "0" model = CudnnTestModel( rnn_mode, num_layers, @@ -1106,22 +1259,22 @@ class CudnnRNNTestTraining(TensorFlowTestCase): self._GradientCheck( sess, total_sum, all_inputs, tolerance=tolerance, delta=delta) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): dropouts = [0, 0.5, 1.] - for config, dropout in itertools.product(test_configs, dropouts): + v2_options = [str(False), str(True)] + for config, dropout, use_v2 in itertools.product(test_configs, dropouts, + v2_options): dtype = config.get("dtype", dtypes.float32) delta = config.get("delta", 1e-4) tolerance = config.get("tolerance", 1e-6) dir_count = config.get("dir_count", 1) shape = config["shape"] with ops.Graph().as_default(): - self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], - shape["num_units"], shape["input_size"], - shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, - tolerance) + self._TestOneSimpleTraining( + rnn_mode, shape["num_layers"], shape["num_units"], + shape["input_size"], shape["batch_size"], shape["seq_length"], + dir_count, dropout, dtype, use_v2, delta, tolerance) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 36fba917a8f56c26fd5b4c3468d1d980a8ba2ba5..d58198faf353aab68430d2fa153a18de359112de 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -142,6 +142,9 @@ class _CudnnRNN(base_layer.Layer): """ # pylint:enable=line-too-long + # TODO(allenl): Document object-based saving and checkpoint compatibility once + # it's implemented for more cuDNN Layers. + # The following are constants defined by subclasses. # Type of RNN cell. _rnn_mode = None @@ -355,7 +358,8 @@ class _CudnnRNN(base_layer.Layer): "CUDA/CuDNN generations.") # Initialize opaque params with a tensor. self.kernel = vs.get_variable( - "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + "opaque_kernel", dtype=self._plain_dtype, + initializer=opaque_params_t, validate_shape=False) # Create saveable in the outer scope of the cudnn subgraph, such that # alternative subgraph with platform-independent rnn cells can load the # checkpoints directly. @@ -363,6 +367,11 @@ class _CudnnRNN(base_layer.Layer): self._create_saveable() self.built = True + def _gather_saveables_for_checkpoint(self): + raise NotImplementedError( + "This cell does not yet support object-based saving. File a feature " + "request if this limitation bothers you.") + def call(self, inputs, initial_state=None, training=True): """Runs the forward step for the RNN model. @@ -499,6 +508,8 @@ class _CudnnRNN(base_layer.Layer): direction=self.direction, scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) + self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access + checkpointable=self, dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) @@ -521,6 +532,16 @@ class CudnnLSTM(_CudnnRNN): return ([self.num_layers * self.num_dirs, batch_size, self.num_units], [self.num_layers * self.num_dirs, batch_size, self.num_units]) + @property + def _gather_saveables_for_checkpoint(self): + if self._direction == CUDNN_RNN_UNIDIRECTION: + # Skip one inheritance level to avoid NotImplementedError. + return super(_CudnnRNN, self)._gather_saveables_for_checkpoint + else: + raise NotImplementedError( + "Object-based saving does not currently support bidirectional LSTM " + "cells. File a feature request if this limitation bothers you.") + class _CudnnRNNNoInputC(_CudnnRNN): """Abstract simple CudnnRNN layer without input_c.""" diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index e87162f0ee9cc4eed795555171f55a93639e83cf..8822a7523f6b168f569e29970c9c29f2eb3614fc 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -17,26 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops +import os +from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.contrib.util import loader -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver - -_cudnn_rnn_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) +from tensorflow.python.training.checkpointable import base as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -91,19 +88,23 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python - r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate - h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 + # reset gate + $$r_t = \sigma(x_t * W_r + h_t-1 * R_h + b_{Wr} + b_{Rr})$$ + # update gate + $$u_t = \sigma(x_t * W_u + h_t-1 * R_u + b_{Wu} + b_{Ru})$$ + # new memory gate + $$h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_{Rh}) + b_{Wh})$$ + $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): ```python - h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_Wh) # new memory gate + # new memory gate + \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) ``` which is not equivalent to Cudnn GRU: in addition to the extra bias term b_Rh, ```python - r .* (h * R) != (r .* h) * R + \\(r .* (h * R) != (r .* h) * R\\) ``` """ @@ -267,13 +268,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): # instead of having the master pull all slices and then save them. slice_spec = "" params = weights + biases - param_names = weight_names + bias_names + self._weight_names = weight_names + self._bias_names = bias_names + self._param_names = weight_names + bias_names + prefixed_param_names = weight_names + bias_names if self._scope: - param_names = ["%s/%s" % (self._scope, pn) for pn in param_names] - + prefixed_param_names = [ + "%s/%s" % (self._scope, pn) for pn in prefixed_param_names] specs = [ saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) - for param, param_name in zip(params, param_names) + for param, param_name in zip(params, prefixed_param_names) ] super(CudnnOpaqueParamsSaveable, self).__init__( array_ops.identity(self._variables), specs, name) @@ -286,6 +290,45 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return state_ops.assign( self._variables, opaque_params, validate_shape=False) + def _checkpointable_save(self, save_buffer): + weights, biases = self._OpaqueParamsToCanonical() + with ops.device("gpu:0"): + (weights, _), (biases, _) = self._TransformCanonical( + weights, biases) + for name, tensor in zip(self._param_names, weights + biases): + save_buffer[name] = array_ops.identity(tensor) + + def _checkpointable_restore(self, restore_buffer): + tensors = [array_ops.identity(restore_buffer[name]) + for name in self._param_names] + return self.restore( + restored_tensors=tensors, + restored_shapes=None # Unused + ) + + def _add_checkpointable_dependencies(self, checkpointable, dtype): + """Add canonical weight dependencies to `checkpointable`. + + When saving or restoring, converts to or from the opaque buffer + format. Weights are saved and loaded in the configuration expected by + cuDNN-compatible cells. + + Args: + checkpointable: An object inheriting from `CheckpointableBase` to add + dependencies too (typically the cuDNN `Layer`). + dtype: The dtype for the canonical parameter Tensors. + """ + split_dependencies = split_dependency.split_dependency( + component_names=self._param_names, + component_dtypes=(dtype,) * len(self._param_names), + fill_save_buffer_fn=self._checkpointable_save, + consume_restore_buffer_fn=self._checkpointable_restore) + self._checkpointable_track_params(checkpointable, split_dependencies) + + def _checkpointable_track_params(self, checkpointable, params): + """Tracks parameters in a canonical configuration.""" + return # NotImplementedError raised by the Layer. + def _TFCanonicalNamePrefix(self, layer, is_fwd=True): if self._direction == CUDNN_RNN_UNIDIRECTION: return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) @@ -481,10 +524,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): _rnn_mode = CUDNN_LSTM _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - # pylint:disable=protected-access - _rnn_cell_name = base_layer._to_snake_case(CudnnCompatibleLSTMCell.__name__) - - # pylint:enable=protected-access + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) def _cudnn_to_tf_gate_params(self, *cu_gate_order): i_g, f_g, c_g, o_g = cu_gate_order @@ -575,6 +615,29 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): tf_biases.append(b) tf_bias_names.append(prefix + "/bias") + def _checkpointable_track_params(self, checkpointable, params): + """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" + biases = [] + weights = [] + for name in self._weight_names: + weights.append(params[name]) + for name in self._bias_names: + biases.append(params[name]) + assert len(params) == len(weights) + len(biases) + if len(weights) == 1 and len(biases) == 1: + # For single-layer cells, allow substituting a cell with no MultiRNNCell + # wrapping. + kernel, = weights # pylint: disable=unbalanced-tuple-unpacking + bias, = biases # pylint: disable=unbalanced-tuple-unpacking + checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access + checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + assert len(biases) == len(weights) + for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): + cell = checkpointable_lib.Checkpointable() + checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell.bias = bias + cell.kernel = kernel + class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): """SaveableObject implementation handling Cudnn GRU opaque params.""" @@ -582,10 +645,7 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): _rnn_mode = CUDNN_GRU _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER - # pylint:disable=protected-access - _rnn_cell_name = base_layer._to_snake_case(CudnnCompatibleGRUCell.__name__) - - # pylint:enable=protected-access + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__) def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -664,11 +724,7 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" - # pylint:disable=protected-access - _rnn_cell_name = base_layer._to_snake_case( - rnn_cell_impl.BasicRNNCell.__name__) - - # pylint:enable=protected-access + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -845,19 +901,27 @@ def _cudnn_rnn(inputs, check_direction(direction) check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) - outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( - input=inputs, - input_h=input_h, - input_c=input_c, - params=params, - is_training=is_training, - rnn_mode=rnn_mode, - input_mode=input_mode, - direction=direction, - dropout=dropout, - seed=seed, - seed2=seed2, - name=name) + # TODO(jamesqin): switch default value to "1" on May 25th 2018, and get rid + # of V1 ops. + use_cudnn_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0") + args = { + "input": inputs, + "input_h": input_h, + "input_c": input_c, + "params": params, + "is_training": is_training, + "rnn_mode": rnn_mode, + "input_mode": input_mode, + "direction": direction, + "dropout": dropout, + "seed": seed, + "seed2": seed2, + "name": name + } + if use_cudnn_v2 is not "1": + outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) + else: + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) return (outputs, output_h, output_c) @@ -1582,35 +1646,3 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): # 1 set of weight and bias parameters for the recurrent input, and 1 for the # previous layer input. _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER - - -@ops.RegisterGradient("CudnnRNN") -def _cudnn_rnn_backward(op, *grad): - if not op.get_attr("is_training"): - raise ValueError( - "CudnnRNN must set is_training to True to be used in gradients") - return gen_cudnn_rnn_ops.cudnn_rnn_backprop( - input=op.inputs[0], - input_h=op.inputs[1], - input_c=op.inputs[2], - params=op.inputs[3], - output=op.outputs[0], - output_h=op.outputs[1], - output_c=op.outputs[2], - output_backprop=grad[0], - output_h_backprop=grad[1], - output_c_backprop=grad[2], - reserve_space=op.outputs[3], - dropout=op.get_attr("dropout"), - seed=op.get_attr("seed"), - seed2=op.get_attr("seed2"), - rnn_mode=op.get_attr("rnn_mode"), - input_mode=op.get_attr("input_mode"), - direction=op.get_attr("direction")) - - -ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 0458199ff771bc45603106411550a39448e515b8..8bdbba83ef6a8541158d956e36caf6a9be435c5b 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -8,6 +8,11 @@ load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", + "if_not_windows", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", ) py_library( @@ -17,35 +22,25 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/contrib/data/python/ops:shuffle_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", ], ) +cc_library( + name = "lib_proto_parsing_for_dataset_ops", + deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]), +) + tf_custom_op_library( name = "_dataset_ops.so", srcs = ["ops/dataset_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"], + deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] + + if_static( + extra_deps = [":lib_proto_parsing_for_dataset_ops"], + otherwise = [], + ), ) tf_gen_op_libs( op_lib_names = ["dataset_ops"], ) - -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index fcdccdd26ca1824bf13f1fd0cfd80b20ca8a10c3..1af1ed08b53ee04367eb316d5c9caa0216f2e88d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,20 +23,31 @@ removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@CheckpointInputPipelineHook +@@CsvDataset +@@SqlDataset +@@assert_element_shape @@batch_and_drop_remainder +@@bucket_by_sequence_length +@@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ignore_errors +@@make_batched_features_dataset +@@make_csv_dataset @@make_saveable_from_iterator @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave +@@prefetch_to_device @@read_batch_features @@rejection_resample +@@sample_from_datasets @@scan @@shuffle_and_repeat +@@sliding_window_batch @@sloppy_interleave @@unbatch @@ -49,6 +60,7 @@ from __future__ import print_function # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops.batching import assert_element_shape from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import map_and_batch @@ -58,16 +70,27 @@ from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element +from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave +from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import CsvDataset +from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset +from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat +from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) + +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 56471911c5c0d1c1825955c67997b5bbc0786463..7b69e10441eba3e38c979d5715c16699ac2710ed 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -10,6 +10,7 @@ cc_library( name = "prefetching_kernels", srcs = ["prefetching_kernels.cc"], deps = [ + "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", @@ -17,6 +18,27 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "directed_interleave_dataset_op", + srcs = ["directed_interleave_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "ignore_errors_dataset_op", srcs = ["ignore_errors_dataset_op.cc"], @@ -28,24 +50,37 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "threadpool_dataset_op", + srcs = ["threadpool_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( + name = "unique_dataset_op", + srcs = ["unique_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "dataset_kernels", deps = [ + ":csv_dataset_op", + ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":prefetching_kernels", + ":threadpool_dataset_op", + ":unique_dataset_op", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4657807785d58727d34f37172bd30c56a5b7cde6 --- /dev/null +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -0,0 +1,734 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "na_value", &na_value)); + + std::vector record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + std::vector select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector filenames, bool header, + int64 buffer_size, const DataTypeVector& output_types, + const std::vector& output_shapes, + std::vector record_defaults, std::vector select_cols, + bool use_quote_delim, char delim, string na_value) + : GraphDatasetBase(ctx), + filenames_(std::move(filenames)), + header_(header), + buffer_size_(buffer_size), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + // TODO(rachelim): Implement this + std::vector input_tensors; + TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); + return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); + do { + // We are currently processing a file, so try to read the next record + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement save + return errors::Unimplemented("CSVDataset: SaveInternal"); + } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement restore + return errors::Unimplemented("CSVDataset: RestoreInternal"); + } + + private: + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted + // fields to output tensors as we go. + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result; + + while (!end_of_record) { // Read till we reach \n, \r or EOF + bool include = + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller + } + + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; + } + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + return parse_result; + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + if (next == '\r') SkipNewLineIfNecessary(); + return parse_result; + } else if (next != '"') { + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); + } + + } else { + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + Status parse_result; + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + + if (ch == dataset()->delim_) { + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + pos_++; + return parse_result; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return parse_result; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); + } + // Otherwise, go to next character + pos_++; + } + } + + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + float value; + if (!strings::safe_strtof(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + double value; + if (!strings::safe_strtod(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + component.scalar()() = field.ToString(); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + input_stream_.reset( + new io::RandomAccessInputStream(file_.get(), false)); + buffer_.clear(); + pos_ = 0; + if (dataset()->header_) { + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters + std::unique_ptr input_stream_ + GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector filenames_; + const bool header_; + const int64 buffer_size_; + const DataTypeVector out_type_; + const std::vector output_shapes_; + const std::vector record_defaults_; + const std::vector select_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a12ca06f4d6cc2096aaf8191a01a899881b43db --- /dev/null +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -0,0 +1,280 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + DatasetBase* selector_input; + OP_REQUIRES_OK(ctx, + GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); + + OP_REQUIRES( + ctx, + selector_input->output_dtypes().size() == 1 && + selector_input->output_dtypes()[0] == DT_INT64 && + selector_input->output_shapes().size() == 1 && + selector_input->output_shapes()[0].IsCompatibleWith( + PartialTensorShape({})), + errors::InvalidArgument( + "The selector input must be a dataset of scalar int64 elements.")); + + std::vector data_inputs; + for (size_t i = 1; i < ctx->num_inputs(); ++i) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); + data_inputs.push_back(input); + + OP_REQUIRES( + ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), + errors::InvalidArgument( + "All inputs must have the same output_dtypes. First input " + "has types ", + DataTypeVectorString(data_inputs[0]->output_dtypes()), + ", and input ", i - 1, " has types ", + DataTypeVectorString(input->output_dtypes()))); + } + *output = new Dataset(ctx, selector_input, std::move(data_inputs)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, + std::vector data_inputs) + : GraphDatasetBase(ctx), + selector_input_(selector_input), + data_inputs_(std::move(data_inputs)) { + selector_input_->Ref(); + + output_shapes_ = data_inputs_[0]->output_shapes(); + data_inputs_[0]->Ref(); + for (size_t i = 1; i < data_inputs_.size(); ++i) { + const DatasetBase* data_input = data_inputs_[i]; + data_input->Ref(); + for (size_t j = 0; j < output_shapes_.size(); ++j) { + output_shapes_[j] = MostSpecificCompatibleShape( + output_shapes_[j], data_input->output_shapes()[j]); + } + } + } + + ~Dataset() override { + selector_input_->Unref(); + for (DatasetBase* data_input : data_inputs_) { + data_input->Unref(); + } + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::DirectedInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return data_inputs_[0]->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* selector_input_node; + TF_RETURN_IF_ERROR( + b->AddParentDataset(ctx, selector_input_, &selector_input_node)); + std::vector data_input_nodes(data_inputs_.size()); + for (size_t i = 0; i < data_inputs_.size(); ++i) { + TF_RETURN_IF_ERROR( + b->AddParentDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + } + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, + {{1, data_input_nodes}}, {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!selector_input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + + while (true) { + std::vector selector_result; + *end_of_sequence = false; + TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( + ctx, &selector_result, end_of_sequence)); + if (*end_of_sequence) { + selector_input_impl_.reset(); + for (auto& data_input_impl : data_input_impls_) { + data_input_impl.reset(); + } + return Status::OK(); + } + + int64 selected_input = selector_result[0].scalar()(); + if (selected_input < 0 || selected_input > data_input_impls_.size()) { + return errors::InvalidArgument( + "Selector index out of range: ", selected_input, + " >= ", data_input_impls_.size()); + } + + if (data_input_impls_[selected_input]) { + bool end_of_selected_input = false; + TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( + ctx, out_tensors, &end_of_selected_input)); + + if (!end_of_selected_input) { + return Status::OK(); + } + + data_input_impls_[selected_input].reset(); + --num_active_inputs_; + + if (num_active_inputs_ == 0) { + selector_input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); + } + } + + LOG(WARNING) << "DirectedInterleave selected an exhausted input: " + << selected_input; + } + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (selector_input_impl_) { + TF_RETURN_IF_ERROR(SaveParent(writer, selector_input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const auto& data_input_impl = data_input_impls_[i]; + if (data_input_impl) { + TF_RETURN_IF_ERROR(SaveParent(writer, data_input_impl)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, selector_input_impl_)); + } else { + selector_input_impl_.reset(); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains(full_name( + strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR( + RestoreParent(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr selector_input_impl_ GUARDED_BY(mu_); + std::vector> data_input_impls_ + GUARDED_BY(mu_); + int64 num_active_inputs_ GUARDED_BY(mu_); + }; + + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); d++) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } + return output_tensorshape; + } + + const DatasetBase* const selector_input_; + const std::vector data_inputs_; + std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), + DirectedInterleaveDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bb29df60e8f114aaa50f578c43e73874f72ab0a3..bbec50681c6f5decec5a3b5fbf09cc3011a21199 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); @@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -72,8 +74,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index d3df14bdd03476e9ee4015b374512e5bb9893a63..a2bfce03620a1482f5b21cbf23c66833bc5cd480 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" @@ -35,38 +36,25 @@ using FunctionBufferCallback = std::function; class FunctionBufferingResource : public ResourceBase { public: FunctionBufferingResource(FunctionLibraryRuntime* lib, + std::unique_ptr pflr, const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, - const std::vector& func_args, - int64 thread_pool_size) + const std::vector& func_args) : lib_(lib), + pflr_(std::move(pflr)), func_(func), buffer_size_(buffer_size), source_device_(source_device), target_device_(target_device), func_args_(func_args), - thread_pool_(new thread::ThreadPool(Env::Default(), ThreadOptions(), - "buffer_resource", thread_pool_size, - false /* low_latency_hint */)), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), - cancelled_(false) { - runner_ = [this](std::function c) { - thread_pool_->Schedule(std::move(c)); - }; - } + cancelled_(false) {} ~FunctionBufferingResource() override { Cancel(); - { - mutex_lock l(mu_); - while (is_buffering_) { - cond_var_.wait(l); - } - } - delete thread_pool_; } string DebugString() override { @@ -100,6 +88,20 @@ class FunctionBufferingResource : public ResourceBase { void Cancel() LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); cancelled_ = true; + while (is_buffering_) { + cond_var_.wait(l); + } + } + + // Cancels all pending operations and then clears out the state. + void Reset() LOCKS_EXCLUDED(mu_) { + Cancel(); + mutex_lock l(mu_); + buffer_.clear(); + requests_.clear(); + is_buffering_ = false; + end_of_sequence_ = false; + cancelled_ = false; } // If the buffer has anything, runs `callback` on the first element in the @@ -164,15 +166,12 @@ class FunctionBufferingResource : public ResourceBase { for (int i = 0; i < cancellation_callbacks.size(); ++i) { cancellation_callbacks[i](cancellation_buffer_elements[i]); } - // We only wait on cond_var_ in the destructor, so there would atmost be - // one waiter to notify. - cond_var_.notify_one(); + cond_var_.notify_all(); return; } FunctionLibraryRuntime::Options opts; // Copied from CapturedFunction::generate_step_id(); opts.step_id = -std::abs(static_cast(random::New64())); - opts.runner = &runner_; opts.source_device = source_device_; AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); @@ -191,13 +190,12 @@ class FunctionBufferingResource : public ResourceBase { mutex_lock l(mu_); BufferElement buffer_element; buffer_element.status = status; - if (!status.ok()) { + if (status.ok()) { + buffer_element.value.swap(*rets); + } else { end_of_sequence_ = true; is_buffering_ = false; - buffer_.push_back(std::move(buffer_element)); - return; } - buffer_element.value.swap(*rets); buffer_.push_back(std::move(buffer_element)); if (!requests_.empty()) { buffer_front = std::move(buffer_.front()); @@ -205,9 +203,16 @@ class FunctionBufferingResource : public ResourceBase { callback = std::move(requests_.front()); requests_.pop_front(); } - if (buffer_.size() < buffer_size_) { + if (buffer_.size() < buffer_size_ && !end_of_sequence_) { restart_buffering = true; } else { + // When the buffer is full, we don't want to call + // FillBuffer() unless we're in cancellation phase in which + // case FillBuffer() will do the final cleanup post + // cancellation. + if (cancelled_) { + restart_buffering = true; + } is_buffering_ = false; } } @@ -222,16 +227,15 @@ class FunctionBufferingResource : public ResourceBase { mutex mu_; FunctionLibraryRuntime* lib_; + std::unique_ptr pflr_; NameAttrList func_; const int64 buffer_size_; const string source_device_; const string target_device_; const std::vector func_args_; - thread::ThreadPool* thread_pool_; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); - std::function)> runner_ = nullptr; bool is_buffering_ GUARDED_BY(mu_); bool end_of_sequence_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_); @@ -241,12 +245,22 @@ class FunctionBufferingResource : public ResourceBase { class FunctionBufferResourceHandleOp : public OpKernel { public: explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { + : OpKernel(ctx), flib_def_(nullptr) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_)); + } + + ~FunctionBufferResourceHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } } void Compute(OpKernelContext* ctx) override { @@ -267,33 +281,44 @@ class FunctionBufferResourceHandleOp : public OpKernel { const string& source_device = ctx->device()->name(); - ContainerInfo cinfo; - OP_REQUIRES_OK(ctx, cinfo.Init(ctx->resource_manager(), def())); - // Create the resource. - FunctionBufferingResource* buffer; - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->LookupOrCreate( - cinfo.container(), cinfo.name(), &buffer, - [lib, &source_device, &target_device, func_args, - this](FunctionBufferingResource** ptr) { - *ptr = new FunctionBufferingResource( - lib, func_, buffer_size_, source_device, target_device, - func_args, thread_pool_size_); - return Status::OK(); - })); - OP_REQUIRES_OK(ctx, buffer->Instantiate()); + mutex_lock l(mu_); + if (!initialized_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); + FunctionLibraryRuntime* clone_lib; + std::unique_ptr pflr; + OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib)); + // Create the resource. + FunctionBufferingResource* buffer; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &buffer, + [clone_lib, &pflr, &source_device, &target_device, func_args, + this](FunctionBufferingResource** ptr) { + *ptr = new FunctionBufferingResource( + clone_lib, std::move(pflr), func_, buffer_size_, + source_device, target_device, func_args); + return Status::OK(); + })); + core::ScopedUnref s(buffer); + OP_REQUIRES_OK(ctx, buffer->Instantiate()); + initialized_ = true; + } OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( - ctx, 0, cinfo.container(), cinfo.name(), + ctx, 0, cinfo_.container(), cinfo_.name(), MakeTypeIndex())); } private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + std::unique_ptr flib_def_; NameAttrList func_; int64 buffer_size_; string container_; string name_; - int64 thread_pool_size_; }; REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") @@ -334,25 +359,27 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, handle, &buffer), done); - core::ScopedUnref s(buffer); if (buffer->Finished()) { + buffer->Unref(); ctx->SetStatus(errors::OutOfRange("end_of_sequence")); done(); return; } FunctionBufferCallback callback = - [ctx, done](const BufferElement& buffer_element) { + [ctx, buffer, done](const BufferElement& buffer_element) { Status s = buffer_element.status; if (!s.ok()) { ctx->SetStatus(s); + buffer->Unref(); done(); return; } for (size_t i = 0; i < buffer_element.value.size(); ++i) { ctx->set_output(i, buffer_element.value[i]); } + buffer->Unref(); done(); }; buffer->MaybeGet(std::move(callback)); @@ -374,4 +401,62 @@ REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") FunctionBufferingResourceGetNextOp); #endif // TENSORFLOW_USE_SYCL +// Resets the FunctionBufferingResource, cancelling all pending requests and +// clearing out the buffer. +class FunctionBufferingResourceResetOp : public OpKernel { + public: + explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + ~FunctionBufferingResourceResetOp() override {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle; + OP_REQUIRES_OK(ctx, + HandleFromInput(ctx, "function_buffer_resource", &handle)); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, handle, &buffer)); + core::ScopedUnref s(buffer); + + buffer->Reset(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#endif // TENSORFLOW_USE_SYCL + +class IteratorGetDeviceOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + // NOTE(mrry): We do not currently Validate that the handle + // corresponds to a real IteratorResource, because that symbol is + // not exposed from the framework library. + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + // NOTE(mrry): Since the operation's input is a resource, we must be + // colocated with it, and so we can simply return the current device's + // name without looking at the input. + device_name_t->scalar()() = ctx->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), + IteratorGetDeviceOp); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3dfc3741c2b040dd5be3223c24f0715ba3be4248 --- /dev/null +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -0,0 +1,198 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace { + +class ThreadPoolResource : public ResourceBase { + public: + ThreadPoolResource(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads, bool low_latency_hint) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { + } + + // Schedules fn() for execution in the pool of threads. + void Schedule(std::function fn) { + thread_pool_.Schedule(std::move(fn)); + } + + string DebugString() override { return "ThreadPoolResource"; } + + private: + thread::ThreadPool thread_pool_; +}; + +// Creates a handle to a ThreadPool resource. Note that we don't use +// ResourceOpKernel here because the ThreadPoolResource constructor requires +// access to `OpKernelContext::env()`, which isn't provided by +// `ResourceOpKernel::CreateResource()`. +class ThreadPoolHandleOp : public OpKernel { + public: + explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES( + ctx, num_threads_ > 0, + errors::InvalidArgument("`num_threads` must be greater than zero.")); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ThreadPoolHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + ThreadPoolResource* resource; + OP_REQUIRES_OK(ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](ThreadPoolResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new ThreadPoolResource( + ctx->env(), {}, display_name_, + num_threads_, + false /* low_latency_hint */); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + string display_name_; + int num_threads_; +}; + +class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + ThreadPoolResource* threadpool_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &threadpool_resource)); + core::ScopedUnref unref_iterator(threadpool_resource); + + *output = new Dataset(ctx, input, threadpool_resource); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + ThreadPoolResource* threadpool) + : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) { + input_->Ref(); + threadpool_->Ref(); + } + + ~Dataset() override { + input_->Unref(); + threadpool_->Unref(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented( + "Cannot currently serialize the thread pool for a " + "ThreadPoolDataset."); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + ThreadPoolResource* pool = dataset()->threadpool_; + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = [pool](std::function c) { + pool->Schedule(std::move(c)); + }; + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.lib = ctx->lib(); + params.function_library = ctx->function_library(); + params.allocator_getter = ctx->allocator_getter(); + IteratorContext threadpool_ctx(params); + return input_impl_->GetNext(&threadpool_ctx, out_tensors, + end_of_sequence); + } + + private: + std::unique_ptr input_impl_; + }; + + const DatasetBase* const input_; + ThreadPoolResource* const threadpool_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU), + ThreadPoolHandleOp); +REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU), + ThreadPoolDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc similarity index 95% rename from tensorflow/core/kernels/data/unique_dataset_op.cc rename to tensorflow/contrib/data/kernels/unique_dataset_op.cc index 7726ee0edf71b34cb65fe5fceb2b60dd30bb58e2..67c237799c10a2724f18bb0df99e4bf8f5cd2b8a 100644 --- a/tensorflow/core/kernels/data/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { @@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unique")})); @@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("UniqueDatasetOp::Dataset"); } @@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 289ffa1d9c29092cdf434e86ed5553ff9644d43e..f271d269ab1b9339de4657e459dcbbd462890f0a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -17,6 +17,57 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("DirectedInterleaveDataset") + .Input("selector_input_dataset: variant") + .Input("data_input_datasets: N * variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("N: int >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A substitute for `InterleaveDataset` on a fixed list of `N` datasets. + +selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines + which of the `N` data inputs should produce the next output element. +data_input_datasets: `N` datasets with the same type that will be interleaved + according to the values of `selector_input_dataset`. +)doc"); + +REGISTER_OP("CSVDataset") + .Input("filenames: string") + .Input("buffer_size: int64") + .Input("header: bool") + .Input("field_delim: string") + .Input("use_quote_delim: bool") + .Input("na_value: string") + .Input("select_cols: int64") + .Input("record_defaults: output_types") + .Output("handle: variant") + .Attr("output_types: list({float,double,int32,int64,string}) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, + // `na_value` must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + // `select_cols` must be a vector + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); + // `record_defaults` must be a list of scalars...? + for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + } + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") @@ -27,6 +78,24 @@ REGISTER_OP("IgnoreErrorsDataset") Creates a dataset that contains the elements of `input_dataset` ignoring errors. )doc"); +REGISTER_OP("UniqueDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the unique elements of `input_dataset`. +)doc"); + +REGISTER_OP("IteratorGetDevice") + .Input("resource: resource") + .Output("device: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Returns the name of the device on which `resource` has been placed. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") @@ -35,7 +104,6 @@ REGISTER_OP("FunctionBufferingResource") .Attr("container: string") .Attr("f: func") .Attr("buffer_size: int") - .Attr("thread_pool_size: int") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Creates a resource that fills up a buffer by making function calls. @@ -45,7 +113,6 @@ target_device: Target device to execute the function on. resource: Handle to the resource created. f: Function to be executed. buffer_size: Size of the buffer. -thread_pool_size: Size of the threadpool doing the prefetching. container: If non-empty, this resource is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this resource will be shared under the given name @@ -65,4 +132,42 @@ output: A list of return values. output_types: The type list for the return values. )doc"); +REGISTER_OP("FunctionBufferingResourceReset") + .Input("function_buffer_resource: resource") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Resets the FunctionBufferingResource. + +function_buffer_resource: The FunctionBufferingResource handle. +)doc"); + +REGISTER_OP("ThreadPoolDataset") + .Input("input_dataset: variant") + .Input("thread_pool: resource") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that uses a custom thread pool to compute `input_dataset`. + +handle: A resource produced by the ThreadPoolHandle op. +)doc"); + +REGISTER_OP("ThreadPoolHandle") + .Output("handle: resource") + .SetShapeFn(shape_inference::ScalarShape) + .Attr("num_threads: int") + .Attr("display_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Doc(R"doc( +Creates a custom thread pool with the given number of threads. + +handle: A resource that can be consumed by one or more ThreadPoolDataset ops. +num_threads: The number of threads in the thread pool. +display_name: A human-readable name for the threads that may be visible in + some visualizations. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index e51d57cc896dc32d8e11912cd89f34a04a858c78..0dfd249ec27d96d6f1a4ae65d623df456db9991f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,24 +4,27 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", - size = "small", + size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_oss", # (b/79552534) + "no_pip", + ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", @@ -32,13 +35,12 @@ py_test( py_test( name = "bucketing_test", - size = "small", + size = "medium", srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -52,6 +54,19 @@ py_test( ], ) +py_test( + name = "cache_dataset_op_test", + size = "small", + srcs = ["cache_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "concatenate_dataset_op_test", size = "small", @@ -59,10 +74,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -79,8 +94,7 @@ py_test( ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -119,21 +133,38 @@ py_library( ], ) +py_test( + name = "csv_dataset_op_test", + size = "small", + srcs = ["csv_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", srcs = ["filter_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -145,7 +176,7 @@ tf_py_test( additional_deps = [ ":dataset_serialization_test", "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -168,13 +199,14 @@ py_test( srcs = ["interleave_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ + "manual", "no_oss", "no_pip", + "notap", ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -185,6 +217,24 @@ py_test( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -195,7 +245,8 @@ tf_py_test( srcs = ["get_single_element_test.py"], additional_deps = [ "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -210,11 +261,14 @@ py_test( size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "noasan", # times out + "optonly", + ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -239,6 +293,19 @@ py_test( ], ) +py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "prefetch_dataset_op_test", size = "small", @@ -259,8 +326,8 @@ py_test( srcs_version = "PY2AND3", deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:counter", + "//tensorflow/contrib/data/python/ops:enumerate_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -272,6 +339,27 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", ], ) @@ -279,12 +367,13 @@ py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -293,8 +382,11 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + "//third_party/py/numpy", ], ) @@ -304,15 +396,22 @@ py_test( srcs = ["resample_test.py"], shard_count = 2, srcs_version = "PY2AND3", - tags = ["noasan"], + tags = [ + "noasan", + "optonly", + ], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:resampling", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:string_ops", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -324,13 +423,14 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", "//third_party/py/numpy", ], ) @@ -343,11 +443,11 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -375,7 +475,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -383,6 +483,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -395,6 +496,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -412,10 +514,30 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + ":reader_dataset_ops_test_base", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "threadpool_dataset_ops_test", + size = "small", + srcs = ["threadpool_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:threadpool", + "//tensorflow/contrib/data/python/ops:unique", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -427,13 +549,13 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/contrib/stateless", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -446,25 +568,20 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) -py_test( +cuda_py_test( name = "prefetching_ops_test", size = "small", srcs = ["prefetching_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "manual", - "no_oss", # b/68785503 - ], - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -479,16 +596,39 @@ py_test( ], ) -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +tf_py_test( + name = "slide_dataset_op_test", + size = "small", + srcs = ["slide_dataset_op_test.py"], + additional_deps = [ + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sliding", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "writer_ops_test", + size = "small", + srcs = ["writer_ops_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:writers", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:lib", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 71dc1c1172c9d515d4c85f85257c952135098329..b5fbc45ad3d8d262c1c79b5723ffeb38ff6a34c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -18,20 +18,26 @@ from __future__ import division from __future__ import print_function import math +import time import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test +from tensorflow.python.util import compat class BatchDatasetTest(test.TestCase): @@ -149,6 +155,69 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(op) + def testUnbatchDatasetWithStrings(self): + data = tuple([math_ops.range(10) for _ in range(3)]) + data = dataset_ops.Dataset.from_tensor_slices(data) + data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z)) + expected_types = (dtypes.int32, dtypes.string, dtypes.int32) + data = data.batch(2) + self.assertEqual(expected_types, data.output_types) + data = data.apply(batching.unbatch()) + self.assertEqual(expected_types, data.output_types) + + iterator = data.make_one_shot_iterator() + op = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(op) + + def testUnbatchDatasetWithSparseTensor(self): + st = sparse_tensor.SparseTensorValue( + indices=[[i, i] for i in range(10)], + values=list(range(10)), + dense_shape=[10, 10]) + data = dataset_ops.Dataset.from_tensors(st) + data = data.apply(batching.unbatch()) + data = data.batch(5) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + st_row = sess.run(next_element) + self.assertEqual([i], st_row.indices) + self.assertEqual([i], st_row.values) + self.assertEqual([10], st_row.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchDatasetWithDenseAndSparseTensor(self): + st = sparse_tensor.SparseTensorValue( + indices=[[i, i] for i in range(10)], + values=list(range(10)), + dense_shape=[10, 10]) + data = dataset_ops.Dataset.from_tensors((list(range(10)), st)) + data = data.apply(batching.unbatch()) + data = data.batch(5) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + dense_elem, st_row = sess.run(next_element) + self.assertEqual(i, dense_elem) + self.assertEqual([i], st_row.indices) + self.assertEqual([i], st_row.values) + self.assertEqual([10], st_row.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testUnbatchSingleElementTupleDataset(self): data = tuple([(math_ops.range(10),) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) @@ -189,6 +258,53 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(op) + def testUnbatchEmpty(self): + data = dataset_ops.Dataset.from_tensors( + (constant_op.constant([]), constant_op.constant([], shape=[0, 4]), + constant_op.constant([], shape=[0, 4, 0]))) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchStaticShapeMismatch(self): + data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), + np.arange(9))) + with self.assertRaises(ValueError): + data.apply(batching.unbatch()) + + def testUnbatchDynamicShapeMismatch(self): + ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) + ph2 = array_ops.placeholder(dtypes.int32, shape=None) + data = dataset_ops.Dataset.from_tensors((ph1, ph2)) + data = data.apply(batching.unbatch()) + iterator = data.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + # Mismatch in the 0th dimension. + sess.run( + iterator.initializer, + feed_dict={ + ph1: np.arange(7).astype(np.int32), + ph2: np.arange(8).astype(np.int32) + }) + with self.assertRaises(errors.InvalidArgumentError): + print(sess.run(next_element)) + + # No 0th dimension (i.e. scalar value) for one component. + sess.run( + iterator.initializer, + feed_dict={ + ph1: np.arange(7).astype(np.int32), + ph2: 7 + }) + with self.assertRaises(errors.InvalidArgumentError): + print(sess.run(next_element)) + def testBatchAndDropRemainder(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], @@ -311,10 +427,12 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> - # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) @@ -330,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -381,11 +500,95 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testBatchAndMapDataset(self): - return self._testBatchAndMapDatasetHelper() + def testMapAndBatch(self): + return self._testMapAndBatchDatasetHelper() + + def testMapAndBatchWithParallelBatches(self): + return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + iterator = ( + dataset_ops.Dataset.range(10).apply( + batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), + batch_size=4, + drop_remainder=drop_remainder)).make_one_shot_iterator()) + if drop_remainder: + self.assertEqual([4, 1], iterator.output_shapes.as_list()) + else: + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + if not drop_remainder: + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchPartialBatch(self): + return self._testMapAndBatchPartialBatchHelper() + + def testMapAndBatchPartialBatchDropRemainder(self): + return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) + + def testMapAndBatchYieldsPartialBatch(self): + iterator = (dataset_ops.Dataset.range(10) + .apply(batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), 4)) + .make_one_shot_iterator()) + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(50000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(5): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) - def testBatchAndMapDatasetWithParallelBatching(self): - return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(4): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) def testMapAndBatchSparse(self): @@ -411,7 +614,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testBatchAndMapDatasetFails(self): + def testMapAndBatchDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -425,7 +628,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testBatchAndMapDatasetShapeMismatch(self): + def testMapAndBatchDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -474,9 +677,7 @@ class BatchDatasetSerializationTest( lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) - # TODO(b/70988345): Re-enable when sparse tensors are properly supported by - # the DatasetSerializationTestBase. - def _testDenseToSparseBatchDatasetCore(self): + def testDenseToSparseBatchDatasetCore(self): components = np.random.randint(5, size=(40,)).astype(np.int32) diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) @@ -503,6 +704,86 @@ class BatchDatasetSerializationTest( self.run_core_tests(self._build_dataset_nested_sparse, None, 1) +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -539,5 +820,149 @@ class PaddedBatchDatasetSerializationTest( lambda: build_dataset(seq_lens2), 8) +class RestructuredDatasetTest(test.TestCase): + + def test_assert_element_shape(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + +class UnbatchDatasetBenchmark(test.Benchmark): + + def benchmarkNativeUnbatch(self): + batch_sizes = [1, 2, 5, 10, 20, 50] + elems_per_trial = 10000 + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) + batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset.batch(batch_size_placeholder) + dataset = dataset.apply(batching.unbatch()) + dataset = dataset.skip(elems_per_trial) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for batch_size in batch_sizes: + deltas = [] + for _ in range(5): + sess.run( + iterator.initializer, + feed_dict={batch_size_placeholder: batch_size}) + start = time.time() + sess.run(next_element.op) + end = time.time() + deltas.append((end - start) / elems_per_trial) + + median_wall_time = np.median(deltas) + print("Unbatch (native) batch size: %d Median wall time per element:" + " %f microseconds" % (batch_size, median_wall_time * 1e6)) + self.report_benchmark( + iters=10000, + wall_time=median_wall_time, + name="benchmark_unbatch_dataset_native_batch_size_%d" % + batch_size) + + # Include a benchmark of the previous `unbatch()` implementation that uses + # a composition of more primitive ops. Eventually we'd hope to generate code + # that is as good in both cases. + def benchmarkOldUnbatchImplementation(self): + batch_sizes = [1, 2, 5, 10, 20, 50] + elems_per_trial = 10000 + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) + batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset.batch(batch_size_placeholder) + dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices) + dataset = dataset.skip(elems_per_trial) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for batch_size in batch_sizes: + deltas = [] + for _ in range(5): + sess.run( + iterator.initializer, + feed_dict={batch_size_placeholder: batch_size}) + start = time.time() + sess.run(next_element.op) + end = time.time() + deltas.append((end - start) / elems_per_trial) + + median_wall_time = np.median(deltas) + print("Unbatch (unfused) batch size: %d Median wall time per element:" + " %f microseconds" % (batch_size, median_wall_time * 1e6)) + self.report_benchmark( + iters=10000, + wall_time=median_wall_time, + name="benchmark_unbatch_dataset_unfused_batch_size_%d" % + batch_size) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index f1b494e1a620992365ed75613b508e32f94b40a4..bd3e034211c4aa454e4f8f6b09f14935d7a3b35c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base @@ -26,6 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -33,6 +36,179 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test +class GroupByReducerTest(test.TestCase): + + def checkResults(self, dataset, shapes, values): + self.assertEqual(shapes, dataset.output_shapes) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + for expected in values: + got = sess.run(get_next) + self.assertEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSum(self): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer(lambda x: x % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testAverage(self): + + def reduce_fn(x, y): + return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( + x[1] + 1), x[1] + 1 + + reducer = grouping.Reducer( + init_func=lambda _: (0.0, 0.0), + reduce_func=reduce_fn, + finalize_func=lambda x: x[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer( + lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) + + def testConcat(self): + components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) + reducer = grouping.Reducer( + init_func=lambda x: "", + reduce_func=lambda x, y: x + y[0], + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensor_slices(components), + dataset_ops.Dataset.range(2 * i))).apply( + grouping.group_by_reducer(lambda x, y: y % 2, reducer)) + self.checkResults( + dataset, + shapes=tensor_shape.scalar(), + values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) + + def testSparseSum(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1], dtype=np.int64)), + dense_shape=np.array([1, 1])) + + reducer = grouping.Reducer( + init_func=lambda _: _sparse(np.int64(0)), + reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), + finalize_func=lambda x: x.values[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( + grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testChangingStateShape(self): + + def reduce_fn(x, _): + # Statically known rank, but dynamic length. + larger_dim = array_ops.concat([x[0], x[0]], 0) + # Statically unknown rank. + larger_rank = array_ops.expand_dims(x[1], 0) + return larger_dim, larger_rank + + reducer = grouping.Reducer( + init_func=lambda x: ([0], 1), + reduce_func=reduce_fn, + finalize_func=lambda x: x) + + for i in range(1, 11): + dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( + grouping.group_by_reducer(lambda x: x, reducer)) + self.assertEqual([None], dataset.output_shapes[0].as_list()) + self.assertIs(None, dataset.output_shapes[1].ndims) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual([0] * (2**i), x) + self.assertAllEqual(np.array(1, ndmin=i), y) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testTypeMismatch(self): + reducer = grouping.Reducer( + init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), + reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64(0), reducer)) + + # TODO(b/78665031): Remove once non-scalar keys are supported. + def testInvalidKeyShape(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) + + # TODO(b/78665031): Remove once non-int64 keys are supported. + def testInvalidKeyType(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: "wrong", reducer)) + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + class GroupByWindowTest(test.TestCase): def testSimple(self): @@ -59,7 +235,7 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) num_full_batches = len([c for c in counts if c == 4]) - self.assertGreaterEqual(num_full_batches, 23) + self.assertGreaterEqual(num_full_batches, 24) self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) def testImmediateOutput(self): @@ -102,6 +278,21 @@ class GroupByWindowTest(test.TestCase): self.assertAllEqual([0, 0, 0], sess.run(get_next)) self.assertAllEqual([1], sess.run(get_next)) + def testEmpty(self): + iterator = ( + dataset_ops.Dataset.range(4).apply( + grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Window size must be greater than zero, but got 0."): + print(sess.run(get_next)) + def testReduceFuncError(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -379,5 +570,118 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +class BucketBySequenceLength(test.TestCase): + + def testBucket(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25, 35] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + def testPadToBoundary(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes[:-1], lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + for _ in range(batch_sizes[-1]): + el = [1] * (boundaries[-1] + 5) + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(3): + batches.append(sess.run(batch)) + with self.assertRaisesOpError("bucket_boundaries"): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + batch_sizes = batch_sizes[:-1] + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(boundaries), sorted(lengths_val)) + + def testTupleElements(self): + + def elements_gen(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (x, y) + + def element_length_fn(x, y): + del y + return array_ops.shape(x)[0] + + dataset = dataset_ops.Dataset.from_generator( + generator=elements_gen, + output_shapes=(tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([])), + output_types=(dtypes.int32, dtypes.int32)) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8])) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f08216a303e2d7dee155ccadcdb9f42f1b24ea0f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental features of CacheDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class CacheToFileDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.range_size = 10 + self.num_repeats = 3 + self.num_outputs = self.range_size * self.num_repeats + self.cache_file_prefix = 'test' + + def ds_fn(self): + return dataset_ops.Dataset.range(self.range_size).cache( + os.path.join(self.get_temp_dir(), + self.cache_file_prefix)).repeat(self.num_repeats) + + def expected_outputs(self): + return list(range(self.range_size)) * self.num_repeats + + def testCheckpointBeforeOneEpoch(self): + # Generate 5 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointBeforeOneEpochThenRunFewSteps(self): + # Generate 8 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 8, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, range(8)) + + # Restoring from checkpoint and running GetNext should return a + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + def testCheckpointAfterOneEpoch(self): + # Generate 15 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointAfterOneEpochThenRunFewSteps(self): + # Generate 18 entries from iterator but save checkpoint after producing + # 15. + outputs = self.gen_outputs( + self.ds_fn, [15], + 18, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) + + outputs = list(range(10)) + list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): + # Generate 13 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 13, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) + + # Since we ran for more than one epoch, the cache was completely written. + # The ckpt was saved when the iterator was in cache-write mode. Test that + # the iterator falls back to read mode after restoring if the cache has + # been completely written. + + outputs = list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedWriterIterator(self): + # Checkpoint before get_next is called even once. + outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + self.assertSequenceEqual(outputs, []) + + outputs = self.gen_outputs( + self.ds_fn, [], + self.num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedMidwayWriterIterator(self): + # Produce 5 elements and checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint, then produce no elements and checkpoint. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce rest of the elements. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testUnusedCheckpointError(self): + # Produce 5 elements and save ckpt. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + + def testIgnoreCheckpointIfCacheWritten(self): + # Produce 15 elements and save ckpt. This will write the complete cache. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Build the iterator again but do not restore from ckpt. Since the cache + # has already been written we should be able to use it. + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5e9416521dcad9ee5047a8275f8fd0142e338 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -0,0 +1,619 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CsvDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import string +import tempfile +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.client import session +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class CsvDatasetOpTest(test.TestCase): + + def _assert_datasets_equal(self, g, ds1, ds2): + assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' + '%s') % (ds1.output_shapes, + ds2.output_shapes) + assert ds1.output_types == ds2.output_types + assert ds1.output_classes == ds2.output_classes + next1 = ds1.make_one_shot_iterator().get_next() + next2 = ds2.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = sess.run(next1) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + sess.run(next2) + break + op2 = sess.run(next2) + self.assertAllEqual(op1, op2) + + def setup_files(self, inputs, linebreak='\n'): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) + filenames.append(fn) + return filenames + + def _make_test_datasets(self, inputs, **kwargs): + # Test by comparing its output to what we could get with map->decode_csv + filenames = self.setup_files(inputs) + dataset_expected = core_readers.TextLineDataset(filenames) + dataset_expected = dataset_expected.map( + lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + dataset_actual = readers.CsvDataset(filenames, **kwargs) + return (dataset_actual, dataset_expected) + + def _test_by_comparison(self, inputs, **kwargs): + """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" + with ops.Graph().as_default() as g: + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(g, dataset_actual, dataset_expected) + + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def _test_dataset(self, + inputs, + expected_output=None, + expected_err_re=None, + linebreak='\n', + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self.setup_files(inputs, linebreak) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): + record_defaults = [[]] * 4 + inputs = [['1,2,3,4']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_int(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_float(self): + record_defaults = [[0.0]] * 4 + inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_string(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_mixedTypes(self): + record_defaults = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.string), + constant_op.constant([], dtype=dtypes.float64) + ] + inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withUseQuoteDelimFalse(self): + record_defaults = [['']] * 4 + inputs = [['1,2,"3,4"', '"5,6",7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_withFieldDelim(self): + record_defaults = [[0]] * 4 + inputs = [['1:2:3:4', '5:6:7:8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, field_delim=':') + + def testCsvDataset_withNaValue(self): + record_defaults = [[0]] * 4 + inputs = [['1,NA,3,4', 'NA,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, na_value='NA') + + def testCsvDataset_withSelectCols(self): + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, select_cols=[1, 2]) + + def testCsvDataset_withSelectColsTooHigh(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults, + select_cols=[3, 4]) + + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withMultipleFiles(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withLeadingAndTrailingSpaces(self): + record_defaults = [[0.0]] * 4 + inputs = [['0, 1, 2, 3']] + expected = [[0.0, 1.0, 2.0, 3.0]] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithMissingDefault(self): + record_defaults = [[]] * 2 + inputs = [['0,']] + self._test_dataset( + inputs, + expected_err_re='Field 1 is required but missing in record!', + record_defaults=record_defaults) + + def testCsvDataset_errorWithFewerDefaultsThanFields(self): + record_defaults = [[0.0]] * 2 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have more in record', + record_defaults=record_defaults) + + def testCsvDataset_errorWithMoreDefaultsThanFields(self): + record_defaults = [[0.0]] * 5 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 5 fields but have 4 in record', + record_defaults=record_defaults) + + def testCsvDataset_withHeader(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2', '1,2']] + expected = [[1, 2]] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withHeaderAndNoRecords(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2']] + expected = [] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_errorWithHeaderEmptyFile(self): + record_defaults = [[0]] * 2 + inputs = [[]] + expected_err_re = "Can't read header of file" + self._test_dataset( + inputs, + expected_err_re=expected_err_re, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withEmptyFile(self): + record_defaults = [['']] * 2 + inputs = [['']] # Empty file + self._test_dataset( + inputs, expected_output=[], record_defaults=record_defaults) + + def testCsvDataset_errorWithEmptyRecord(self): + record_defaults = [['']] * 2 + inputs = [['', '1,2']] # First record is empty + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults) + + def testCsvDataset_withChainedOps(self): + # Testing that one dataset can create multiple iterators fine. + # `repeat` creates multiple iterators from the same C++ Dataset. + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', '5,6,,8']] + ds_actual, ds_expected = self._make_test_datasets( + inputs, record_defaults=record_defaults) + with ops.Graph().as_default() as g: + self._assert_datasets_equal(g, + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) + + def testCsvDataset_withTypeDefaults(self): + # Testing using dtypes as record_defaults for required fields + record_defaults = [dtypes.float32, [0.0]] + inputs = [['1.0,2.0', '3.0,4.0']] + self._test_dataset( + inputs, + [[1.0, 2.0], [3.0, 4.0]], + record_defaults=record_defaults, + ) + + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self.setup_files(data) + + with ops.Graph().as_default() as g: + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + next_batch = ds.make_one_shot_iterator().get_next() + + with self.test_session(graph=g) as sess: + result = list(sess.run(next_batch).values()) + + self.assertEqual(result, sorted(result)) + +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + + +class CsvDatasetBenchmark(test.Benchmark): + """Benchmarks for the various ways of creating a dataset from CSV files. + """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 + + def _setUp(self, str_val): + # Since this isn't test.TestCase, have to manually create a test dir + gfile.MakeDirs(googletest.GetTempDir()) + self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + + self._num_cols = [4, 64, 256] + self._num_per_iter = 5000 + self._filenames = [] + for n in self._num_cols: + fn = os.path.join(self._temp_dir, 'file%d.csv' % n) + with open(fn, 'wb') as f: + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) + self._filenames.append(fn) + + def _tearDown(self): + gfile.DeleteRecursively(self._temp_dir) + + def _runBenchmark(self, dataset, num_cols, prefix): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: + start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. + sess.run(next_element) + end = time.time() + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter + print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, + median_wall_time)) + self.report_benchmark( + iters=self._num_per_iter, + wall_time=median_wall_time, + name='%s_with_cols_%d' % (prefix, num_cols)) + + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') + self._tearDown() + + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') + self._tearDown() + + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index dbc35097ddda9f0375060d43aeb43efa8107f929..393f08850b1865180a8b94e9209b2445b54c8b69 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -163,7 +163,7 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs, sparse_tensors=False, verify_exhausted=True): - """Verifies that restoring into an already initilized iterator works. + """Verifies that restoring into an already initialized iterator works. Args: ds_fn: See `run_core_tests`. @@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase): ckpt_saved=False, init_before_restore=False, sparse_tensors=False, - verify_exhausted=True): + verify_exhausted=True, + save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. @@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase): if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True return outputs diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34b6a080c0aae7dfc228746139acc52cea4e6f28 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 5000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 32ea44f7c7ba329dc253bb9fbbcac0a1ed16aec7..87b7c6ddb7afcbaaf8fe97cd8be87e6f5af8cd4d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -22,6 +22,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -33,17 +34,25 @@ class GetSingleElementTest(test.TestCase): take_value = array_ops.placeholder_with_default( constant_op.constant(1, dtype=dtypes.int64), shape=[]) + def make_sparse(x): + x_1d = array_ops.reshape(x, [1]) + x_2d = array_ops.reshape(x, [1, 1]) + return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) + dataset = (dataset_ops.Dataset.range(100) .skip(skip_value) - .map(lambda x: x * x) + .map(lambda x: (x * x, make_sparse(x))) .take(take_value)) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + for x in [0, 5, 10]: + dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) + self.assertEqual(x * x, dense_val) + self.assertAllEqual([[x]], sparse_val.indices) + self.assertAllEqual([x], sparse_val.values) + self.assertAllEqual([x], sparse_val.dense_shape) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Dataset was empty."): diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 256ad8d94dc1a7c2b26df3f1ebf8e8e321882c15..bee561e3e23a2ab6f314894caa21785347e6ca8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -94,6 +94,76 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -338,7 +408,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): """Tests where all the workers race in producing elements. - Note: this is in contrast with the prevous test which carefully sequences + Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. Args: @@ -424,7 +494,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): """Tests where all the workers race in producing elements. - Note: this is in contrast with the prevous test which carefully sequences + Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30f1847dcddbfaf379ef2b09185f7a8db4aaeae2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index dc3e38db59301bf1819999f479171af35930e9d2..b08132cd72254326d965907a1fdafb8a820926a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import threading from tensorflow.contrib.data.python.ops import prefetching_ops @@ -26,37 +25,43 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class StagingAreaOpsTest(test.TestCase): +class PrefetchingKernelsOpsTest(test.TestCase): def setUp(self): self._event = threading.Event() - def _prefetch_fn_helper(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + def _create_ds_and_iterator(self, device0, initializable=False): def gen(): - for i in itertools.count(start=1, step=1): - yield [i + 0.0] + for i in range(1, 10): + yield [float(i)] if i == 6: self._event.set() with ops.device(device0): - dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() + ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) + if initializable: + ds_iterator = ds.make_initializable_iterator() + else: + ds_iterator = ds.make_one_shot_iterator() + return (ds, ds_iterator) + + def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): + ds_iterator_handle = ds_iterator.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) + h, ds.output_types, ds.output_shapes) return remote_iterator.get_next() target = constant_op.constant(device0) @@ -64,15 +69,28 @@ class StagingAreaOpsTest(test.TestCase): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, target_device=target, - string_arg=iterator_3_handle, + string_arg=ds_iterator_handle, buffer_size=3, - thread_pool_size=2, shared_name=buffer_name) with ops.device(device1): prefetch_op = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=buffer_resource_handle, output_types=[dtypes.float32]) + reset_op = prefetching_ops.function_buffering_resource_reset( + function_buffer_resource=buffer_resource_handle) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + return (prefetch_op, reset_op, destroy_op) + + def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) + prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, + device0, device1) with self.test_session(config=worker_config) as sess: elem = sess.run(prefetch_op) @@ -86,26 +104,277 @@ class StagingAreaOpsTest(test.TestCase): self._event.wait() elem = sess.run(prefetch_op) self.assertEqual(elem, [5.0]) - sess.run( - resource_variable_ops.destroy_resource_op( - buffer_resource_handle, ignore_lookup_error=True)) + sess.run(destroy_op) def testSameDeviceCPU(self): - self._prefetch_fn_helper("same_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:0") + self._prefetch_fn_helper_one_shot("same_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:0") def testDifferentDeviceCPU(self): - self._prefetch_fn_helper("diff_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:1") + self._prefetch_fn_helper_one_shot("diff_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:1") def testDifferentDeviceCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") - self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/gpu:0") + self._prefetch_fn_helper_one_shot("cpu_gpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/gpu:0") + + def testReinitialization(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + # Lets reset the function buffering resource and reinitialize the + # iterator. Should be able to go through this again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run(destroy_op) + + def testReinitializationOutOfRange(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + # Now reset everything and try it out again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + + +class PrefetchToDeviceTest(test.TestCase): + + def testPrefetchToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchSparseTensorsToDevice(self): + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2]) + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_initializable_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 6efe97444a375febc550ff3a3ea04bcd9330a3a5..3b07ef290bc38daa37472ef8919f3350851fe370 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,10 +21,11 @@ import gzip import os import zlib +import numpy as np + from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op @@ -34,6 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -254,128 +256,31 @@ class TFRecordDatasetSerializationTest( lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) -class ReadBatchFeaturesTest(test.TestCase): +def _interleave(iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 - def setUp(self): - super(ReadBatchFeaturesTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - self.test_filenames = self._createFiles() + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 - def _read_batch_features(self, filenames, num_epochs, batch_size): - self.filenames = filenames - self.num_epochs = num_epochs - self.batch_size = batch_size - - return readers.read_batch_features( - file_pattern=self.filenames, - batch_size=self.batch_size, - features={ - "file": parsing_ops.FixedLenFeature([], dtypes.int64), - "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) - }, - reader=core_readers.TFRecordDataset, - randomize_input=False, - num_epochs=self.num_epochs) - - def _record(self, f, r): - example = example_pb2.Example(features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature(int64_list=feature_pb2.Int64List( - value=[f])), - "record": - feature_pb2.Feature(int64_list=feature_pb2.Int64List( - value=[r])), - "keywords": - feature_pb2.Feature(bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) - return example.SerializeToString() - - def _get_keywords(self, f, r): - num_keywords = 1 + (f + r) % 2 - keywords = [] - for index in range(num_keywords): - keywords.append(compat.as_bytes("keyword%d" % index)) - return keywords - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - def _next_actual_batch(self, sess): - file_op = self.outputs["file"] - keywords_indices_op = self.outputs["keywords"].indices - keywords_values_op = self.outputs["keywords"].values - keywords_dense_shape_op = self.outputs["keywords"].dense_shape - record_op = self.outputs["record"] - return sess.run([ - file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op - ]) - - def _next_expected_batch(self, file_indices, batch_size, num_epochs): - - def _next_record(file_indices): - for j in file_indices: - for i in range(self._num_records): - yield j, i - - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - for _ in range(num_epochs): - for record in _next_record(file_indices): - f = record[0] - r = record[1] - file_batch.append(f) - record_batch.append(r) - keywords = self._get_keywords(f, r) - keywords_batch_values.extend(keywords) - keywords_batch_indices.extend([[batch_index, i] - for i in range(len(keywords))]) - batch_index += 1 - keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) - if len(file_batch) == batch_size: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch - ] - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - if file_batch: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch - ] - - def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1): - if file_index is not None: - file_indices = [file_index] - else: - file_indices = range(self._num_files) - for expected_batch in self._next_expected_batch(file_indices, batch_size, - num_epochs): - actual_batch = self._next_actual_batch(sess) - for i in range(len(expected_batch)): - self.assertAllEqual(expected_batch[i], actual_batch[i]) +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testRead(self): for batch_size in [1, 2]: @@ -383,33 +288,33 @@ class ReadBatchFeaturesTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 1. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from both files. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) @@ -418,9 +323,10 @@ class ReadBatchFeaturesTest(test.TestCase): "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } - dataset = (core_readers.TFRecordDataset(self.test_filenames) - .map(lambda x: parsing_ops.parse_single_example(x, features)) - .repeat(10).batch(2)) + dataset = ( + core_readers.TFRecordDataset(self.test_filenames) + .map(lambda x: parsing_ops.parse_single_example(x, features)) + .repeat(10).batch(2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer next_element = iterator.get_next() @@ -435,6 +341,762 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testReadWithFusedShuffleRepeatDataset(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + for batch_size in [1, 2]: + # Test that shuffling with same seed produces the same result. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5).make_one_shot_iterator().get_next() + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + # Test that shuffling with different seeds produces a different order. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=15).make_one_shot_iterator().get_next() + all_equal = True + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + def testParallelReadersAndParsers(self): + num_epochs = 5 + for batch_size in [1, 2]: + for reader_num_threads in [2, 4]: + for parser_num_threads in [2, 4]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( + sess, + batch_size, + num_epochs=num_epochs, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess) + + def testDropFinalBatch(self): + for batch_size in [1, 2]: + for num_epochs in [1, 10]: + with ops.Graph().as_default(): + # Basic test: read from file 0. + self.outputs = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + drop_final_batch=True).make_one_shot_iterator().get_next() + for _, tensor in self.outputs.items(): + if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. + self.assertEqual(tensor.shape[0], batch_size) + + +class MakeCsvDatasetTest(test.TestCase): + + COLUMN_TYPES = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] + COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))] + DEFAULT_VALS = [[], [], [], [], ["NULL"]] + DEFAULTS = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.float64), + constant_op.constant(["NULL"], dtype=dtypes.string) + ] + LABEL = COLUMNS[0] + + def setUp(self): + super(MakeCsvDatasetTest, self).setUp() + self._num_files = 2 + self._num_records = 11 + self._test_filenames = self._create_files() + + def _csv_values(self, fileno, recordno): + return [ + fileno, + recordno, + fileno * recordno * 0.5, + fileno * recordno + 0.5, + "record %d" % recordno if recordno % 2 == 1 else "", + ] + + def _write_file(self, filename, rows): + for i in range(len(rows)): + if isinstance(rows[i], list): + rows[i] = ",".join(str(v) if v is not None else "" for v in rows[i]) + fn = os.path.join(self.get_temp_dir(), filename) + f = open(fn, "w") + f.write("\n".join(rows)) + f.close() + return fn + + def _create_file(self, fileno, header=True): + rows = [] + if header: + rows.append(self.COLUMNS) + for recno in range(self._num_records): + rows.append(self._csv_values(fileno, recno)) + return self._write_file("csv_file%d.csv" % fileno, rows) + + def _create_files(self): + filenames = [] + for i in range(self._num_files): + filenames.append(self._create_file(i)) + return filenames + + def _make_csv_dataset( + self, + filenames, + defaults, + column_names=COLUMNS, + label_name=LABEL, + select_cols=None, + batch_size=1, + num_epochs=1, + shuffle=False, + shuffle_seed=None, + header=True, + na_value="", + ): + return readers.make_csv_dataset( + filenames, + batch_size=batch_size, + column_names=column_names, + column_defaults=defaults, + label_name=label_name, + num_epochs=num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + header=header, + na_value=na_value, + select_columns=select_cols, + ) + + def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults): + features = {col: list() for col in self.COLUMNS} + for _ in range(num_epochs): + for i in file_indices: + for j in range(self._num_records): + values = self._csv_values(i, j) + for n, v in enumerate(values): + if v == "": # pylint: disable=g-explicit-bool-comparison + values[n] = defaults[n][0] + values[-1] = values[-1].encode("utf-8") + + # Regroup lists by column instead of row + for n, col in enumerate(self.COLUMNS): + features[col].append(values[n]) + if len(list(features.values())[0]) == batch_size: + yield features + features = {col: list() for col in self.COLUMNS} + + def _run_actual_batch(self, outputs, sess): + features, labels = sess.run(outputs) + batch = [features[k] for k in self.COLUMNS if k != self.LABEL] + batch.append(labels) + return batch + + def _verify_records( + self, + sess, + dataset, + file_indices, + defaults=tuple(DEFAULT_VALS), + label_name=LABEL, + batch_size=1, + num_epochs=1, + ): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + for expected_features in self._next_actual_batch(file_indices, batch_size, + num_epochs, defaults): + actual_features = sess.run(get_next) + + if label_name is not None: + expected_labels = expected_features.pop(label_name) + # Compare labels + self.assertAllEqual(expected_labels, actual_features[1]) + actual_features = actual_features[0] # Extract features dict from tuple + + for k in expected_features.keys(): + # Compare features + self.assertAllEqual(expected_features[k], actual_features[k]) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testMakeCSVDataset(self): + defaults = self.DEFAULTS + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Basic test: read from file 0. + dataset = self._make_csv_dataset(self._test_filenames[0], defaults) + self._verify_records(sess, dataset, [0]) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Basic test: read from file 1. + dataset = self._make_csv_dataset(self._test_filenames[1], defaults) + self._verify_records(sess, dataset, [1]) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. + dataset = self._make_csv_dataset(self._test_filenames, defaults) + self._verify_records(sess, dataset, range(self._num_files)) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, defaults, batch_size=2, num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def testMakeCSVDataset_withBadColumns(self): + """Tests that exception is raised when input is malformed. + """ + dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1] + defaults = self.DEFAULTS + + # Duplicate column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, column_names=dupe_columns) + + # Label key not one of column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, label_name="not_a_real_label") + + def testMakeCSVDataset_withNoLabel(self): + """Tests that CSV datasets can be created when no label is specified. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Make sure this works with no label key supplied. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=2, + num_epochs=10, + label_name=None) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + label_name=None) + + def testMakeCSVDataset_withNoHeader(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), header=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + header=False, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) + + def testMakeCSVDataset_withTypes(self): + """Tests that defaults can be a dtype instead of a Tensor for required vals. + """ + defaults = [d for d in self.COLUMN_TYPES[:-1]] + defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset(self._test_filenames, defaults) + self._verify_records(sess, dataset, range(self._num_files)) + + def testMakeCSVDataset_withNoColNames(self): + """Tests that datasets can be created when column names are not specified. + + In that case, we should infer the column names from the header lines. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + column_names=None, + batch_size=2, + num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def testMakeCSVDataset_withTypeInferenceMismatch(self): + # Test that error is thrown when num fields doesn't match columns + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, + column_names=self.COLUMNS + ["extra_name"], + defaults=None, + batch_size=2, + num_epochs=10) + + def testMakeCSVDataset_withTypeInference(self): + """Tests that datasets can be created when no defaults are specified. + + In that case, we should infer the types from the first N records. + """ + # Test that it works with standard test files (with header, etc) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + self._test_filenames, defaults=None, batch_size=2, num_epochs=10) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + defaults=[[], [], [], [], [""]]) + + def testMakeCSVDataset_withTypeInferenceTricky(self): + # Test on a deliberately tricky file (type changes as we read more rows, and + # there are null values) + fn = os.path.join(self.get_temp_dir(), "file.csv") + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, + dtypes.string, dtypes.string + ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[None, None, None, "NAN", "", + "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""], + ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']] + expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""], + [123, 2, 2**64, 123.4, "", "cd,efg"]] + for row in expected: + row[-1] = row[-1].encode("utf-8") # py3 expects byte strings + row[-2] = row[-2].encode("utf-8") # py3 expects byte strings + self._write_file("file.csv", [col_names] + rows) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + na_value="NAN", + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + print(features["col%d" % i].dtype, expected_dtypes[i]) + assert features["col%d" % i].dtype == expected_dtypes[i] + for i in range(len(rows)): + assert sess.run(features) == dict(zip(col_names, expected[i])) + + def testMakeCSVDataset_withTypeInferenceAllTypes(self): + # Test that we make the correct inference for all types with fallthrough + fn = os.path.join(self.get_temp_dir(), "file.csv") + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, + dtypes.string, dtypes.string + ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] + expected = [[ + 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + ]] + self._write_file("file.csv", [col_names] + rows) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + na_value="NAN", + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + self.assertAllEqual(features["col%d" % i].dtype, expected_dtypes[i]) + for i in range(len(rows)): + self.assertAllEqual( + sess.run(features), dict(zip(col_names, expected[i]))) + + def testMakeCSVDataset_withSelectColsError(self): + data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + col_names = ["col%d" % i for i in range(5)] + fn = self._write_file("file.csv", [col_names] + data) + with self.assertRaises(ValueError): + # Mismatch in number of defaults and number of columns selected, + # should raise an error + self._make_csv_dataset( + fn, + defaults=[[0]] * 5, + column_names=col_names, + label_name=None, + select_cols=[1, 3]) + with self.assertRaises(ValueError): + # Invalid column name should raise an error + self._make_csv_dataset( + fn, + defaults=[[0]], + column_names=col_names, + label_name=None, + select_cols=["invalid_col_name"]) + + def testMakeCSVDataset_withSelectCols(self): + data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + col_names = ["col%d" % i for i in range(5)] + fn = self._write_file("file.csv", [col_names] + data) + # If select_cols is specified, should only yield a subset of columns + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=[[0], [0]], + column_names=col_names, + label_name=None, + select_cols=[1, 3]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + # Can still do default inference with select_cols + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=col_names, + label_name=None, + select_cols=[1, 3]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + # Can still do column name inference + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + select_cols=[1, 3]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + # Can specify column names instead of indices + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + select_cols=[col_names[1], col_names[3]]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + + def testMakeCSVDataset_withShuffle(self): + total_records = self._num_files * self._num_records + defaults = self.DEFAULTS + for batch_size in [1, 2]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Test that shuffling with the same seed produces the same result + dataset1 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + dataset2 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Test that shuffling with a different seed produces different results + dataset1 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + dataset2 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=6) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + all_equal = False + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + +class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length, + drop_final_batch, + use_parser_fn): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) + + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for f, r in next_records: + record = self._record(f, r) + if use_parser_fn: + record = record[1:] + record_batch.append(record) + batch_index += 1 + if len(record_batch) == batch_size: + yield record_batch + record_batch = [] + batch_index = 0 + if record_batch and not drop_final_batch: + yield record_batch + + def _verify_records(self, + sess, + outputs, + batch_size, + file_index, + num_epochs, + interleave_cycle_length, + drop_final_batch, + use_parser_fn): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length, + drop_final_batch, use_parser_fn): + actual_batch = sess.run(outputs) + self.assertAllEqual(expected_batch, actual_batch) + + def _read_test(self, batch_size, num_epochs, file_index=None, + num_parallel_reads=1, drop_final_batch=False, parser_fn=False): + if file_index is None: + file_pattern = self.test_filenames + else: + file_pattern = self.test_filenames[file_index] + + if parser_fn: + fn = lambda x: string_ops.substr(x, 1, 999) + else: + fn = None + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs = readers.make_tf_record_dataset( + file_pattern=file_pattern, + num_epochs=num_epochs, + batch_size=batch_size, + parser_fn=fn, + num_parallel_reads=num_parallel_reads, + drop_final_batch=drop_final_batch, + shuffle=False).make_one_shot_iterator().get_next() + self._verify_records( + sess, outputs, batch_size, file_index, num_epochs=num_epochs, + interleave_cycle_length=num_parallel_reads, + drop_final_batch=drop_final_batch, use_parser_fn=parser_fn) + with self.assertRaises(errors.OutOfRangeError): + sess.run(outputs) + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + # Basic test: read from file 0. + self._read_test(batch_size, num_epochs, 0) + + # Basic test: read from file 1. + self._read_test(batch_size, num_epochs, 1) + + # Basic test: read from both files. + self._read_test(batch_size, num_epochs) + + # Basic test: read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8) + + def testDropFinalBatch(self): + for batch_size in [1, 2, 10]: + for num_epochs in [1, 3]: + # Read from file 0. + self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) + + # Read from both files. + self._read_test(batch_size, num_epochs, drop_final_batch=True) + + # Read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + drop_final_batch=True) + + def testParserFn(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for drop_final_batch in [False, True]: + self._read_test(batch_size, num_epochs, parser_fn=True, + drop_final_batch=drop_final_batch) + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + parser_fn=True, drop_final_batch=drop_final_batch) + + def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, + seed=None): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + num_parallel_reads=num_parallel_reads, + shuffle=True, + shuffle_seed=seed) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + sess.run(iterator.initializer) + first_batches = [] + try: + while True: + first_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + sess.run(iterator.initializer) + second_batches = [] + try: + while True: + second_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + self.assertEqual(len(first_batches), len(second_batches)) + if seed is not None: + # if you set a seed, should get the same results + for i in range(len(first_batches)): + self.assertAllEqual(first_batches[i], second_batches[i]) + + expected = [] + for f in range(self._num_files): + for r in range(self._num_records): + expected.extend([self._record(f, r)] * num_epochs) + + for batches in (first_batches, second_batches): + actual = [] + for b in batches: + actual.extend(b) + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testShuffle(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for num_parallel_reads in [1, 2]: + # Test that all expected elements are produced + self._shuffle_test(batch_size, num_epochs, num_parallel_reads) + # Test that elements are produced in a consistent order if + # you specify a seed. + self._shuffle_test(batch_size, num_epochs, num_parallel_reads, + seed=21345) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..805a7c7b7384d53cc166a48ba243502ef8643280 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,218 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ReadBatchFeaturesTestBase(test.TestCase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string) + }, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op + ]) + + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): + actual_batch = self._next_actual_batch(sess) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 3c7b46629edb13459766b5ef3f392e8d00ad4db8..bdc003a8a5bd646e1d5c598befa2694da512d0a9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -18,58 +18,163 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +import time +from absl.testing import parameterized from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat -class ResampleTest(test.TestCase): +def _time_resampling( + test_obj, data_np, target_dist, init_dist, num_to_sample): + dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat() - def testInitialKnownDistribution(self): - self._testDistribution(initial_known=True) + # Reshape distribution via rejection sampling. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist, + seed=142)) - def testInitialNotKnownDistribution(self): - self._testDistribution(initial_known=False) + get_next = dataset.make_one_shot_iterator().get_next() - def _testDistribution(self, initial_known): + with test_obj.test_session() as sess: + start_time = time.time() + for _ in xrange(num_to_sample): + sess.run(get_next) + end_time = time.time() + + return end_time - start_time + + +class ResampleTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ("InitialDistributionKnown", True), + ("InitialDistributionUnknown", False)) + def testDistribution(self, initial_known): classes = np.random.randint(5, size=(20000,)) # Uniformly sampled target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] initial_dist = [0.2] * 5 if initial_known else None - iterator = (dataset_ops.Dataset.from_tensor_slices(classes).shuffle( - 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).apply( - resampling.rejection_resample( - target_dist=target_dist, - initial_dist=initial_dist, - class_func=lambda c, _: c, - seed=27)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + classes = math_ops.to_int64(classes) # needed for Windows build. + dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( + 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() + + get_next = dataset.apply( + resampling.rejection_resample( + target_dist=target_dist, + initial_dist=initial_dist, + class_func=lambda c, _: c, + seed=27)).make_one_shot_iterator().get_next() with self.test_session() as sess: - sess.run(init_op) returned = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - returned.append(sess.run(get_next)) + while len(returned) < 4000: + returned.append(sess.run(get_next)) returned_classes, returned_classes_and_data = zip(*returned) _, returned_data = zip(*returned_classes_and_data) self.assertAllEqual([compat.as_bytes(str(c)) for c in returned_classes], returned_data) total_returned = len(returned_classes) - # Subsampling rejects a large percentage of the initial data in - # this case. - self.assertGreater(total_returned, 20000 * 0.2) class_counts = np.array([ len([True for v in returned_classes if v == c]) for c in range(5)]) returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) + @parameterized.named_parameters( + ("OnlyInitial", True), + ("NotInitial", False)) + def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): + init_dist = [0.5, 0.5] + target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test that this works. + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + def testRandomClasses(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution. + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Apply a random mapping that preserves the data distribution. + def _remap_fn(_): + return math_ops.cast(random_ops.random_uniform([1]) * num_classes, + dtypes.int32)[0] + dataset = dataset.map(_remap_fn) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + classes, _ = zip(*returned) + bincount = np.bincount( + np.array(classes), + minlength=num_classes).astype(np.float32) / len(classes) + + self.assertAllClose(target_dist, bincount, atol=1e-2) + + +class ResampleDatasetBenchmark(test.Benchmark): + + def benchmarkResamplePerformance(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 1000 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + resample_time = _time_resampling( + self, data_np, target_dist, init_dist, num_to_sample=1000) + + self.report_benchmark( + iters=1000, wall_time=resample_time, name="benchmark_resample") + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index e0494736b72ae52f586cb80d42a5c1e50ac17a61..eb2ceff893543f710d4f0246adf4e6367a2deeb0 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -24,24 +24,31 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test class ScanDatasetTest(test.TestCase): - def _count(self, start, step): - return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( - scan_ops.scan(start, lambda state, _: (state + step, state))) + def _counting_dataset(self, start, scan_fn): + return dataset_ops.Dataset.from_tensors(0).repeat().apply( + scan_ops.scan(start, scan_fn)) def testCount(self): + def make_scan_fn(step): + return lambda state, _: (state + step, state) + start = array_ops.placeholder(dtypes.int32, shape=[]) step = array_ops.placeholder(dtypes.int32, shape=[]) take = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = self._count(start, step).take(take).make_initializable_iterator() + iterator = self._counting_dataset( + start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() with self.test_session() as sess: @@ -57,19 +64,55 @@ class ScanDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + @test_util.run_in_graph_and_eager_modes() def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) ).make_one_shot_iterator() + + if context.executing_eagerly(): + next_element = iterator.get_next + else: + get_next = iterator.get_next() + next_element = lambda: get_next + + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(2, self.evaluate(next_element())) + self.assertEqual(3, self.evaluate(next_element())) + self.assertEqual(5, self.evaluate(next_element())) + self.assertEqual(8, self.evaluate(next_element())) + + def testSparseCount(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def make_scan_fn(step): + return lambda state, _: (_sparse(state.values[0] + step), state) + + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._counting_dataset( + _sparse(start), + make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() with self.test_session() as sess: - self.assertEqual(1, sess.run(next_element)) - self.assertEqual(1, sess.run(next_element)) - self.assertEqual(2, sess.run(next_element)) - self.assertEqual(3, sess.run(next_element)) - self.assertEqual(5, sess.run(next_element)) - self.assertEqual(8, sess.run(next_element)) + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element).values[0]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with @@ -125,7 +168,7 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerialzationTest( +class ScanDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_dataset(self, num_elements): diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 36ddf3004237ed042f21d691d83eafbaa20621e6..d0cb203a3afd2775756c8542a1e86faedc5cee53 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -47,6 +47,11 @@ class SequenceDatasetSerializationTest( # Skip nothing self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + def testInvalidSkip(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -69,6 +74,11 @@ class SequenceDatasetSerializationTest( # Take nothing self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + def testInvalidTake(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -100,6 +110,12 @@ class SequenceDatasetSerializationTest( # Test repeat empty dataset self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + def testInvalidRepeat(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0), + None, 0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index bcc644c0971854d948025009dc7add2fea214048..1b67a33f04b0c2ac80402e163005123a4b3e4400 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ShuffleDatasetSerializationTest( @@ -50,26 +52,100 @@ class ShuffleDatasetSerializationTest( num_repeats = 5 num_outputs = range_limit * num_repeats buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False # pylint: disable=cell-var-from-loop # pylint: disable=g-long-lambda - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) class ShuffleAndRepeatTest( diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33c48e20bea53b88d69a59e715af38b22dd2cbd4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -0,0 +1,242 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import sliding +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class SlideDatasetTest(test.TestCase): + + def testSlideDataset(self): + """Test an dataset that maps a TF function across its input elements.""" + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + window_size = array_ops.placeholder(dtypes.int64, shape=[]) + stride = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(count) -> _SlideDataset(window_size, stride). + iterator = (dataset_ops.Dataset.from_tensor_slices(components) + .map(_map_fn) + .repeat(count) + .apply(sliding.sliding_window_batch(window_size, stride)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + # Slide over a finite input, where the window_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) + # Same formula with convolution layer. + num_batches = (20 * 7 - 14) // 7 + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*7 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over a finite input, where the window_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) + + num_batches = (20 * 7 - 17) // 9 + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(17): + self.assertAllEqual(component[(i*9 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over a finite input, which is less than window_size, + # should fail straight away. + sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty window_size should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0}) + + # Invalid stride should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) + + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSlideSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSlideSparseWithDifferentDenseShapes(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=array_ops.expand_dims( + math_ops.range(i, dtype=dtypes.int64), 1), + values=array_ops.fill([math_ops.to_int32(i)], i), + dense_shape=[i]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected_indices = [] + expected_values = [] + for j in range(5): + for k in range(i * 3 + j): + expected_indices.append([j, k]) + expected_values.append(i * 3 + j) + expected = sparse_tensor.SparseTensorValue( + indices=expected_indices, + values=expected_values, + dense_shape=[5, i * 3 + 5 - 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedSlideSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .apply(sliding.sliding_window_batch(4, 2)) + .apply(sliding.sliding_window_batch(3, 1)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + # Slide: 1st batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], + [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + # Slide: 2nd batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], + [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSlideShapeError(self): + + def generator(): + yield [1.0, 2.0, 3.0] + yield [4.0, 5.0, 6.0] + yield [7.0, 8.0, 9.0, 10.0] + + iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, + output_shapes=[None]) + .apply(sliding.sliding_window_batch(3, 1)) + .make_initializable_iterator()) + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot batch tensors with different shapes in component 0. " + r"First element had shape \[3\] and element 2 had shape \[4\]."): + sess.run(next_element) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e26cef8ec522c7e69a0c19b2b30a969bbfc0ad78..4148addf2878c99f47ebe1454edf69ad7f38dfbc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -22,6 +22,7 @@ import os import sqlite3 +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTest(test.TestCase): +class SqlDatasetTestBase(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): dataset = readers.SqlDataset(self.driver_name, self.data_source_name, @@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase): conn.commit() conn.close() + +class SqlDatasetTest(SqlDatasetTestBase): + # Test that SqlDataset can read from a database table. def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, @@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase): sess.run(get_next) +class SqlDatasetSerializationTest( + SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 07bdf920446e953c2a1abaf495d2e9e1256106fd..17b6644759e53f84b23e070a71267aa15dcffe49 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTest(test.TestCase): +class StatsDatasetTestBase(test.TestCase): def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() @@ -49,18 +50,21 @@ class StatsDatasetTest(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + +class StatsDatasetTest(StatsDatasetTestBase): + def testBytesProduced(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) + stats_ops.bytes_produced_stats("bytes_produced")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) expected_sum = 0.0 for i in range(100): self.assertAllEqual( @@ -76,16 +80,16 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) def testLatencyStats(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -95,16 +99,15 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) def testReinitialize(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run(stats_aggregator_subscriber) for j in range(5): sess.run(iterator.initializer) for i in range(100): @@ -130,17 +133,17 @@ class StatsDatasetTest(test.TestCase): sess.run(next_element) def testMultipleTags(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")).apply( - stats_ops.latency_stats("record_latency_2")) + stats_ops.latency_stats("record_latency_2")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -154,17 +157,17 @@ class StatsDatasetTest(test.TestCase): sess.run(summary_t), "record_latency_2", 100.0) def testRepeatedTags(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -174,19 +177,17 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) def testMultipleIteratorsSameAggregator(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator_0 = dataset.make_initializable_iterator() iterator_1 = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), - stats_aggregator.subscribe(iterator_1)] next_element = iterator_0.get_next() + iterator_1.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator_0.initializer, iterator_1.initializer, - stats_aggregator_subscribers]) + sess.run([iterator_0.initializer, iterator_1.initializer]) for i in range(100): self.assertEqual(i * 2, sess.run(next_element)) self._assertSummaryHasCount( @@ -195,19 +196,44 @@ class StatsDatasetTest(test.TestCase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) - def testMultipleStatsAggregatorsSameIteratorFail(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) + +class FeatureStatsDatasetTest( + StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=True).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator_0 = stats_ops.StatsAggregator() - stats_aggregator_1 = stats_ops.StatsAggregator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run(stats_aggregator_0.subscribe(iterator)) - # TODO(mrry): Consider making this allowable (and also allowing - # aggregators to unsubscribe). - with self.assertRaises(errors.FailedPreconditionError): - sess.run(stats_aggregator_1.subscribe(iterator)) + sess.run(iterator.initializer) + for _ in range(total_records // batch_size): + sess.run(next_element) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:features", total_records * 3) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:feature-values", + self._sum_keywords(1) * num_epochs + 2 * total_records) class StatsDatasetSerializationTest( @@ -218,6 +244,14 @@ class StatsDatasetSerializationTest( lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( stats_ops.bytes_produced_stats("bytes_produced")) + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + def testBytesStatsDatasetSaveableCore(self): num_outputs = 100 self.run_core_tests( @@ -235,6 +269,14 @@ class StatsDatasetSerializationTest( return dataset_ops.Dataset.range(num_elements).apply( stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + def testLatencyStatsDatasetSaveableCore(self): num_outputs = 100 @@ -253,5 +295,9 @@ class StatsDatasetSerializationTest( None, num_outputs) +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9167cb3379bba5cb1ba76a96549395c45dca9e35 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import numpy as np + +from tensorflow.contrib.data.python.ops import threadpool +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class OverrideThreadpoolDatasetTest(test.TestCase): + + def testNumThreads(self): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + for num_threads in [1, 2, 4, 8, 16]: + + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, display_name="private_thread_pool_%d" % num_threads)) + + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c603ecc5ab27a711557376246b093fd5f80f8aec --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -0,0 +1,117 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import writers +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class TFRecordWriterTest(test.TestCase): + + def setUp(self): + super(TFRecordWriterTest, self).setUp() + self._num_records = 7 + self.filename = array_ops.placeholder(dtypes.string, shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + + input_dataset = readers.TFRecordDataset([self.filename], + self.compression_type) + self.writer = writers.TFRecordWriter( + self._outputFilename(), self.compression_type).write(input_dataset) + + def _record(self, i): + return compat.as_bytes("Record %d" % (i)) + + def _createFile(self, options=None): + filename = self._inputFilename() + writer = python_io.TFRecordWriter(filename, options) + for i in range(self._num_records): + writer.write(self._record(i)) + writer.close() + return filename + + def _inputFilename(self): + return os.path.join(self.get_temp_dir(), "tf_record.in.txt") + + def _outputFilename(self): + return os.path.join(self.get_temp_dir(), "tf_record.out.txt") + + def testWrite(self): + with self.test_session() as sess: + sess.run( + self.writer, feed_dict={ + self.filename: self._createFile(), + }) + for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): + self.assertAllEqual(self._record(i), r) + + def testWriteZLIB(self): + options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) + with self.test_session() as sess: + sess.run( + self.writer, + feed_dict={ + self.filename: self._createFile(options), + self.compression_type: "ZLIB", + }) + for i, r in enumerate( + tf_record.tf_record_iterator(self._outputFilename(), options=options)): + self.assertAllEqual(self._record(i), r) + + def testWriteGZIP(self): + options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) + with self.test_session() as sess: + sess.run( + self.writer, + feed_dict={ + self.filename: self._createFile(options), + self.compression_type: "GZIP", + }) + for i, r in enumerate( + tf_record.tf_record_iterator(self._outputFilename(), options=options)): + self.assertAllEqual(self._record(i), r) + + def testFailDataset(self): + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write("whoops") + + def testFailDType(self): + input_dataset = dataset_ops.Dataset.from_tensors(10) + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write(input_dataset) + + def testFailShape(self): + input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write(input_dataset) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index b488357f226d0922bba3799cc1f4b5c75e2e8328..33b7a75046cf2acfa3d787833b907aa2b28dbdca 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -12,18 +12,26 @@ load( load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") py_library( - name = "dataset_ops", - srcs = [ - "counter.py", - "get_single_element.py", + name = "counter", + srcs = ["counter.py"], + srcs_version = "PY2AND3", + deps = [ + ":scan_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", ], +) + +py_library( + name = "get_single_element", + srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ - ":transformation_ops", "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) @@ -37,6 +45,27 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -66,18 +95,27 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":dataset_ops", + ":batching", + ":gen_dataset_ops", + ":interleave_ops", + ":shuffle_ops", + ":stats_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", ], ) @@ -88,49 +126,212 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":random_ops", - ":transformation_ops", "//tensorflow/python/data/ops:dataset_ops", ], ) py_library( - name = "transformation_ops", - srcs = [ - "batching.py", - "enumerate_ops.py", - "error_ops.py", - "grouping.py", - "interleave_ops.py", - "resampling.py", - "scan_ops.py", - "stats_ops.py", - "unique.py", + name = "batching", + srcs = ["batching.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], +) + +py_library( + name = "enumerate_ops", + srcs = ["enumerate_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "error_ops", + srcs = ["error_ops.py"], srcs_version = "PY2AND3", deps = [ ":contrib_op_loader", ":gen_dataset_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "grouping", + srcs = ["grouping.py"], + srcs_version = "PY2AND3", + deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", - "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "interleave_ops", + srcs = ["interleave_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + ":random_ops", + "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "resampling", + srcs = ["resampling.py"], + srcs_version = "PY2AND3", + deps = [ + ":batching", + ":interleave_ops", + ":scan_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:logging_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) +py_library( + name = "scan_ops", + srcs = ["scan_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "sliding", + srcs = ["sliding.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "stats_ops", + srcs = ["stats_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "threadpool", + srcs = ["threadpool.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + ], +) + +py_library( + name = "unique", + srcs = [ + "unique.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "writers", + srcs = [ + "writers.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + tf_gen_op_wrapper_py( name = "gen_dataset_ops", out = "gen_dataset_ops.py", @@ -167,17 +368,37 @@ py_library( srcs = ["prefetching_ops.py"], deps = [ ":contrib_op_loader", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +py_library( + name = "dataset_ops", + deps = [ + ":batching", + ":counter", + ":enumerate_ops", + ":error_ops", + ":get_single_element", + ":grouping", + ":interleave_ops", + ":optimization", + ":prefetching_ops", + ":readers", + ":resampling", + ":scan_ops", + ":shuffle_ops", + ":sliding", + ":stats_ops", + ":threadpool", + ":unique", + ":writers", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6eb512dec67cb7b9c8c4518d03aee0b436205f9a..052618e08c8f204613db5a20d42e078f17f12840 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -28,6 +30,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation def dense_to_sparse_batch(batch_size, row_shape): @@ -79,28 +82,95 @@ def dense_to_sparse_batch(batch_size, row_shape): return _apply_fn +class UnbatchDataset(dataset_ops.Dataset): + """A dataset that splits the elements of its input into multiple elements.""" + + def __init__(self, input_dataset): + """See `unbatch()` for more details.""" + super(UnbatchDataset, self).__init__() + flat_shapes = nest.flatten(input_dataset.output_shapes) + if any(s.ndims == 0 for s in flat_shapes): + raise ValueError("Cannot unbatch an input with scalar components.") + known_batch_dim = tensor_shape.Dimension(None) + for s in flat_shapes: + try: + known_batch_dim = known_batch_dim.merge_with(s[0]) + except ValueError: + raise ValueError("Cannot unbatch an input whose components have " + "different batch sizes.") + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.unbatch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return nest.map_structure(lambda s: s[1:], + self._input_dataset.output_shapes) + + @property + def output_types(self): + return self._input_dataset.output_types + + def unbatch(): - """A Transformation which splits the elements of a dataset. + """Splits elements of a dataset into multiple elements on the batch dimension. For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, - where `B` may vary from element to element, then for each element in - the dataset, the unbatched dataset will contain `B` consecutive elements + where `B` may vary for each input element, then for each element in the + dataset, the unbatched dataset will contain `B` consecutive elements of shape `[a0, a1, ...]`. + ```python + # NOTE: The following example uses `{ ... }` to represent the contents + # of a dataset. + a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } + + a.apply(tf.contrib.data.unbatch()) == { + 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} + ``` + Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - - def unbatch_map(arg, *rest): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + if not sparse.any_sparse(dataset.output_classes): + return UnbatchDataset(dataset) + + # NOTE(mrry): We must ensure that any SparseTensors in `dataset` + # are normalized to the rank-1 dense representation, so that the + # sparse-oblivious unbatching logic will slice them + # appropriately. This leads to a somewhat inefficient re-encoding step + # for all SparseTensor components. + # TODO(mrry): Consider optimizing this in future + # if it turns out to be a bottleneck. + def normalize(arg, *rest): if rest: - return dataset_ops.Dataset.from_tensor_slices((arg,) + rest) + return sparse.serialize_many_sparse_tensors((arg,) + rest) else: - return dataset_ops.Dataset.from_tensor_slices(arg) + return sparse.serialize_many_sparse_tensors(arg) - return dataset.flat_map(map_func=unbatch_map) + normalized_dataset = dataset.map(normalize) + + # NOTE(mrry): Our `map()` has lost information about the sparseness + # of any SparseTensor components, so re-apply the structure of the + # original dataset. + restructured_dataset = _RestructuredDataset( + normalized_dataset, + dataset.output_types, + dataset.output_shapes, + dataset.output_classes, + allow_unsafe_cast=True) + return UnbatchDataset(restructured_dataset) return _apply_fn @@ -147,6 +217,8 @@ def filter_irregular_batches(batch_size): return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -179,12 +251,16 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time + # after 6/30/2018. batched = dataset.batch(batch_size) return filter_irregular_batches(batch_size)(batched) return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.") def padded_batch_and_drop_remainder(batch_size, padded_shapes, padding_values=None): @@ -213,6 +289,8 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` + # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) return filter_irregular_batches(batch_size)(batched) @@ -238,11 +316,8 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + row_shape=convert.partial_shape_to_tensor(self._row_shape), + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -264,7 +339,8 @@ class _RestructuredDataset(dataset_ops.Dataset): dataset, output_types, output_shapes=None, - output_classes=None): + output_classes=None, + allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: @@ -282,22 +358,27 @@ class _RestructuredDataset(dataset_ops.Dataset): If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. + allow_unsafe_cast: (Optional.) If `True`, the caller may switch the + reported output types and shapes of the restructured dataset, e.g. to + switch a sparse tensor represented as `tf.variant` to its user-visible + type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() - self._dataset = dataset - - # Validate that the types are compatible. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - flat_original_types = nest.flatten(dataset.output_types) - flat_new_types = nest.flatten(output_types) - if flat_original_types != flat_new_types: - raise ValueError( - "Dataset with output types %r cannot be restructured to have output " - "types %r" % (dataset.output_types, output_types)) + self._input_dataset = dataset + + if not allow_unsafe_cast: + # Validate that the types are compatible. + output_types = nest.map_structure(dtypes.as_dtype, output_types) + flat_original_types = nest.flatten(dataset.output_types) + flat_new_types = nest.flatten(output_types) + if flat_original_types != flat_new_types: + raise ValueError( + "Dataset with output types %r cannot be restructured to have " + "output types %r" % (dataset.output_types, output_types)) self._output_types = output_types @@ -307,18 +388,19 @@ class _RestructuredDataset(dataset_ops.Dataset): nest.flatten( dataset.output_shapes)) else: - # Validate that the shapes are compatible. - nest.assert_same_structure(output_types, output_shapes) - flat_original_shapes = nest.flatten(dataset.output_shapes) - flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) - - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes): - if not original_shape.is_compatible_with(new_shape): - raise ValueError( - "Dataset with output shapes %r cannot be restructured to have " - "incompatible output shapes %r" % (dataset.output_shapes, - output_shapes)) + if not allow_unsafe_cast: + # Validate that the shapes are compatible. + nest.assert_same_structure(output_types, output_shapes) + flat_original_shapes = nest.flatten(dataset.output_shapes) + flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) + + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes): + if not original_shape.is_compatible_with(new_shape): + raise ValueError( + "Dataset with output shapes %r cannot be restructured to have " + "incompatible output shapes %r" % (dataset.output_shapes, + output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: @@ -330,7 +412,7 @@ class _RestructuredDataset(dataset_ops.Dataset): self._output_classes = output_classes def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access + return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access @property def output_classes(self): @@ -345,37 +427,81 @@ class _RestructuredDataset(dataset_ops.Dataset): return self._output_shapes +def assert_element_shape(expected_shapes): + """Assert the shape of this `Dataset`. + + ```python + shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) + print(result.output_shapes) # ==> "((16, 256), )" + ``` + + If dataset shapes and expected_shape, are fully defined, assert they match. + Otherwise, add assert op that will validate the shapes when tensors are + evaluated, and set shapes on tensors, respectively. + + Args: + expected_shapes: A nested structure of `tf.TensorShape` objects. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _check_shape(*elements): + flatten_tensors = nest.flatten(elements) + flatten_shapes = nest.flatten(expected_shapes) + checked_tensors = [ + with_shape(shape, tensor) + for shape, tensor in zip(flatten_shapes, flatten_tensors) + ] + return nest.pack_sequence_as(elements, checked_tensors) + + def _apply_fn(dataset): + return _RestructuredDataset( + dataset.map(_check_shape), + dataset.output_types, + output_shapes=expected_shapes, + output_classes=dataset.output_classes) + + return _apply_fn + + class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, + drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size = ops.convert_to_tensor( + self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + self._drop_remainder_t = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + + self._batch_size = batch_size + self._drop_remainder = drop_remainder def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, - batch_size=self._batch_size, - num_parallel_batches=self._num_parallel_batches, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + batch_size=self._batch_size_t, + num_parallel_calls=self._num_parallel_calls_t, + drop_remainder=self._drop_remainder_t, + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property def output_shapes(self): + dim = self._batch_size if self._drop_remainder else None return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(tensor_util.constant_value( - self._batch_size)).concatenate(s) + tensor_shape.vector(dim).concatenate(s) for s in nest.flatten(self._output_shapes) ]) @@ -384,7 +510,11 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): return self._output_types -def map_and_batch(map_func, batch_size, num_parallel_batches=1): +def map_and_batch(map_func, + batch_size, + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -400,18 +530,37 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py index 63226fe78163c59025623a362d17c400fbe57c67..6ef65f9624601286691505a795a86dd6226eead1 100644 --- a/tensorflow/contrib/data/python/ops/counter.py +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import ops def Counter(start=0, step=1, dtype=dtypes.int64): - """Creates a `Dataset` of a `step`-separated count startin from `start`. + """Creates a `Dataset` that counts from `start` in steps of size `step`. For example: @@ -38,12 +38,13 @@ def Counter(start=0, step=1, dtype=dtypes.int64): ``` Args: - start: starting value for count. - step: step size. - dtype: counter data type. + start: (Optional.) The starting value for the counter. Defaults to 0. + step: (Optional.) The step size for the counter. Defaults to 1. + dtype: (Optional.) The data type for counter elements. Defaults to + `tf.int64`. Returns: - A `Dataset` of scalar elements. + A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py deleted file mode 100644 index 214641bb9a62e6cbdece07b511864a5cff20944d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ /dev/null @@ -1,690 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gen_io_ops -from tensorflow.python.util import deprecation - - -class Dataset(dataset_ops.Dataset): - """Represents a potentially large set of elements. - - A `Dataset` can be used to represent an input pipeline as a - collection of elements (nested structures of tensors) and a "logical - plan" of transformations that act on those elements. - """ - - def __init__(self, dataset): - super(Dataset, self).__init__() - self._dataset = dataset - - @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") - def make_dataset_resource(self): - return self._as_variant_tensor() - - def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access - - @property - def output_classes(self): - return self._dataset.output_classes - - @property - def output_shapes(self): - return self._dataset.output_shapes - - @property - def output_types(self): - return self._dataset.output_types - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") - def from_tensors(tensors): - """Creates a `Dataset` with a single element, comprising the given tensors. - - Args: - tensors: A nested structure of tensors. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TensorDataset(tensors)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") - def from_tensor_slices(tensors): - """Creates a `Dataset` whose elements are slices of the given tensors. - - Args: - tensors: A nested structure of tensors, each having the same size in the - 0th dimension. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TensorSliceDataset(tensors)) - - @staticmethod - @deprecation.deprecated(None, - "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") - def from_sparse_tensor_slices(sparse_tensor): - """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. - - Args: - sparse_tensor: A `tf.SparseTensor`. - - Returns: - A `Dataset` of rank-(N-1) sparse tensors. - """ - return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") - def from_generator(generator, output_types, output_shapes=None): - """Creates a `Dataset` whose elements are generated by `generator`. - - The `generator` argument must be a callable object that returns - an object that support the `iter()` protocol (e.g. a generator function). - The elements generated by `generator` must be compatible with the given - `output_types` and (optional) `output_shapes` arguments. - - For example: - - ```python - import itertools - - def gen(): - for i in itertools.count(1): - yield (i, [1] * i) - - ds = Dataset.from_generator( - gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) - value = ds.make_one_shot_iterator().get_next() - - sess.run(value) # (1, array([1])) - sess.run(value) # (2, array([1, 1])) - ``` - - Args: - generator: A callable object that takes no arguments and returns an - object that supports the `iter()` protocol. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element yielded by `generator`. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` - objects corresponding to each component of an element yielded by - `generator`. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.Dataset.from_generator( - generator, output_types, output_shapes)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") - def range(*args): - """Creates a `Dataset` of a step-separated range of values. - - For example: - - ```python - Dataset.range(5) == [0, 1, 2, 3, 4] - Dataset.range(2, 5) == [2, 3, 4] - Dataset.range(1, 5, 2) == [1, 3] - Dataset.range(1, 5, -2) == [] - Dataset.range(5, 1) == [] - Dataset.range(5, 1, -2) == [5, 3] - ``` - - Args: - *args: follow same semantics as python's xrange. - len(args) == 1 -> start = 0, stop = args[0], step = 1 - len(args) == 2 -> start = args[0], stop = args[1], step = 1 - len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] - - Returns: - A `RangeDataset`. - - Raises: - ValueError: if len(args) == 0. - """ - return Dataset(dataset_ops.RangeDataset(*args)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") - def zip(datasets): - """Creates a `Dataset` by zipping together the given datasets. - - This method has similar semantics to the built-in `zip()` function - in Python, with the main difference being that the `datasets` - argument can be an arbitrary nested structure of `Dataset` objects. - For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { 4, 5, 6 } - c = { (7, 8), (9, 10), (11, 12) } - d = { 13, 14 } - - # The nested structure of the `datasets` argument determines the - # structure of elements in the resulting dataset. - Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } - Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } - - # The `datasets` argument may contain an arbitrary number of - # datasets. - Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), - (2, 5, (9, 10)), - (3, 6, (11, 12)) } - - # The number of elements in the resulting dataset is the same as - # the size of the smallest dataset in `datasets`. - Dataset.zip((a, d)) == { (1, 13), (2, 14) } - ``` - - Args: - datasets: A nested structure of datasets. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ZipDataset(datasets)) - - def concatenate(self, dataset): - """Creates a `Dataset` by concatenating given dataset with this dataset. - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { 4, 5, 6, 7 } - - # Input dataset and dataset to be concatenated should have same - # nested structures and output types. - # c = { (8, 9), (10, 11), (12, 13) } - # d = { 14.0, 15.0, 16.0 } - # a.concatenate(c) and a.concatenate(d) would result in error. - - a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } - ``` - - Args: - dataset: `Dataset` to be concatenated. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset)) - - def prefetch(self, buffer_size): - """Creates a `Dataset` that prefetches elements from this dataset. - - Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - maximum number elements that will be buffered when prefetching. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") - def list_files(file_pattern): - """A dataset of all files matching a pattern. - - Example: - If we had the following files on our filesystem: - - /path/to/dir/a.txt - - /path/to/dir/b.py - - /path/to/dir/c.py - If we pass "/path/to/dir/*.py" as the directory, the dataset would - produce: - - /path/to/dir/b.py - - /path/to/dir/c.py - - Args: - file_pattern: A string or scalar string `tf.Tensor`, representing - the filename pattern that will be matched. - - Returns: - A `Dataset` of strings corresponding to file names. - """ - return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) - - def repeat(self, count=None): - """Repeats this dataset `count` times. - - Args: - count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - number of times the elements of this dataset should be repeated. The - default behavior (if `count` is `None` or `-1`) is for the elements to - be repeated indefinitely. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") - def enumerate(self, start=0): - """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" - - return self.apply(enumerate_ops.enumerate_dataset(start)) - - def shuffle(self, buffer_size, seed=None): - """Randomly shuffles the elements of this dataset. - - Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - number of elements from this dataset from which the new - dataset will sample. - seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed)) - - def cache(self, filename=""): - """Caches the elements in this dataset. - - Args: - filename: A `tf.string` scalar `tf.Tensor`, representing the name of a - directory on the filesystem to use for caching tensors in this Dataset. - If a filename is not provided, the dataset will be cached in memory. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.CacheDataset(self._dataset, filename)) - - def take(self, count): - """Creates a `Dataset` with at most `count` elements from this dataset. - - Args: - count: A `tf.int64` scalar `tf.Tensor`, representing the number of - elements of this dataset that should be taken to form the new dataset. - If `count` is -1, or if `count` is greater than the size of this - dataset, the new dataset will contain all elements of this dataset. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TakeDataset(self._dataset, count)) - - def skip(self, count): - """Creates a `Dataset` that skips `count` elements from this dataset. - - Args: - count: A `tf.int64` scalar `tf.Tensor`, representing the number - of elements of this dataset that should be skipped to form the - new dataset. If `count` is greater than the size of this - dataset, the new dataset will contain no elements. If `count` - is -1, skips the entire dataset. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.SkipDataset(self._dataset, count)) - - def shard(self, num_shards, index): - """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. - - This dataset operator is very useful when running distributed training, as - it allows each worker to read a unique subset. - - When reading a single input file, you can skip elements as follows: - - ```python - d = tf.data.TFRecordDataset(FLAGS.input_file) - d = d.shard(FLAGS.num_workers, FLAGS.worker_index) - d = d.repeat(FLAGS.num_epochs) - d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) - ``` - - Important caveats: - - - Be sure to shard before you use any randomizing operator (such as - shuffle). - - Generally it is best if the shard operator is used early in the dataset - pipeline. For example, when reading from a set of TFRecord files, shard - before converting the dataset to input samples. This avoids reading every - file on every worker. The following is an example of an efficient - sharding strategy within a complete pipeline: - - ```python - d = tf.data.Dataset.list_files(FLAGS.pattern) - d = d.shard(FLAGS.num_workers, FLAGS.worker_index) - d = d.repeat(FLAGS.num_epochs) - d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.interleave(tf.data.TFRecordDataset, - cycle_length=FLAGS.num_readers, block_length=1) - d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) - ``` - - Args: - num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of - shards operating in parallel. - index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. - - Returns: - A `Dataset`. - - Raises: - ValueError: if `num_shards` or `index` are illegal values. Note: error - checking is done on a best-effort basis, and aren't guaranteed to be - caught upon dataset creation. (e.g. providing in a placeholder tensor - bypasses the early checking, and will instead result in an error during - a session.run call.) - """ - return Dataset(self._dataset.shard(num_shards, index)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.") - def ignore_errors(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" - - return self.apply(error_ops.ignore_errors()) - - def batch(self, batch_size): - """Combines consecutive elements of this dataset into batches. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size)) - - def padded_batch(self, batch_size, padded_shapes, padding_values=None): - """Combines consecutive elements of this dataset into padded batches. - - Like `Dataset.dense_to_sparse_batch()`, this method combines - multiple consecutive elements of this dataset, which might have - different shapes, into a single element. The tensors in the - resulting element have an additional outer dimension, and are - padded to the respective shape in `padded_shapes`. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - padded_shapes: A nested structure of `tf.TensorShape` or - `tf.int64` vector tensor-like objects representing the shape - to which the respective component of each input element should - be padded prior to batching. Any unknown dimensions - (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a - tensor-like object) will be padded to the maximum size of that - dimension in each batch. - padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`, representing the padding values to use for the - respective components. Defaults are `0` for numeric types and - the empty string for string types. - - Returns: - A `Dataset`. - """ - return Dataset( - dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, - padding_values)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") - def dense_to_sparse_batch(self, batch_size, row_shape): - """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" - - return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.group_by_window())`.") - def group_by_window(self, key_func, reduce_func, window_size): - """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" - - return self.apply( - grouping.group_by_window(key_func, reduce_func, window_size)) - - @deprecation.deprecated_args( - None, - "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", - "num_threads", "output_buffer_size") - def map(self, - map_func, - num_threads=None, - output_buffer_size=None, - num_parallel_calls=None): - """Maps `map_func` across this dataset. - - Args: - map_func: A function mapping a nested structure of tensors (having - shapes and types defined by `self.output_shapes` and - `self.output_types`) to another nested structure of tensors. - num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead. - output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`, - representing the maximum number of processed elements that will be - buffered. - num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, - representing the number elements to process in parallel. If not - specified, elements will be processed sequentially. - - Returns: - A `Dataset`. - """ - if num_threads is None and num_parallel_calls is None: - ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func)) - else: - if num_threads is None: - ret = Dataset( - dataset_ops.ParallelMapDataset(self._dataset, map_func, - num_parallel_calls)) - else: - ret = Dataset( - dataset_ops.ParallelMapDataset(self._dataset, map_func, - num_threads)) - if output_buffer_size is not None: - ret = ret.prefetch(output_buffer_size) - return ret - - def flat_map(self, map_func): - """Maps `map_func` across this dataset and flattens the result. - - Args: - map_func: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - `Dataset`. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func)) - - def interleave(self, map_func, cycle_length, block_length=1): - """Maps `map_func` across this dataset, and interleaves the results. - - For example, you can use `Dataset.interleave()` to process many input files - concurrently: - - ```python - # Preprocess 4 files concurrently, and interleave blocks of 16 records from - # each file. - filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] - dataset = (Dataset.from_tensor_slices(filenames) - .interleave(lambda x: - TextLineDataset(x).map(parse_fn, num_parallel_calls=1), - cycle_length=4, block_length=16)) - ``` - - The `cycle_length` and `block_length` arguments control the order in which - elements are produced. `cycle_length` controls the number of input elements - that are processed concurrently. If you set `cycle_length` to 1, this - transformation will handle one input element at a time, and will produce - identical results = to @{tf.data.Dataset.flat_map}. In general, - this transformation will apply `map_func` to `cycle_length` input elements, - open iterators on the returned `Dataset` objects, and cycle through them - producing `block_length` consecutive elements from each iterator, and - consuming the next input element each time it reaches the end of an - iterator. - - For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3, 4, 5 } - - # NOTE: New lines indicate "block" boundaries. - a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), - cycle_length=2, block_length=4) == { - 1, 1, 1, 1, - 2, 2, 2, 2, - 1, 1, - 2, 2, - 3, 3, 3, 3, - 4, 4, 4, 4, - 3, 3, - 4, 4, - 5, 5, 5, 5, - 5, 5, - } - ``` - - NOTE: The order of elements yielded by this transformation is - deterministic, as long as `map_func` is a pure function. If - `map_func` contains any stateful operations, the order in which - that state is accessed is undefined. - - Args: - map_func: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - `Dataset`. - cycle_length: The number of elements from this dataset that will be - processed concurrently. - block_length: The number of consecutive elements to produce from each - input element before cycling to another input element. - - Returns: - A `Dataset`. - """ - return Dataset( - dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, - block_length)) - - @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") - def unbatch(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" - - return self.apply(batching.unbatch()) - - def filter(self, predicate): - """Filters this dataset according to `predicate`. - - Args: - predicate: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - scalar `tf.bool` tensor. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.FilterDataset(self._dataset, predicate)) - - def apply(self, transformation_func): - """Apply a transformation function to this dataset. - - `apply` enables chaining of custom `Dataset` transformations, which are - represented as functions that take one `Dataset` argument and return a - transformed `Dataset`. - - For example: - - ``` - dataset = (dataset.map(lambda x: x ** 2) - .(group_by_window(key_func, reduce_func, window_size)) - .map(lambda x: x ** 3)) - ``` - - Args: - transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. - - Returns: - The `Dataset` returned by applying `transformation_func` to this dataset. - """ - dataset = transformation_func(self) - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`transformation_func` must return a Dataset.") - return Dataset(dataset) - - -def get_single_element(dataset): - """Returns the single element in `dataset` as a nested structure of tensors. - - This function enables you to use a @{tf.data.Dataset} in a stateless - "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. - This can be useful when your preprocessing transformations are expressed - as a `Dataset`, and you want to use the transformation at serving time. - For example: - - ```python - input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) - - def preprocessing_fn(input_str): - # ... - return image, label - - dataset = (tf.data.Dataset.from_tensor_slices(input_batch) - .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) - .batch(BATCH_SIZE)) - - image_batch, label_batch = tf.contrib.data.get_single_element(dataset) - ``` - - Args: - dataset: A @{tf.data.Dataset} object containing a single element. - - Returns: - A nested structure of @{tf.Tensor} objects, corresponding to the single - element of `dataset`. - - Raises: - TypeError: if `dataset` is not a `tf.data.Dataset` object. - InvalidArgumentError (at runtime): if `dataset` does not contain exactly - one element. - """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( - dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 6c21e489f7c35484ebacd465e3b46d6920df5933..5f5513849cb29a18b86ba8bcee1ab6c9c60674cb 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse def ignore_errors(): @@ -64,10 +62,7 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index a817b45b71b608810a9d7536ec123ab84f7cdc3b..0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.ops import gen_dataset_ops @@ -59,9 +60,11 @@ def get_single_element(dataset): """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( + + nested_ret = nest.pack_sequence_as( + dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) + **dataset_ops.flat_structure(dataset))) + return sparse.deserialize_sparse_tensors( + nested_ret, dataset.output_types, dataset.output_shapes, + dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 67b085002aa7797d858837fea4646fb968ad5d97..f9f25e6a0687fe7167525847c64743f52a551fb0 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,13 +17,50 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops + + +def group_by_reducer(key_func, reducer): + """A transformation that groups elements and performs a reduction. + + This transformation maps element of a dataset to a key using `key_func` and + groups the elements by key. The `reducer` is used to process each group; its + `init_func` is used to initialize state for each group when it is created, the + `reduce_func` is used to update the state every time an element is mapped to + the matching group, and the `finalize_func` is used to map the final state to + an output value. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reducer: An instance of `Reducer`, which captures the reduction logic using + the `init_func`, `reduce_func`, and `finalize_func` functions. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return GroupByReducerDataset(dataset, key_func, reducer) + + return _apply_fn def group_by_window(key_func, @@ -35,7 +72,7 @@ def group_by_window(key_func, This transformation maps each consecutive element in a dataset to a key using `key_func` and groups the elements by key. It then applies `reduce_func` to at most `window_size_func(key)` elements matching the same - key. All execpt the final window for each key will contain + key. All except the final window for each key will contain `window_size_func(key)` elements; the final window may be smaller. You may provide either a constant `window_size` or a window size determined by @@ -85,6 +122,114 @@ def group_by_window(key_func, return _apply_fn +def bucket_by_sequence_length(element_length_func, + bucket_boundaries, + bucket_batch_sizes, + padded_shapes=None, + padding_values=None, + pad_to_bucket_boundary=False): + """A transformation that buckets elements in a `Dataset` by length. + + Elements of the `Dataset` are grouped together by length and then are padded + and batched. + + This is useful for sequence tasks in which the elements have variable length. + Grouping together elements that have similar lengths reduces the total + fraction of padding in a batch which increases training step efficiency. + + Args: + element_length_func: function from element in `Dataset` to `tf.int32`, + determines the length of the element, which will determine the bucket it + goes into. + bucket_boundaries: `list`, upper length boundaries of the buckets. + bucket_batch_sizes: `list`, batch size per bucket. Length should be + `len(bucket_boundaries) + 1`. + padded_shapes: Nested structure of `tf.TensorShape` to pass to + @{tf.data.Dataset.padded_batch}. If not provided, will use + `dataset.output_shapes`, which will result in variable length dimensions + being padded out to the maximum length in each batch. + padding_values: Values to pad with, passed to + @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. + pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown + size to maximum length in batch. If `True`, will pad dimensions with + unknown size to bucket boundary, and caller must ensure that the source + `Dataset` does not contain any elements with length longer than + `max(bucket_boundaries)`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + + Raises: + ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. + """ + with ops.name_scope("bucket_by_seq_length"): + if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): + raise ValueError( + "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") + + batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) + + def element_to_bucket_id(*args): + """Return int64 id of the length bucket for this element.""" + seq_length = element_length_func(*args) + + boundaries = list(bucket_boundaries) + buckets_min = [np.iinfo(np.int32).min] + boundaries + buckets_max = boundaries + [np.iinfo(np.int32).max] + conditions_c = math_ops.logical_and( + math_ops.less_equal(buckets_min, seq_length), + math_ops.less(seq_length, buckets_max)) + bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) + + return bucket_id + + def window_size_fn(bucket_id): + # The window size is set to the batch size for this bucket + window_size = batch_sizes[bucket_id] + return window_size + + def make_padded_shapes(shapes, none_filler=None): + padded = [] + for shape in nest.flatten(shapes): + shape = tensor_shape.TensorShape(shape) + shape = [ + none_filler if d.value is None else d + for d in shape + ] + padded.append(shape) + return nest.pack_sequence_as(shapes, padded) + + def batching_fn(bucket_id, grouped_dataset): + """Batch elements in dataset.""" + batch_size = batch_sizes[bucket_id] + none_filler = None + if pad_to_bucket_boundary: + err_msg = ("When pad_to_bucket_boundary=True, elements must have " + "length <= max(bucket_boundaries).") + check = check_ops.assert_less( + bucket_id, + constant_op.constant(len(bucket_batch_sizes) - 1, + dtype=dtypes.int64), + message=err_msg) + with ops.control_dependencies([check]): + boundaries = constant_op.constant(bucket_boundaries, + dtype=dtypes.int64) + bucket_boundary = boundaries[bucket_id] + none_filler = bucket_boundary + shapes = make_padded_shapes( + padded_shapes or grouped_dataset.output_shapes, + none_filler=none_filler) + return grouped_dataset.padded_batch(batch_size, shapes, padding_values) + + def _apply_fn(dataset): + return dataset.apply( + group_by_window(element_to_bucket_id, batching_fn, + window_size_func=window_size_fn)) + + return _apply_fn + + class _VariantDataset(dataset_ops.Dataset): """A Dataset wrapper for a tf.variant-typed function argument.""" @@ -112,6 +257,254 @@ class _VariantDataset(dataset_ops.Dataset): return self._output_types +class GroupByReducerDataset(dataset_ops.Dataset): + """A `Dataset` that groups its input and performs a reduction.""" + + def __init__(self, input_dataset, key_func, reducer): + """See `group_by_reducer()` for details.""" + super(GroupByReducerDataset, self).__init__() + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_init_func(reducer.init_func) + self._make_reduce_func(reducer.reduce_func, input_dataset) + self._make_finalize_func(reducer.finalize_func) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) + def tf_key_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + # pylint: disable=protected-access + if dataset_ops._should_unpack_args(nested_args): + ret = key_func(*nested_args) + # pylint: enable=protected-access + else: + ret = key_func(nested_args) + ret = ops.convert_to_tensor(ret) + if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + return ret + + self._key_func = tf_key_func + self._key_func.add_to_graph(ops.get_default_graph()) + + def _make_init_func(self, init_func): + """Make wrapping Defun for init_func.""" + + @function.Defun(dtypes.int64) + def tf_init_func(key): + """A wrapper for Defun that facilitates shape inference.""" + key.set_shape([]) + ret = init_func(key) + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._state_classes = sparse.get_classes(ret) + self._state_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._state_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._init_func = tf_init_func + self._init_func.add_to_graph(ops.get_default_graph()) + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + # Iteratively rerun the reduce function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + # Create a list in which `tf_reduce_func` will store the new shapes. + flat_new_state_shapes = [] + + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) + def tf_reduce_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): + arg.set_shape(shape) + + pivot = len(nest.flatten(self._state_shapes)) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) + + ret = reduce_func(nested_state_args, nested_input_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + # Extract shape information from the returned values. + flat_new_state = nest.flatten(ret) + flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, + nest.pack_sequence_as(self._state_types, + [t.dtype for t in flat_new_state]))) + + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, + [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + # Use the private method that will execute `tf_reduce_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access + + flat_state_shapes = nest.flatten(self._state_shapes) + weakened_state_shapes = [ + old.most_specific_compatible_shape(new) + for old, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for old_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if old_shape.ndims is not None and ( + weakened_shape.ndims is None or + old_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._reduce_func = tf_reduce_func + self._reduce_func.add_to_graph(ops.get_default_graph()) + + def _make_finalize_func(self, finalize_func): + """Make wrapping Defun for finalize_func.""" + + @function.Defun(*(nest.flatten( + sparse.as_dense_types(self._state_types, self._state_classes)))) + def tf_finalize_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes))): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(self._state_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, self._state_types, self._state_shapes, + self._state_classes) + + ret = finalize_func(nested_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._output_classes = sparse.get_classes(ret) + self._output_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._output_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._finalize_func = tf_finalize_func + self._finalize_func.add_to_graph(ops.get_default_graph()) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_reducer_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._init_func.captured_inputs, + self._reduce_func.captured_inputs, + self._finalize_func.captured_inputs, + key_func=self._key_func, + init_func=self._init_func, + reduce_func=self._reduce_func, + finalize_func=self._finalize_func, + **dataset_ops.flat_structure(self)) + + class GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" @@ -136,6 +529,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): if window_size.dtype != dtypes.int64: raise ValueError( "`window_size_func` must return a single tf.int64 tensor.") + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return window_size self._window_size_func = tf_window_size_func @@ -168,6 +562,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError("`key_func` must return a single tf.int64 tensor.") + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return ret self._key_func = tf_key_func @@ -191,6 +586,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._output_classes = output_dataset.output_classes self._output_types = output_dataset.output_types self._output_shapes = output_dataset.output_shapes + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return output_dataset._as_variant_tensor() # pylint: disable=protected-access self._reduce_func = tf_reduce_func @@ -217,7 +613,31 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) + + +class Reducer(object): + """A reducer is used for reducing a set of elements. + + A reducer is represented as a tuple of the three functions: + 1) initialization function: key => initial state + 2) reduce function: (old state, input) => new state + 3) finalization function: state => result + """ + + def __init__(self, init_func, reduce_func, finalize_func): + self._init_func = init_func + self._reduce_func = reduce_func + self._finalize_func = finalize_func + + @property + def init_func(self): + return self._init_func + + @property + def reduce_func(self): + return self._reduce_func + + @property + def finalize_func(self): + return self._finalize_func diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 3124ca1d1540e12d949dded88ce1c66181be3595..70153ac575758f16beff373941dfefb32bd342cf 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,101 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib import stateless +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert +from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation -class ParallelInterleaveDataset(dataset_ops.Dataset): - """A `Dataset` that maps a function over its input and flattens the result.""" - - def __init__(self, input_dataset, map_func, cycle_length, block_length, - sloppy, buffer_output_elements, prefetch_input_elements): - """See `tf.contrib.data.parallel_interleave()` for details.""" - super(ParallelInterleaveDataset, self).__init__() - self._input_dataset = input_dataset - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access - dataset = map_func(*nested_args) - else: - dataset = map_func(nested_args) - - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - - return dataset._as_variant_tensor() # pylint: disable=protected-access - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) - - self._cycle_length = ops.convert_to_tensor( - cycle_length, dtype=dtypes.int64, name="cycle_length") - self._block_length = ops.convert_to_tensor( - block_length, dtype=dtypes.int64, name="block_length") - self._sloppy = ops.convert_to_tensor( - sloppy, dtype=dtypes.bool, name="sloppy") - self._buffer_output_elements = convert.optional_param_to_tensor( - "buffer_output_elements", - buffer_output_elements, - argument_default=2 * block_length) - self._prefetch_input_elements = convert.optional_param_to_tensor( - "prefetch_input_elements", - prefetch_input_elements, - argument_default=2 * cycle_length) - - def _as_variant_tensor(self): - return gen_dataset_ops.parallel_interleave_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._map_func.captured_inputs, - self._cycle_length, - self._block_length, - self._sloppy, - self._buffer_output_elements, - self._prefetch_input_elements, - f=self._map_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def parallel_interleave(map_func, cycle_length, block_length=1, @@ -162,7 +82,7 @@ def parallel_interleave(map_func, @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return ParallelInterleaveDataset( + return readers.ParallelInterleaveDataset( dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements) @@ -221,7 +141,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return ParallelInterleaveDataset( + return readers.ParallelInterleaveDataset( dataset, map_func, cycle_length, @@ -231,3 +151,133 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): prefetch_input_elements=None) return _apply_fn + + +class DirectedInterleaveDataset(dataset_ops.Dataset): + """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" + + def __init__(self, selector_input, data_inputs): + self._selector_input = selector_input + self._data_inputs = list(data_inputs) + + for data_input in data_inputs[1:]: + if (data_input.output_types != data_inputs[0].output_types or + data_input.output_classes != data_inputs[0].output_classes): + raise TypeError("All datasets must have the same type.") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_dataset_ops.directed_interleave_dataset( + self._selector_input._as_variant_tensor(), + [data_input._as_variant_tensor() for data_input in self._data_inputs], + **dataset_ops.flat_structure(self)) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._data_inputs[0].output_classes + + @property + def output_shapes(self): + ret = self._data_inputs[0].output_shapes + for data_input in self._data_inputs[1:]: + ret = nest.pack_sequence_as(ret, [ + ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( + nest.flatten(ret), nest.flatten(data_input.output_shapes)) + ]) + return ret + + @property + def output_types(self): + return self._data_inputs[0].output_types + + +def sample_from_datasets(datasets, weights=None, seed=None): + """Samples elements at random from the datasets in `datasets`. + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + weights: (Optional.) A list of `len(datasets)` floating-point values where + `weights[i]` represents the probability with which an element should be + sampled from `datasets[i]`, or a @{tf.data.Dataset} object where each + element is such a list. Defaults to a uniform distribution across + `datasets`. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A dataset that interleaves elements from `datasets` at random, according to + `weights` if provided, otherwise with uniform probability. + + Raises: + TypeError: If the `datasets` or `weights` arguments have the wrong type. + ValueError: If the `weights` argument is specified and does not match the + length of the `datasets` element. + """ + num_datasets = len(datasets) + if weights is None: + weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat() + elif not isinstance(weights, dataset_ops.Dataset): + weights = ops.convert_to_tensor(weights, name="weights") + if weights.dtype not in (dtypes.float32, dtypes.float64): + raise TypeError("`weights` must be convertible to a tensor of " + "`tf.float32` or `tf.float64` elements.") + if not weights.shape.is_compatible_with([num_datasets]): + raise ValueError("`weights` must be a vector of length `len(datasets)`.") + weights = dataset_ops.Dataset.from_tensors(weights).repeat() + + # The `stateless_multinomial()` op expects log-probabilities, as opposed to + # weights. + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + def select_dataset(logits, seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + selector_input = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) + + return DirectedInterleaveDataset(selector_input, datasets) + + +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb035e573b70e8b19570e4e8ceca3c005..0d71be66018eeebe60de9deff24ceb6854d209d9 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,10 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import saver +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook def make_saveable_from_iterator(iterator): @@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator): return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access -class _Saveable(saver.BaseSaverBuilder.SaveableObject): +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): """SaveableObject for saving/restoring iterator state.""" def __init__(self, iterator_resource): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ - saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") ] super(_Saveable, self).__init__(iterator_resource, specs, iterator_resource.name) @@ -75,3 +77,182 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, unused_restored_shapes): with ops.colocate_with(self.op): return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input__.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = saver_lib.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30a993b1f7056b9726f524b2279131339c80c5eb --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test.TestCase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.Variable( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = saver_lib.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..9612ac5ae910f8ee08d4b3ed9097a5c80266fcfd --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,74 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + **dataset_ops.flat_structure(self)) + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 96a9e9ed6649444dac5e56d7dd2fcdb62fc56459..e4c9f8b58a2a4390004b0ad318163526b443d44f 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,27 +17,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops # TODO(rohanj): Add a python class that constructs resource in the __init__ # method and provides a get_next() that calls the prefetch op. def function_buffering_resource(string_arg, target_device, - shared_name, f, buffer_size, - thread_pool_size=1, container="", + shared_name=None, name=None): + if shared_name is None: + shared_name = "" return gen_dataset_ops.function_buffering_resource( string_arg=string_arg, target_device=target_device, shared_name=shared_name, f=f, buffer_size=buffer_size, - thread_pool_size=thread_pool_size, container=container, name=name) @@ -49,3 +60,266 @@ def function_buffering_resource_get_next(function_buffer_resource, function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) + + +def function_buffering_resource_reset(function_buffer_resource, name=None): + return gen_dataset_ops.function_buffering_resource_reset( + function_buffer_resource=function_buffer_resource, name=name) + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + device, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + iterator_device = gen_dataset_ops.iterator_get_device( + self._input_iterator._iterator_resource) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=iterator_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name) + + if not self._one_shot: + reset_op = function_buffering_resource_reset(self._buffering_resource) + with ops.control_dependencies([reset_op]): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + self._buffering_resource, + output_types=nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + nest.flatten(ret), nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + + return ret + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): + """A replacement for @{tf.data.Iterator} that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + device, + buffer_size): + with ops.device("/device:CPU:0"): + super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset) + input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle( + self._resource) + + self._device = device + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self.output_types, self.output_shapes, self.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + _prefetch_fn.add_to_graph(None) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=gen_dataset_ops.iterator_get_device(self._resource), + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=iterator_ops._generate_shared_name( + "function_buffer_resource")) + + def _next_internal(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): + with ops.device(self._device): + ret = gen_dataset_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffering_resource, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to another device.""" + + def __init__(self, input_dataset, device, buffer_size): + self._input_dataset = input_dataset + self._device = device + self._buffer_size = buffer_size if buffer_size is not None else 1 + + # The static analysis cannot tell that the eager iterator's superclass has + # a `next()` method. + # pylint: disable=non-iterator-returned + def __iter__(self): + """Creates an `Iterator` for enumerating the elements of this dataset. + + The returned iterator implements the Python iterator protocol and therefore + can only be used in eager mode. + + Returns: + An `Iterator` over the elements of this dataset. + + Raises: + RuntimeError: If eager execution is enabled. + """ + if context.executing_eagerly(): + return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, + self._buffer_size) + else: + raise RuntimeError("dataset.__iter__() is only supported when eager " + "execution is enabled.") + # pylint: enable=non-iterator-returned + + def make_one_shot_iterator(self): + if context.executing_eagerly(): + return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, + self._buffer_size) + else: + return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True, + device=self._device, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + device=self._device, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_device()` must be the last " + "transformation in a dataset pipeline.") + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_device(device, buffer_size=None): + """A transformation that prefetches dataset values to the given `device`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + device: A string. The name of a device to which elements will be prefetched. + buffer_size: (Optional.) The number of elements to buffer on `device`. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, device, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 7d727165feabb101549567f28a2dfa07083de244..e670c4c8354f4067eb21c9b1fce708147c162967 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -18,12 +18,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import constant_op +from tensorflow.python.data.util import random_seed from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops @@ -34,25 +31,13 @@ class RandomDataset(dataset_ops.Dataset): def __init__(self, seed=None): """A `Dataset` of pseudorandom values.""" super(RandomDataset, self).__init__() - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) def _as_variant_tensor(self): return gen_dataset_ops.random_dataset( seed=self._seed, seed2=self._seed2, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 57f30102778f3bac47580f9bdf94e411dfe1b621..83095c7ba1c6465d18490e5197f71bf7f1fe2497 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,20 +17,769 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import csv + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation + +_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64, dtypes.string) + + +def _is_valid_int32(str_val): + try: + # Checks equality to prevent int32 overflow + return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( + str_val) + except (ValueError, OverflowError): + return False + + +def _is_valid_int64(str_val): + try: + dtypes.int64.as_numpy_dtype(str_val) + return True + except (ValueError, OverflowError): + return False + + +def _is_valid_float(str_val, float_dtype): + try: + return float_dtype.as_numpy_dtype(str_val) < np.inf + except ValueError: + return False + + +def _infer_type(str_val, na_value, prev_type): + """Given a string, infers its tensor type. + + Infers the type of a value by picking the least 'permissive' type possible, + while still allowing the previous type inference for this column to be valid. + + Args: + str_val: String value to infer the type of. + na_value: Additional string to recognize as a NA/NaN CSV value. + prev_type: Type previously inferred based on values of this column that + we've seen up till now. + Returns: + Inferred dtype. + """ + if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type + return prev_type + + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most + + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions + + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] + + +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): + """Generator that yields rows of CSV file(s) in order.""" + for fn in filenames: + with file_io.FileIO(fn, "r") as f: + rdr = csv.reader( + f, + delimiter=field_delim, + quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) + if header: + next(rdr) # Skip header lines + + for csv_row in rdr: + if len(csv_row) != num_cols: + raise ValueError( + "Problem inferring types: CSV row has different number of fields " + "than expected.") + yield csv_row + + +def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, + na_value, header, num_rows_for_inference, + select_columns): + """Infers column types from the first N valid CSV records of files.""" + if select_columns is None: + select_columns = range(num_cols) + inferred_types = [None] * len(select_columns) + + for i, csv_row in enumerate( + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): + if num_rows_for_inference is not None and i >= num_rows_for_inference: + break + + for j, col_index in enumerate(select_columns): + inferred_types[j] = _infer_type(csv_row[col_index], na_value, + inferred_types[j]) + + # Replace None's with a default type + inferred_types = [t or dtypes.string for t in inferred_types] + # Default to 0 or '' for null values + return [ + constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) + for t in inferred_types + ] + + +def _infer_column_names(filenames, field_delim, use_quote_delim): + """Infers column names from first rows of files.""" + csv_kwargs = { + "delimiter": field_delim, + "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE + } + with file_io.FileIO(filenames[0], "r") as f: + try: + column_names = next(csv.reader(f, **csv_kwargs)) + except StopIteration: + raise ValueError(("Received StopIteration when reading the header line " + "of %s. Empty file?") % filenames[0]) + + for name in filenames[1:]: + with file_io.FileIO(name, "r") as f: + try: + if next(csv.reader(f, **csv_kwargs)) != column_names: + raise ValueError( + "Files have different column names in the header row.") + except StopIteration: + raise ValueError(("Received StopIteration when reading the header line " + "of %s. Empty file?") % filenames[0]) + return column_names + + +def _get_sorted_col_indices(select_columns, column_names): + """Transforms select_columns argument into sorted column indices.""" + names_to_indices = {n: i for i, n in enumerate(column_names)} + num_cols = len(column_names) + for i, v in enumerate(select_columns): + if isinstance(v, int): + if v < 0 or v >= num_cols: + raise ValueError( + "Column index %d specified in select_columns out of valid range." % + v) + continue + if v not in names_to_indices: + raise ValueError( + "Value '%s' specified in select_columns not a valid column index or " + "name." % v) + select_columns[i] = names_to_indices[v] + + # Sort and ensure there are no duplicates + result = sorted(set(select_columns)) + if len(result) != len(select_columns): + raise ValueError("select_columns contains duplicate columns") + return result + + +def _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): + """Optionally shuffle and repeat dataset, as requested.""" + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + return dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + return dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + return dataset.repeat(num_epochs) + return dataset + + +def make_tf_record_dataset( + file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=None, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): + """Reads and optionally parses TFRecord files into a dataset. + + Provides common functionality such as batching, optional parsing, shuffling, + and performant defaults. + + Args: + file_pattern: List of files or patterns of TFRecord file paths. + See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + parser_fn: (Optional.) A function accepting string input to parse + and process the record contents. This function must map records + to components of a fixed shape, so they may be batched. By + default, uses the record contents unmodified. + num_epochs: (Optional.) An int specifying the number of times this + dataset is repeated. If None (the default), cycles through the + dataset forever. + shuffle: (Optional.) A bool that indicates whether the input + should be shuffled. Defaults to `True`. + shuffle_buffer_size: (Optional.) Buffer size to use for + shuffling. A large buffer size ensures better shuffling, but + increases memory usage and startup time. + shuffle_seed: (Optional.) Randomization seed to use for shuffling. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + num_parallel_reads: (Optional.) Number of threads used to read + records from files. By default or if set to a value >1, the + results will be interleaved. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + + Returns: + A dataset, where each element matches the output of `parser_fn` + except it will have an additional leading `batch-size` dimension, + or a `batch_size`-length 1-D tensor of strings if `parser_fn` is + unspecified. + """ + files = dataset_ops.Dataset.list_files( + file_pattern, shuffle=shuffle, seed=shuffle_seed) + + if num_parallel_reads is None: + # Note: We considered auto-tuning this value, but there is a concern + # that this affects the mixing of records from different files, which + # could affect training convergence/accuracy, so we are defaulting to + # a constant for now. + num_parallel_reads = 24 + dataset = core_readers.TFRecordDataset( + files, num_parallel_reads=num_parallel_reads) + + if shuffle_buffer_size is None: + # TODO(josh11b): Auto-tune this value when not specified + shuffle_buffer_size = 10000 + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + if parser_fn is None: + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + else: + # TODO(josh11b): if num_parallel_parser_calls is None, use some function + # of num cores instead of map_and_batch's default behavior of one batch. + dataset = dataset.apply(batching.map_and_batch( + parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + + if prefetch_buffer_size is None: + prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) + + +def make_csv_dataset( + file_pattern, + batch_size, + column_names=None, + column_defaults=None, + label_name=None, + select_columns=None, + field_delim=",", + use_quote_delim=True, + na_value="", + header=True, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + num_parallel_reads=1, + num_parallel_parser_calls=2, + sloppy=False, + num_rows_for_inference=100, +): + """Reads CSV files into a dataset. + + Reads CSV files into a dataset, where each element is a (features, labels) + tuple that corresponds to a batch of CSV rows. The features dictionary + maps feature column names to `Tensor`s containing the corresponding + feature data, and labels is a `Tensor` containing the batch's label data. + + Args: + file_pattern: List of files or patterns of file paths containing CSV + records. See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + column_names: An optional list of strings that corresponds to the CSV + columns, in order. One per column of the input record. If this is not + provided, infers the column names from the first row of the records. + These names will be the keys of the features dict of each dataset element. + column_defaults: A optional list of default values for the CSV fields. One + item per selected column of the input record. Each item in the list is + either a valid CSV dtype (float32, float64, int32, int64, or string), or a + `Tensor` with one of the aforementioned types. The tensor can either be + a scalar default value (if the column is optional), or an empty tensor (if + the column is required). If a dtype is provided instead of a tensor, the + column is also treated as required. If this list is not provided, tries + to infer types based on reading the first num_rows_for_inference rows of + files specified, and assumes all columns are optional, defaulting to `0` + for numeric values and `""` for string values. If both this and + `select_columns` are specified, these must have the same lengths, and + `column_defaults` is assumed to be sorted in order of increasing column + index. + label_name: A optional string corresponding to the label column. If + provided, the data for this column is returned as a separate `Tensor` from + the features dictionary, so that the dataset complies with the format + expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input + function. + select_columns: An optional list of integer indices or string column + names, that specifies a subset of columns of CSV data to select. If + column names are provided, these must correspond to names provided in + `column_names` or inferred from the file header lines. When this argument + is specified, only a subset of CSV columns will be parsed and returned, + corresponding to the columns specified. Using this results in faster + parsing and lower memory usage. If both this and `column_defaults` are + specified, these must have the same lengths, and `column_defaults` is + assumed to be sorted in order of increasing column index. + field_delim: An optional `string`. Defaults to `","`. Char delimiter to + separate fields in a record. + use_quote_delim: An optional bool. Defaults to `True`. If false, treats + double quotation marks as regular characters inside of the string fields. + na_value: Additional string to recognize as NA/NaN. + header: A bool that indicates whether the first rows of provided CSV files + correspond to header lines with column names, and should not be included + in the data. + num_epochs: An int specifying the number of times this dataset is repeated. + If None, cycles through the dataset forever. + shuffle: A bool that indicates whether the input should be shuffled. + shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size + ensures better shuffling, but increases memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: An int specifying the number of feature batches to + prefetch for performance improvement. Recommended value is the number of + batches consumed per training step. + num_parallel_reads: Number of threads used to read CSV records from files. + If >1, the results will be interleaved. + num_parallel_parser_calls: Number of parallel invocations of the CSV parsing + function on CSV records. + sloppy: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + num_rows_for_inference: Number of rows of a file to use for type inference + if record_defaults is not provided. If None, reads all the rows of all + the files. Defaults to 100. + + Returns: + A dataset, where each element is a (features, labels) tuple that corresponds + to a batch of `batch_size` CSV rows. The features dictionary maps feature + column names to `Tensor`s containing the corresponding column data, and + labels is a `Tensor` containing the column data for the label column + specified by `label_name`. + + Raises: + ValueError: If any of the arguments is malformed. + """ + # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Clean arguments; figure out column names and defaults + + if column_names is None: + if not header: + raise ValueError("Cannot infer column names without a header line.") + # If column names are not provided, infer from the header lines + column_names = _infer_column_names(filenames, field_delim, use_quote_delim) + if len(column_names) != len(set(column_names)): + raise ValueError("Cannot have duplicate column names.") + + if select_columns is not None: + select_columns = _get_sorted_col_indices(select_columns, column_names) + + if column_defaults is not None: + column_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in column_defaults + ] + else: + # If column defaults are not provided, infer from records at graph + # construction time + column_defaults = _infer_column_defaults( + filenames, len(column_names), field_delim, use_quote_delim, na_value, + header, num_rows_for_inference, select_columns) + + if select_columns is not None and len(column_defaults) != len(select_columns): + raise ValueError( + "If specified, column_defaults and select_columns must have same " + "length." + ) + if select_columns is not None and len(column_names) > len(select_columns): + # Pick the relevant subset of column names + column_names = [column_names[i] for i in select_columns] + + if label_name is not None and label_name not in column_names: + raise ValueError("`label_name` provided must be one of the columns.") + + def filename_to_dataset(filename): + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header) + + def map_fn(*columns): + """Organizes columns into a features dictionary. + + Args: + *columns: list of `Tensor`s corresponding to one csv record. + Returns: + An OrderedDict of feature names to values for that particular record. If + label_name is provided, extracts the label feature to be returned as the + second element of the tuple. + """ + features = collections.OrderedDict(zip(column_names, columns)) + if label_name is not None: + label = features.pop(label_name) + return features, label + return features + + # Read files sequentially (if num_parallel_reads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) + + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map + dataset = dataset.batch(batch_size=batch_size) + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) + dataset = dataset.prefetch(prefetch_buffer_size) + + return dataset + + +_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + + +class CsvDataset(dataset_ops.Dataset): + """A Dataset comprising lines from one or more CSV files.""" + + def __init__(self, + filenames, + record_defaults, + buffer_size=None, + header=False, + field_delim=",", + use_quote_delim=True, + na_value="", + select_cols=None): + """Creates a `CsvDataset` by reading and decoding CSV files. + + The elements of this dataset correspond to records from the file(s). + RFC 4180 format is expected for CSV files + (https://tools.ietf.org/html/rfc4180) + Note that we allow leading and trailing spaces with int or float field. + + + For example, suppose we have a file 'my_file0.csv' with four CSV columns of + different data types: + ``` + abcdefg,4.28E10,5.55E6,12 + hijklmn,-5.3E14,,2 + ``` + + We can construct a CsvDataset from it as follows: + ```python + dataset = tf.contrib.data.CsvDataset( + "my_file*.csv", + [tf.float32, # Required field, use dtype or empty tensor + tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 + tf.int32, # Required field, use dtype or empty tensor + ], + select_cols=[1,2,3] # Only parse last three columns + ) + ``` + + The expected output of its iterations is: + ```python + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + + >> (4.28e10, 5.55e6, 12) + >> (-5.3e14, 0.0, 2) + ``` + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_defaults: A list of default values for the CSV fields. Each item in + the list is either a valid CSV `DType` (float32, float64, int32, int64, + string), or a `Tensor` object with one of the above types. One per + column of CSV data, with either a scalar `Tensor` default value for the + column if it is optional, or `DType` or empty `Tensor` if required. If + both this and `select_columns` are specified, these must have the same + lengths, and `column_defaults` is assumed to be sorted in order of + increasing column index. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer while reading files. Defaults to 4MB. + header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) + have header line(s) that should be skipped when parsing. Defaults to + `False`. + field_delim: (Optional.) A `tf.string` scalar containing the delimiter + character that separates fields in a record. Defaults to `","`. + use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats + double quotation marks as regular characters inside of string fields + (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. + na_value: (Optional.) A `tf.string` scalar indicating a value that will + be treated as NA/NaN. + select_cols: (Optional.) A sorted list of column indices to select from + the input data. If specified, only this subset of columns will be + parsed. Defaults to parsing all columns. + """ + super(CsvDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + record_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in record_defaults + ] + self._record_defaults = ops.convert_n_to_tensor( + record_defaults, name="record_defaults") + self._buffer_size = convert.optional_param_to_tensor( + "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) + self._header = ops.convert_to_tensor( + header, dtype=dtypes.bool, name="header") + self._field_delim = ops.convert_to_tensor( + field_delim, dtype=dtypes.string, name="field_delim") + self._use_quote_delim = ops.convert_to_tensor( + use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") + self._na_value = ops.convert_to_tensor( + na_value, dtype=dtypes.string, name="na_value") + self._select_cols = convert.optional_param_to_tensor( + "select_cols", + select_cols, + argument_default=[], + argument_dtype=dtypes.int64, + ) + self._output_shapes = tuple( + tensor_shape.scalar() for _ in range(len(record_defaults))) + self._output_types = tuple(d.dtype for d in self._record_defaults) + self._output_classes = tuple( + ops.Tensor for _ in range(len(record_defaults))) + + def _as_variant_tensor(self): + # Constructs graph node for the dataset op. + return contrib_gen_dataset_ops.csv_dataset( + filenames=self._filenames, + record_defaults=self._record_defaults, + buffer_size=self._buffer_size, + header=self._header, + output_shapes=self._output_shapes, + field_delim=self._field_delim, + use_quote_delim=self._use_quote_delim, + na_value=self._na_value, + select_cols=self._select_cols, + ) + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + +def make_batched_features_dataset(file_pattern, + batch_size, + features, + reader=core_readers.TFRecordDataset, + reader_args=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + reader_num_threads=1, + parser_num_threads=2, + sloppy_ordering=False, + drop_final_batch=False): + """Returns a `Dataset` of feature dictionaries from `Example` protos. + + Example: + + ``` + serialized_examples = [ + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } + }, + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } + } + ] + ``` + + We can use arguments: + + ``` + features: { + "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), + "gender": FixedLenFeature([], dtype=tf.string), + "kws": VarLenFeature(dtype=tf.string), + } + ``` + + And the expected output is: + + ```python + { + "age": [[0], [-1]], + "gender": [["f"], ["f"]], + "kws": SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["code", "art", "sports"] + dense_shape=[2, 2]), + } + ``` + + Args: + file_pattern: List of files or patterns of file paths containing + `Example` records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + features: A `dict` mapping feature keys to `FixedLenFeature` or + `VarLenFeature` values. See `tf.parse_example`. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. + reader_args: Additional arguments to pass to the reader class. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. Defaults to `None`. + shuffle: A boolean, indicates whether the input should be shuffled. Defaults + to `True`. + shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity + ensures better shuffling but would increase memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: Number of feature batches to prefetch in order to + improve performance. Recommended value is the number of batches consumed + per training step (default is 1). + reader_num_threads: Number of threads used to read `Example` records. If >1, + the results will be interleaved. + parser_num_threads: Number of threads to use for parsing `Example` tensors + into a dictionary of `Feature` tensors. + sloppy_ordering: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + drop_final_batch: If `True`, and the batch size does not evenly divide the + input dataset size, the final smaller batch will be dropped. Defaults to + `False`. + + Returns: + A dataset of `dict` elements. Each `dict` maps feature keys to + `Tensor` or `SparseTensor` objects. + """ + # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Read `Example` records from files as tensor objects. + if reader_args is None: + reader_args = [] + + # Read files sequentially (if reader_num_threads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + lambda filename: reader(filename, *reader_args), + cycle_length=reader_num_threads, + sloppy=sloppy_ordering)) + + # Extract values if the `Example` tensors are stored as key-value tuples. + if dataset.output_types == (dtypes.string, dtypes.string): + dataset = dataset.map(lambda _, v: v) + + # Apply dataset repeat and shuffle transformations. + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + dataset = dataset.apply(stats_ops.feature_stats("record_stats")) + + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + + # Parse `Example` tensors to a dictionary of `Feature` tensors. + dataset = dataset.map( + lambda x: parsing_ops.parse_example(x, features), + num_parallel_calls=parser_num_threads) + + # TODO(rachelim): Add an optional label_name argument for extracting the label + # from the features dictionary, to comply with the type expected by the + # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function. + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +@deprecation.deprecated(None, + "Use `tf.contrib.data.make_batched_features_dataset`") def read_batch_features(file_pattern, batch_size, features, - reader, + reader=core_readers.TFRecordDataset, reader_args=None, randomize_input=True, num_epochs=None, @@ -80,47 +829,42 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. - reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of Examples. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. - capacity: Capacity of the ShuffleDataset. A large capacity ensures better + capacity: Buffer size of the ShuffleDataset. A large capacity ensures better shuffling but would increase memory usage and startup time. - Returns: A dict from keys in features to `Tensor` or `SparseTensor` objects. """ - filenames = _get_file_names(file_pattern, randomize_input) - if reader_args: - dataset = reader(filenames, *reader_args) - else: - dataset = reader(filenames) - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) - if num_epochs != 1: - dataset = dataset.repeat(num_epochs) - if randomize_input: - dataset = dataset.shuffle(capacity) - dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) - dataset = dataset.prefetch(1) + dataset = make_batched_features_dataset( + file_pattern, + batch_size, + features, + reader=reader, + reader_args=reader_args, + shuffle=randomize_input, + num_epochs=num_epochs, + shuffle_buffer_size=capacity) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs -def _get_file_names(file_pattern, randomize_input): +def _get_file_names(file_pattern, shuffle): """Parse list of file names from pattern, optionally shuffled. Args: file_pattern: File glob pattern, or list of glob patterns. - randomize_input: Whether to shuffle the order of file names. + shuffle: Whether to shuffle the order of file names. Returns: List of file names matching `file_pattern`. @@ -141,7 +885,7 @@ def _get_file_names(file_pattern, randomize_input): raise ValueError("No files match %s." % file_pattern) # Sort files so it will be deterministic for unit tests. - if not randomize_input: + if not shuffle: file_names = sorted(file_names) return file_names diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 56f526a330bfbea7305b0754bfd114c5e97db506..182a5c6ff36fcda8c9e2c522cce07bed0c2daec9 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -20,10 +20,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops @@ -50,78 +52,182 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. """ - def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - dist_estimation_batch_size = 32 - target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") + target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") class_values_ds = dataset.map(class_func) + + # Get initial distribution. if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") - acceptance_dist = _calculate_acceptance_probs(initial_dist_t, - target_dist_t) + acceptance_dist, prob_of_original = ( + _calculate_acceptance_probs_with_mixing(initial_dist_t, + target_dist_t)) initial_dist_ds = dataset_ops.Dataset.from_tensors( initial_dist_t).repeat() acceptance_dist_ds = dataset_ops.Dataset.from_tensors( acceptance_dist).repeat() + prob_of_original_ds = dataset_ops.Dataset.from_tensors( + prob_of_original).repeat() + else: + initial_dist_ds = _estimate_initial_dist_ds( + target_dist_t, class_values_ds) + acceptance_and_original_prob_ds = initial_dist_ds.map( + lambda initial: _calculate_acceptance_probs_with_mixing( + initial, target_dist_t)) + acceptance_dist_ds = acceptance_and_original_prob_ds.map( + lambda accept_prob, _: accept_prob) + prob_of_original_ds = acceptance_and_original_prob_ds.map( + lambda _, prob_original: prob_original) + filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, + class_values_ds, seed) + # Prefetch filtered dataset for speed. + filtered_ds = filtered_ds.prefetch(3) + + prob_original_static = _get_prob_original_static( + initial_dist_t, target_dist_t) if initial_dist is not None else None + if prob_original_static == 1: + return dataset_ops.Dataset.zip((class_values_ds, dataset)) + elif prob_original_static == 0: + return filtered_ds else: - num_classes = (target_dist_t.shape[0].value or - array_ops.shape(target_dist_t)[0]) - smoothing_constant = 10 - initial_examples_per_class_seen = array_ops.fill( - [num_classes], np.int64(smoothing_constant)) - - def update_estimate_and_tile(num_examples_per_class_seen, c): - updated_examples_per_class_seen, dist = _estimate_data_distribution( - c, num_examples_per_class_seen) - tiled_dist = array_ops.tile( - array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) - return updated_examples_per_class_seen, tiled_dist - - initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .apply(scan_ops.scan(initial_examples_per_class_seen, - update_estimate_and_tile)) - .apply(batching.unbatch())) - acceptance_dist_ds = initial_dist_ds.map( - lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) - - def maybe_warn_on_large_rejection(accept_dist, initial_dist): - proportion_rejected = math_ops.reduce_sum( - (1 - accept_dist) * initial_dist) - return control_flow_ops.cond( - math_ops.less(proportion_rejected, .5), - lambda: accept_dist, - lambda: logging_ops.Print( # pylint: disable=g-long-lambda - accept_dist, [proportion_rejected, initial_dist, accept_dist], - message="Proportion of examples rejected by sampler is high: ", - summarize=100, - first_n=10)) - - acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, - initial_dist_ds)) - .map(maybe_warn_on_large_rejection)) - - current_probabilities_ds = dataset_ops.Dataset.zip( - (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) - filtered_ds = ( - dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds, - dataset)) - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) - return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + return interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], + weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), + seed=seed) return _apply_fn -def _calculate_acceptance_probs(initial_probs, target_probs): - """Calculate the per-class acceptance rates. +def _get_prob_original_static(initial_dist_t, target_dist_t): + """Returns the static probability of sampling from the original. + + `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters + an Op that it isn't defined for. We have some custom logic to avoid this. + + Args: + initial_dist_t: A tensor of the initial distribution. + target_dist_t: A tensor of the target distribution. + + Returns: + The probability of sampling from the original distribution as a constant, + if it is a constant, or `None`. + """ + init_static = tensor_util.constant_value(initial_dist_t) + target_static = tensor_util.constant_value(target_dist_t) + + if init_static is None or target_static is None: + return None + else: + return np.min(target_static / init_static) + + +def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, + seed): + """Filters a dataset based on per-class acceptance probabilities. Args: - initial_probs: The class probabilities of the data. - target_probs: The desired class proportion in minibatches. + dataset: The dataset to be filtered. + acceptance_dist_ds: A dataset of acceptance probabilities. + initial_dist_ds: A dataset of the initial probability distribution, given or + estimated. + class_values_ds: A dataset of the corresponding classes. + seed: (Optional.) Python integer seed for the resampler. + Returns: - A list of the per-class acceptance probabilities. + A dataset of (class value, data) after filtering. + """ + def maybe_warn_on_large_rejection(accept_dist, initial_dist): + proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) + return control_flow_ops.cond( + math_ops.less(proportion_rejected, .5), + lambda: accept_dist, + lambda: logging_ops.Print( # pylint: disable=g-long-lambda + accept_dist, [proportion_rejected, initial_dist, accept_dist], + message="Proportion of examples rejected by sampler is high: ", + summarize=100, + first_n=10)) + + acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, + initial_dist_ds)) + .map(maybe_warn_on_large_rejection)) + + def _gather_and_copy(class_val, acceptance_prob, data): + return class_val, array_ops.gather(acceptance_prob, class_val), data + + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) + filtered_ds = ( + current_probabilities_and_class_and_data_ds + .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) + return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + + +def _estimate_initial_dist_ds( + target_dist_t, class_values_ds, dist_estimation_batch_size=32, + smoothing_constant=10): + num_classes = (target_dist_t.shape[0].value or + array_ops.shape(target_dist_t)[0]) + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist - This method is based on solving the following analysis: + initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) + + return initial_dist_ds + + +def _get_target_to_initial_ratio(initial_probs, target_probs): + # Add tiny to initial_probs to avoid divide by zero. + denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) + return target_probs / denom + + +def _estimate_data_distribution(c, num_examples_per_class_seen): + """Estimate data distribution as labels are seen. + + Args: + c: The class labels. Type `int32`, shape `[batch_size]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. + + Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. + dist: The updated distribution. Type `float32`, shape `[num_classes]`. + """ + num_classes = num_examples_per_class_seen.get_shape()[0].value + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( + array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) + init_prob_estimate = math_ops.truediv( + num_examples_per_class_seen, + math_ops.reduce_sum(num_examples_per_class_seen)) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist + + +def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): + """Calculates the acceptance probabilities and mixing ratio. + + In this case, we assume that we can *either* sample from the original data + distribution with probability `m`, or sample from a reshaped distribution + that comes from rejection sampling on the original distribution. This + rejection sampling is done on a per-class basis, with `a_i` representing the + probability of accepting data from class `i`. + + This method is based on solving the following analysis for the reshaped + distribution: Let F be the probability of a rejection (on any example). Let p_i be the proportion of examples in the data in class i (init_probs) @@ -150,39 +256,39 @@ def _calculate_acceptance_probs(initial_probs, target_probs): 0 <= t_i <= 1, sum_i(t_i) = 1 ``` - - A solution for a_i in terms of the other variabes is the following: + A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` - """ - # Add tiny to initial_probs to avoid divide by zero. - denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) - ratio_l = target_probs / denom - # Calculate list of acceptance probabilities. - max_ratio = math_ops.reduce_max(ratio_l) - return ratio_l / max_ratio + If we try to minimize the amount of data rejected, we get the following: + M_max = max_i [ t_i / p_i ] + M_min = min_i [ t_i / p_i ] -def _estimate_data_distribution(c, num_examples_per_class_seen): - """Estimate data distribution as labels are seen. + The desired probability of accepting data if it comes from class `i`: + + a_i = (t_i/p_i - m) / (M_max - m) + + The desired probability of pulling a data element from the original dataset, + rather than the filtered one: + + m = M_min Args: - c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, - containing counts. + initial_probs: A Tensor of the initial probability distribution, given or + estimated. + target_probs: A Tensor of the corresponding classes. Returns: - num_examples_per_lass_seen: Updated counts. Type `int64`, shape - `[num_classes]`. - dist: The updated distribution. Type `float32`, shape `[num_classes]`. + (A 1D Tensor with the per-class acceptance probabilities, the desired + probability of pull from the original distribution.) """ - num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in batch. - num_examples_per_class_seen = math_ops.add( - num_examples_per_class_seen, math_ops.reduce_sum( - array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) - init_prob_estimate = math_ops.truediv( - num_examples_per_class_seen, - math_ops.reduce_sum(num_examples_per_class_seen)) - dist = math_ops.cast(init_prob_estimate, dtypes.float32) - return num_examples_per_class_seen, dist + ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) + max_ratio = math_ops.reduce_max(ratio_l) + min_ratio = math_ops.reduce_min(ratio_l) + + # Target prob to sample from original distribution. + m = min_ratio + + # TODO(joelshor): Simplify fraction, if possible. + a_i = (ratio_l - m) / (max_ratio - m) + return a_i, m diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 1c88366273f5d186509454188e02350d4ea9f66b..67eede981cb8f685ba4840d1f3c12bfea54c7646 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -36,18 +37,22 @@ class _ScanDataset(dataset_ops.Dataset): self._input_dataset = input_dataset with ops.name_scope("initial_state"): + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. self._initial_state = nest.pack_sequence_as(initial_state, [ - ops.convert_to_tensor(t, name="component_%d" % i) + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( + t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) - # Compute initial values for the state shapes and types based on - # the initial state. These will be refined by running - # `tf_scan_func` one or more times below. - # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. + # Compute initial values for the state classes, shapes and types based on + # the initial state. The shapes may be refined by running `tf_scan_func` one + # or more times below. + self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, - [t.shape for t in nest.flatten(self._initial_state)]) + [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) @@ -57,72 +62,109 @@ class _ScanDataset(dataset_ops.Dataset): self._output_shapes = None self._output_types = None - # Iteratively rerun the scan function until reaching a fixed pont on + # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: - flat_state_shapes = nest.flatten(self._state_shapes) - flat_state_types = nest.flatten(self._state_types) - - # Create a list in which `tf_scan_func` will store the s + # Create a list in which `tf_scan_func` will store the new shapes. flat_new_state_shapes = [] - @function.Defun(*(flat_state_types + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. - # TODO(b/69424092): Check that neither inputs nor outputs are sparse. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, - flat_state_shapes + nest.flatten(dense_shapes)): + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): arg.set_shape(shape) - pivot = len(flat_state_shapes) - old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) - input_value = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - - ret = scan_func(old_state, input_value) + pivot = len(nest.flatten(self._state_shapes)) + print(self._state_classes) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + print(input_dataset.output_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) + + ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) new_state, output_value = ret - flat_new_state = [ - ops.convert_to_tensor(t) for t in nest.flatten(new_state) - ] - flat_output_value = [ - ops.convert_to_tensor(t) for t in nest.flatten(output_value) - ] + # Extract and validate class information from the returned values. + for t, clazz in zip( + nest.flatten(new_state), nest.flatten(self._state_classes)): + if not isinstance(t, clazz): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, + nest.pack_sequence_as( + self._state_types, + [type(t) for t in nest.flatten(new_state)]))) + self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. - flat_new_state_shapes.extend([t.shape for t in flat_new_state]) + flat_new_state_shapes.extend( + [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( - output_value, [t.shape for t in flat_output_value]) + output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, flat_state_types): + for t, dtype in zip( + nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % - (self._state_types, nest.pack_sequence_as( - self._state_types, [t.dtype for t in flat_new_state]))) - self._output_classes = nest.pack_sequence_as( - output_value, [ops.Tensor for _ in flat_output_value]) + (self._state_types, + nest.pack_sequence_as( + self._state_types, + [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in flat_output_value]) + output_value, [t.dtype for t in nest.flatten(output_value)]) - return flat_new_state + flat_output_value + dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access + + # Serialize any sparse tensors. + new_state = nest.pack_sequence_as(new_state, [ + t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) + ]) + output_value = nest.pack_sequence_as(output_value, [ + t for t in nest.flatten( + sparse.serialize_sparse_tensors(output_value)) + ]) + return nest.flatten(new_state) + nest.flatten(output_value) # Use the private method that will execute `tf_scan_func` but delay # adding it to the graph in case we need to rerun the function. tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + flat_state_shapes = nest.flatten(self._state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -144,18 +186,16 @@ class _ScanDataset(dataset_ops.Dataset): weakened_state_shapes) self._scan_func = tf_scan_func + self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, - nest.flatten(self._initial_state), + nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 99bb79bc06a421f811869ca9169aaa11deaca2f3..d7f8a73fe3d67bb83e44e962832ce34c116aef66 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -18,12 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse +from tensorflow.python.data.util import random_seed from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.ops import gen_dataset_ops @@ -45,17 +43,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): else: self._count = ops.convert_to_tensor( count, dtype=dtypes.int64, name="count") - - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) def _as_variant_tensor(self): # pylint: disable=protected-access @@ -66,10 +54,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): count=self._count, seed=self._seed, seed2=self._seed2, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py new file mode 100644 index 0000000000000000000000000000000000000000..f935beb1a9e85d4901857e7781a5ed8473838fa5 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sliding dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class _SlideDataset(dataset_ops.Dataset): + """A `Dataset` that passes a sliding window over its input.""" + + def __init__(self, input_dataset, window_size, stride=1): + """See `sliding_window_batch` for details.""" + super(_SlideDataset, self).__init__() + self._input_dataset = input_dataset + self._window_size = ops.convert_to_tensor( + window_size, dtype=dtypes.int64, name="window_size") + self._stride = ops.convert_to_tensor( + stride, dtype=dtypes.int64, name="stride") + + def _as_variant_tensor(self): + return gen_dataset_ops.slide_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + window_size=self._window_size, + stride=self._stride, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + input_shapes = self._input_dataset.output_shapes + return nest.pack_sequence_as(input_shapes, [ + tensor_shape.vector(None).concatenate(s) + for s in nest.flatten(self._input_dataset.output_shapes) + ]) + + @property + def output_types(self): + return self._input_dataset.output_types + + +def sliding_window_batch(window_size, stride=1): + """A sliding window with size of `window_size` and step of `stride`. + + This transformation passes a sliding window over this dataset. The + window size is `window_size` and step size is `stride`. If the left + elements cannot fill up the sliding window, this transformation will + drop the final smaller element. For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { [1], [2], [3], [4], [5], [6] } + + a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) == + { + [[1], [2], [3]], + [[3], [4], [5]], + } + ``` + + Args: + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + elements in the sliding window. + stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + steps moving the sliding window forward for one iteration. The default + is `1`. It must be in `[1, window_size)`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _SlideDataset(dataset, window_size, stride) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 9cd1701c397b5a0bf5cc47c1bcab033704794d80..3c82a03df1745d855b2d3f918f7bbde113600556 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,9 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -47,7 +44,7 @@ class StatsAggregator(object): dataset = ... iterator = dataset.make_one_shot_iterator() stats_aggregator = stats_ops.StatsAggregator() - set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) + set_op = stats_aggregator.subscribe(iterator) with tf.Session() as sess: # Running `set_op` will associate `iterator` with `stats_aggregator`. @@ -85,32 +82,57 @@ class StatsAggregator(object): """ return gen_dataset_ops.stats_aggregator_summary(self._resource) - def subscribe(self, iterator): - """Returns a @{tf.Operation} to associate this aggregator with `iterator`. - Note: Each @{tf.data.Iterator} can be associated with at most one - `StatsAggregator`. After running the operation that this function - returns, all statistics recorded in the iteration of `iterator` - will be stored in `stats_aggregator`. +class _SetStatsAggregatorDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" - Args: - iterator: A @{tf.data.Iterator} object. + def __init__(self, input_dataset, stats_aggregator): + super(_SetStatsAggregatorDataset, self).__init__() + self._input_dataset = input_dataset + self._stats_aggregator = stats_aggregator - Returns: - A @{tf.Operation} that, when run, associates this aggregator with - `iterator`. - """ - if not isinstance(iterator, iterator_ops.Iterator): - raise TypeError("`iterator` must be a `tf.data.Iterator` object.") - return gen_dataset_ops.iterator_set_stats_aggregator( - iterator._iterator_resource, self._resource) # pylint: disable=protected-access + def _as_variant_tensor(self): + return gen_dataset_ops.set_stats_aggregator_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._stats_aggregator._resource, # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +def set_stats_aggregator(stats_aggregator): + """Set the given stats_aggregator for aggregating the input dataset stats. + + Args: + stats_aggregator: A `StatsAggregator` object. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _SetStatsAggregatorDataset(dataset, stats_aggregator) + + return _apply_fn def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. - To consume the statistics, associate a `StatsAggregator` with an iterator - over the output dataset. + To consume the statistics, associate a `StatsAggregator` with the output + dataset. Args: tag: String. All statistics recorded by the returned transformation will @@ -131,8 +153,8 @@ def bytes_produced_stats(tag): def latency_stats(tag): """Records the latency of producing each element of the input dataset. - To consume the statistics, associate a `StatsAggregator` with an iterator - over the output dataset. + To consume the statistics, associate a `StatsAggregator` with the output + dataset. Args: tag: String. All statistics recorded by the returned transformation will @@ -149,6 +171,27 @@ def latency_stats(tag): return _apply_fn +def feature_stats(tag): + """Records the features stats from `Example` records of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will be + associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag) + + return _apply_fn + + class _StatsDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and also records statistics.""" @@ -162,10 +205,7 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py new file mode 100644 index 0000000000000000000000000000000000000000..bb49604d4de90d726418684124608438aa33e6cf --- /dev/null +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for controlling threading in `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.ops import resource_variable_ops + +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def _generate_shared_name(prefix): + with _uid_lock: + global _uid_counter + uid = _uid_counter + _uid_counter += 1 + return "{}{}".format(prefix, uid) + + +class PrivateThreadPool(object): + """A stateful resource that represents a private thread pool.""" + + def __init__(self, num_threads, display_name=None): + """Creates a `PrivateThreadPool` with the given number of threads.""" + if context.executing_eagerly(): + shared_name = _generate_shared_name("privatethreadpool") + self._resource = gen_dataset_ops.thread_pool_handle( + num_threads=num_threads, + display_name=display_name, + shared_name=shared_name) + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=context.context().device_name) + else: + self._resource = gen_dataset_ops.thread_pool_handle( + num_threads=num_threads, display_name=display_name) + + +class _ThreadPoolDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and sets a custom threadpool.""" + + def __init__(self, input_dataset, thread_pool): + super(_ThreadPoolDataset, self).__init__() + self._input_dataset = input_dataset + self._thread_pool = thread_pool + + def _as_variant_tensor(self): + return gen_dataset_ops.thread_pool_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._thread_pool._resource, # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def override_threadpool(dataset, thread_pool): + """Returns a new dataset that uses the given thread pool for its operations. + + Args: + dataset: A `tf.data.Dataset` object. + thread_pool: A `PrivateThreadPool` object. + + Returns: + A dataset containing the same values as `dataset`, but which uses + `thread_pool` to compute any of its parallel operations (such as + @{tf.data.Dataset.map}). + """ + return _ThreadPoolDataset(dataset, thread_pool) diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 133e17d20d0fc4c8d52cef3c95c132374e927a0b..4ce6ddede8350735636fd152fdc9df0319265990 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -17,11 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes -from tensorflow.python.ops import gen_dataset_ops def unique(): @@ -64,10 +63,7 @@ class UniqueDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unique_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py new file mode 100644 index 0000000000000000000000000000000000000000..f53bd3f7383950d6cfdb35e12811fb1daf24b320 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/writers.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for tf.data writers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class TFRecordWriter(object): + """Writes data to a TFRecord file.""" + + def __init__(self, filename, compression_type=None): + self._filename = ops.convert_to_tensor( + filename, dtypes.string, name="filename") + self._compression_type = convert.optional_param_to_tensor( + "compression_type", + compression_type, + argument_default="", + argument_dtype=dtypes.string) + + def write(self, dataset): + """Returns a @{tf.Operation} to write a dataset to a file. + + Args: + dataset: a @{tf.data.Dataset} whose elements are to be written to a file + + Returns: + A @{tf.Operation} that, when run, writes contents of `dataset` to a file. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + if (dataset.output_types != dtypes.string or + dataset.output_shapes != tensor_shape.scalar()): + raise TypeError( + "`dataset` must produce scalar `DT_STRING` tensors whereas it " + "produces shape {0} and types {1}".format(dataset.output_shapes, + dataset.output_types)) + return gen_dataset_ops.dataset_to_tf_record( + dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index f6de5998d73a4869d2444cd90c9b64d1a2c889ac..3b50a48336d77ebd9327fa24e5612a95d5d0c372 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -13,19 +13,10 @@ load( "tf_pyclif_proto_library", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], cc_api_version = 2, - go_api_version = 2, java_api_version = 2, visibility = ["//visibility:public"], ) @@ -34,7 +25,6 @@ tf_proto_library( name = "generic_tree_model_extensions", srcs = ["generic_tree_model_extensions.proto"], cc_api_version = 2, - go_api_version = 2, protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/deprecated/BUILD b/tensorflow/contrib/deprecated/BUILD index 3dfbbf55273848afb8ad74ad444f0d85b45610bd..401527f1e74f7725d02a3b92a2c661d8ffc11e21 100644 --- a/tensorflow/contrib/deprecated/BUILD +++ b/tensorflow/contrib/deprecated/BUILD @@ -30,15 +30,3 @@ py_test( "//tensorflow/python:logging_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..74b2cd90a187159fd2da8ce236c14e813cc43c49 --- /dev/null +++ b/tensorflow/contrib/distribute/BUILD @@ -0,0 +1,36 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "distribute", + srcs = ["__init__.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/distribute/python:cross_tower_ops", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:step_fn", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md new file mode 100644 index 0000000000000000000000000000000000000000..44a4481021c380e72b535cf0aca39df2bf04d3b7 --- /dev/null +++ b/tensorflow/contrib/distribute/README.md @@ -0,0 +1,141 @@ +# Distribution Strategy + +> *NOTE*: This is a experimental feature. The API and performance +> characteristics are subject to change. + +## Overview + +[`DistributionStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/DistributionStrategy) +API is an easy way to distribute your training +across multiple devices/machines. Our goal is to allow users to use existing +models and training code with minimal changes to enable distributed training. +Moreover, we've design the API in such a way that it works with both eager and +graph execution. + +Currently we support one type of strategy, called +[`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy). +It does in-graph replication with synchronous training +on many GPUs on one machine. Essentially, we create copies of all variables in +the model's layers on each device. We then use all-reduce to combine gradients +across the devices before applying them to the variables to keep them in sync. +In the future, we intend to support other kinds of training configurations such +as multi-node, synchronous, +[asynchronous](https://www.tensorflow.org/deploy/distributed#putting_it_all_together_example_trainer_program), +parameter servers and model parallelism. + +## Example + +Let's demonstrate how to use this API with a simple example. We will use the +[`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) +approach, and show you how to scale your model to run on multiple GPUs on one +machine using `MirroredStrategy`. + +Let's consider a very simple model function which tries to learn a simple +function. + +```python +def model_fn(features, labels, mode): + layer = tf.layers.Dense(1) + logits = layer(features) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"logits": logits} + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + loss = tf.losses.mean_squared_error( + labels=labels, predictions=tf.reshape(logits, [])) + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec(mode, loss=loss) + + if mode == tf.estimator.ModeKeys.TRAIN: + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) +``` + +Let's also define a simple input function to feed data for training this model. +Note that we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. + + +```python +def input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(100) + labels = tf.data.Dataset.from_tensors(1.).repeat(100) + return dataset_ops.Dataset.zip((features, labels)) +``` + +Now that we have a model function and input function defined, we can define the +estimator. To use `MirroredStrategy`, all we need to do is: + +* Create an instance of the `MirroredStrategy` class. +* Pass it to the +[`RunConfig`](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig) +parameter of `Estimator`. + + +```python +distribution = tf.contrib.distribute.MirroredStrategy() +config = tf.estimator.RunConfig(train_distribute=distribution) +classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) +classifier.train(input_fn=input_fn) +``` + +That's it! This change will now configure estimator to run on all GPUs on your +machine, with the `MirroredStrategy` approach. It will take care of distributing +the input dataset, replicating layers and variables on each device, and +combining and applying gradients. + +The model and input functions do not have to change because we have changed the +underlying components of TensorFlow (such as +optimizer, batch norm and summaries) to become distribution-aware. +That means those components know how to +combine their state across devices. Further, saving and checkpointing works +seamlessly, so you can save with one or no distribution strategy and resume with +another. + +Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__). +There are few things you can customize in practice: + +* You can specify a list of specific GPUs (using param `devices`) or the number +of GPUs (using param `num_gpus`), in case you don't want auto detection. +* You can specify various parameters for all reduce with the `cross_tower_ops` +param, such as the all reduce algorithm to use, and gradient repacking. + +## Performance Tips + +We've tried to make it such that you get the best performance for your existing +model. We also recommend you follow the tips from +[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance). +Specifically, we found using [`map_and_batch`](https://www.tensorflow.org/performance/datasets_performance#map_and_batch) +and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_performance#pipelining) +in the input function gives a solid boost in performance. When using +`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size. + +## Caveats +This feature is in early stages and there are a lot of improvements forthcoming: + +* Metrics are not yet supported during distributed training. They are still +supported during the evaluation. +* Summaries are only computed in the first tower in `MirroredStrategy`. +* Evaluation is not yet distributed. +* Eager support is in the works; performance can be more challenging with eager +execution. +* As mentioned earlier, multi-node and other distributed strategies will be +introduced in the future. +* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch) +your input data, we will place one batch on each GPU in each step. So your +effective batch size will be `num_gpus * batch_size`. Therefore, consider +adjusting your learning rate or batch size according to the number of GPUs. +We are working on addressing this limitation by splitting each batch across GPUs +instead. +* PartitionedVariables are not supported yet. + +## What's next? + +Please give distribution strategies a try. This feature is in early stages and +is evolving, so we welcome your feedback via +[issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new). + + diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76711baf3a11c8978fbb5770ec173ff74a153158 --- /dev/null +++ b/tensorflow/contrib/distribute/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Prototype of a distributed computation library for TF.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.cross_tower_ops import * +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.contrib.distribute.python.monitor import Monitor +from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.python.training.distribute import * + +from tensorflow.python.util.all_util import remove_undocumented + + +_allowed_symbols = [ + 'AllReduceCrossTowerOps', + 'CrossTowerOps', + 'DistributionStrategy', + 'MirroredStrategy', + 'Monitor', + 'OneDeviceStrategy', + 'ReductionToOneDeviceCrossTowerOps', + 'Step', + 'StandardInputStep', + 'StandardSingleLossStep', + 'TowerContext', + 'get_cross_tower_context', + 'get_distribution_strategy', + 'get_loss_reduction', + 'get_tower_context', + 'has_distribution_strategy', + 'require_tower_context', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9dfb8552f1b0f058b44f8ed09c2ed681367293d5 --- /dev/null +++ b/tensorflow/contrib/distribute/python/BUILD @@ -0,0 +1,612 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# TODO(priyag): Figure out testonly issues that are preventing us from +# including our tests in pip for now. + +py_library( + name = "values", + srcs = ["values.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":input_ops", + ":prefetching_ops_v2", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device_util", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "values_test", + srcs = ["values_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":multi_worker_test_base", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python:errors", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python:device_util", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:model_fn", + ], + tags = [ + "no_pip", + ], +) + +py_library( + name = "mirrored_strategy", + srcs = ["mirrored_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":shared_variable_creator", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_util", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + "@six_archive//:six", + ], +) + +py_library( + name = "multi_worker_strategy", + srcs = ["multi_worker_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + +py_library( + name = "one_device_strategy", + srcs = ["one_device_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":values", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_library( + name = "strategy_test_lib", + testonly = 1, + srcs = ["strategy_test_lib.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "combinations", + testonly = 1, + srcs = ["combinations.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":multi_worker_strategy", + ":one_device_strategy", + ":tpu_strategy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "combinations_test", + srcs = ["combinations_test.py"], + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "mirrored_strategy_test", + srcs = ["mirrored_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":strategy_test_lib", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":one_device_strategy", + ":strategy_test_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:test", + ], +) + +cuda_py_test( + name = "mirrored_strategy_multigpu_test", + srcs = ["mirrored_strategy_multigpu_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + ":strategy_test_lib", + "//tensorflow/python:distribute", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "guitar", + "no_pip", + "multi_and_single_gpu", + # Do not perform the extra analysis on this test, because it is already + # performed for the `:mirrored_strategy_test` target. + "no_oss", + "noasan", + "notap", + "notsan", + ], +) + +py_library( + name = "multi_worker_test_base", + testonly = 1, + srcs = ["multi_worker_test_base.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "step_fn", + srcs = ["step_fn.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:training", + "//tensorflow/python/eager:backprop", + ], +) + +py_library( + name = "tpu_strategy", + srcs = ["tpu_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":one_device_strategy", + ":values", + "//tensorflow/contrib/tpu", + "//tensorflow/contrib/tpu:tpu_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "minimize_loss_test_lib", + testonly = 1, + srcs = ["minimize_loss_test.py"], + deps = [ + ":combinations", + ":mirrored_strategy", + ":single_loss_example", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "minimize_loss_test", + srcs = ["minimize_loss_test.py"], + additional_deps = [ + ":minimize_loss_test_lib", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "optimizer_v2_test", + srcs = ["optimizer_v2_test.py"], + additional_deps = [ + ":combinations", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "estimator_integration_test", + srcs = ["estimator_integration_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:dnn_linear_combined", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "single_loss_example", + srcs = ["single_loss_example.py"], + deps = [ + ":step_fn", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +cuda_py_test( + name = "step_fn_test", + srcs = ["step_fn_test.py"], + additional_deps = [ + ":single_loss_example", + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "monitor", + srcs = ["monitor.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "monitor_test", + srcs = ["monitor_test.py"], + additional_deps = [ + ":combinations", + ":monitor", + ":one_device_strategy", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "shared_variable_creator", + srcs = ["shared_variable_creator.py"], + visibility = ["//tensorflow:internal"], +) + +py_test( + name = "shared_variable_creator_test", + srcs = ["shared_variable_creator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":shared_variable_creator", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "cross_tower_utils", + srcs = ["cross_tower_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":values", + "//tensorflow/contrib/all_reduce:all_reduce_py", + "//tensorflow/contrib/nccl:nccl_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +cuda_py_test( + name = "cross_tower_utils_test", + srcs = ["cross_tower_utils_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_utils", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", + ], +) + +py_library( + name = "cross_tower_ops", + srcs = ["cross_tower_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":cross_tower_utils", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "cross_tower_ops_test", + srcs = ["cross_tower_ops_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_ops", + ":multi_worker_test_base", + ":values", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + shard_count = 15, + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "prefetching_ops_v2", + srcs = ["prefetching_ops_v2.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:contrib_op_loader", + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +cuda_py_test( + name = "prefetching_ops_v2_test", + srcs = ["prefetching_ops_v2_test.py"], + additional_deps = [ + ":prefetching_ops_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_library( + name = "input_ops", + srcs = ["input_ops.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + ], +) + +cuda_py_test( + name = "input_ops_test", + srcs = ["input_ops_test.py"], + additional_deps = [ + ":input_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:errors", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:io_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python:util", + ], + tags = [ + "no_pip", + ], +) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:keras", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/keras", + ], + tags = [ + "multi_and_single_gpu", + "noguitar", + "notsan", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py new file mode 100644 index 0000000000000000000000000000000000000000..ba03b14deb9a3897dae29382ce601c0319f84735 --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -0,0 +1,401 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Facilities for creating multiple test combinations. + +Here is an example of testing various optimizers in Eager and Graph mode: + +class AdditionExample(test.TestCase, parameterized.TestCase): + @combinations.generate( + combinations.combine(mode=["graph", "eager"], + optimizer=[AdamOptimizer(), + GradientDescentOptimizer()])) + def testOptimizer(self, optimizer): + ... f(optimizer)... + +This will run `testOptimizer` 4 times with the specified optimizers: 2 in +Eager and 2 in Graph mode. +The test will be provided with arguments that match the arguments of combine +by name. It is necessary to request all arguments, except for `mode`, which is +optional. + +`combine()` function is available for creating a cross product of various +options. `times()` function exists for creating a product of N `combine()`-ed +results. See below. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import sys +import types +import unittest +from absl.testing import parameterized +import six + +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib +from tensorflow.contrib.optimizer_v2 import adam as adam_v2 +from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.training import adam +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import gradient_descent +from tensorflow.python.util import tf_inspect + + +GPU_TEST = "test_gpu" in sys.argv[0] +TPU_TEST = "test_tpu" in sys.argv[0] + + +def generate(combinations): + """A decorator for generating test cases of a test method or a test class. + + Args: + combinations: a list of dictionaries created using combine() and times(). + + Restrictions: + -- the "mode" argument can be either "eager" or "graph". It's "graph" by + default. + -- arguments of the test method must match by name to get the corresponding + value of the combination. Tests must accept all arguments except the + "mode", "required_tpu" and "required_gpus". + -- "distribution" argument is special and optional. It is meant for passing + instances of DistributionStrategy. Each instance is to be passed as via + `NamedDistribution`. If using "distribution", "required_gpus" and + "required_tpu" should be specified via the NamedDistribution instance, + rather than as separate arguments. + -- "required_tpu" argument is special and optional. If not `None`, then the + test will be skipped if TPUs aren't available. + -- "required_gpus" argument is special and optional. If not `None`, then the + test will be skipped if the specified number of GPUs aren't available. + + Returns: + a decorator that will cause the test method or the test class to be run + under the specified conditions. + + Raises: + ValueError - if "mode" argument wasn't either "eager" or "graph" or if other + arguments were not accepted by the test method. + """ + + def decorator(test_method_or_class): + """The decorator to be returned.""" + + # Generate good test names that can be used with --test_filter. + named_combinations = [] + for combination in combinations: + # We use OrderedDicts in `combine()` and `times()` to ensure stable + # order of keys in each dictionary. + assert isinstance(combination, OrderedDict) + name = "".join([ + "_{}_{}".format( + "".join(filter(str.isalnum, key)), + "".join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) + + if isinstance(test_method_or_class, type): + class_object = test_method_or_class + class_object._test_method_ids = test_method_ids = {} + for name, test_method in six.iteritems(class_object.__dict__.copy()): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + isinstance(test_method, types.FunctionType)): + delattr(class_object, name) + methods = {} + parameterized._update_class_dict_for_param_test_case( + class_object.__name__, methods, test_method_ids, name, + parameterized._ParameterizedTestIter( + _augment_with_special_arguments(test_method), + named_combinations, parameterized._NAMED, name)) + for method_name, method in six.iteritems(methods): + setattr(class_object, method_name, method) + + return class_object + else: + test_method = _augment_with_special_arguments(test_method_or_class) + return parameterized.named_parameters(*named_combinations)(test_method) + + return decorator + + +def _augment_with_special_arguments(test_method): + def decorated(self, **kwargs): + """A wrapped test method that treats some arguments in a special way.""" + mode = kwargs.pop("mode", "graph") + + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. + requested_arguments = tf_inspect.getfullargspec(test_method).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with ops.Graph().as_default(), context.eager_mode(): + test_method(**kwargs_to_pass) + elif mode == "graph": + with ops.Graph().as_default(), context.graph_mode(): + test_method(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + return decorated + + +def combine(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = combine(**rest) + + key = first[0] + values = first[1] + if not isinstance(values, list): + values = [values] + + return [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + + +def times(*combined): + """Generate a product of N sets of combinations. + + times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) + + Args: + *combined: N lists of dictionaries that specify combinations. + + Returns: + a list of dictionaries for each combination. + + Raises: + ValueError: if some of the inputs have overlapping keys. + """ + assert combined + + if len(combined) == 1: + return combined[0] + + first = combined[0] + rest_combined = times(*combined[1:]) + + combined_results = [] + for a in first: + for b in rest_combined: + if set(a.keys()).intersection(set(b.keys())): + raise ValueError("Keys need to not overlap: {} vs {}".format( + a.keys(), b.keys())) + + combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) + return combined_results + + +class NamedObject(object): + """A class that translates an object into a good test name.""" + + def __init__(self, name, obj): + self._name = name + self._obj = obj + + def __getattr__(self, name): + return getattr(self._obj, name) + + def __call__(self, *args, **kwargs): + return self._obj(*args, **kwargs) + + def __repr__(self): + return self._name + + +class NamedDistribution(object): + """Translates DistributionStrategy and its data into a good name.""" + + def __init__(self, name, distribution_fn, required_gpus=None, + required_tpu=False): + self._distribution_fn = distribution_fn + self._name = name + self._required_gpus = required_gpus + self._required_tpu = required_tpu + + def __repr__(self): + return self._name + + @property + def strategy(self): + return self._distribution_fn() + + @property + def required_gpus(self): + return self._required_gpus + + @property + def required_tpu(self): + return self._required_tpu + + +# pylint: disable=g-long-lambda +default_strategy = NamedDistribution( + "Default", + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + required_gpus=None) +one_device_strategy = NamedDistribution( + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), + required_gpus=None) +tpu_strategy_single_iteration = NamedDistribution( + "TPUSingleIteration", + lambda: tpu_lib.TPUStrategy(iterations_per_step=1), + required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) +# Note that we disable prefetching for testing since prefetching makes +# the input non-deterministic. +mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "MirroredCPUAndGPU", + lambda: mirrored_lib.MirroredStrategy( + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + required_gpus=1) +mirrored_strategy_with_two_gpus = NamedDistribution( + "Mirrored2GPUs", + lambda: mirrored_lib.MirroredStrategy( + ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + required_gpus=2) + +multi_worker_strategy_with_cpu = NamedDistribution( + "MultiWorkerCPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=0), 0) +multi_worker_strategy_with_one_gpu = NamedDistribution( + "MultiWorker1GPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=1), 1) +multi_worker_strategy_with_two_gpus = NamedDistribution( + "MultiWorker2GPUs", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=2), 2) + +adam_optimizer_v1_fn = NamedObject( + "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v1_fn = NamedObject( + "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) + +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v2_fn = NamedObject( + "GradientDescentV2", + lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) + +graph_and_eager_modes = ["graph", "eager"] + + +def distributions_and_v1_optimizers(): + """A common set of combination with DistributionStrategies and Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]) + + +def distributions_and_v2_optimizers(): + """DistributionStrategies and V2 Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py new file mode 100644 index 0000000000000000000000000000000000000000..86aa48cea889c6c2ce169b18bcabb6d08890fbed --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for some testing utils from strategy_test_lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test + + +class TestingCombinationsTest(test.TestCase): + + def test_combine(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 1, + "b": 3 + }, { + "a": 2, + "b": 2 + }, { + "a": 2, + "b": 3 + }], combinations.combine(a=[1, 2], b=[2, 3])) + + def test_combine_single_parameter(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 2, + "b": 2 + }], combinations.combine(a=[1, 2], b=2)) + + def test_add(self): + self.assertEqual( + [{ + "a": 1 + }, { + "a": 2 + }, { + "b": 2 + }, { + "b": 3 + }], + combinations.combine(a=[1, 2]) + + combinations.combine(b=[2, 3])) + + def test_times(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1 + c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "eager")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "eager")]) + ], c4) + + def test_times_variable_arguments(self): + c1 = combinations.combine(mode=["graph", "eager"]) + c2 = combinations.combine(optimizer=["adam", "gd"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1, c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "gd")]) + ], c4) + self.assertEqual( + combinations.combine( + mode=["graph", "eager"], + optimizer=["adam", "gd"], + distribution=["d1", "d2"]), c4) + + def test_overlapping_keys(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"): + _ = combinations.times(c1, c2) + + +@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1])) +class CombineTheTestSuite(parameterized.TestCase): + + def test_add_things(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def test_add_things_one_more(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def not_a_test(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + def _test_but_private(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + # Check that nothing funny happens to a non-callable that starts with "_test". + test_member = 0 + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ae8b9712c392fa948c8598dd123cdea01d9866 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -0,0 +1,753 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes for different algorithms of reduction and broadcasting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import six + +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.client import device_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import device_util + + +def _validate_destinations(destinations): + if not isinstance(destinations, + (value_lib.DistributedValues, six.string_types, list)): + raise ValueError("destinations must be one of a `DistributedValues` object," + " a device string, a list of device strings or None") + + if not destinations: + raise ValueError("destinations can not be empty") + + +def _validate_value_destination_pairs(value_destination_pairs): + # pylint: disable=g-missing-docstring + if not value_destination_pairs: return False + if not isinstance(value_destination_pairs, (list, tuple)): return False + if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): + return False + if not all([isinstance(v[0], value_lib.PerDevice) + for v in value_destination_pairs]): + return False + return True + + +# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. +def _get_devices_from(destinations): + if isinstance(destinations, value_lib.DistributedValues): + return list(destinations.devices) + elif isinstance(destinations, six.string_types): + return [device_util.resolve(destinations)] + else: + return [device_util.resolve(destination) for destination in destinations] + + +def _devices_match(left, right): + return set(_get_devices_from(left)) == set(_get_devices_from(right)) + + +def _all_devices_match(value_destination_pairs): + if not all([d is None or _devices_match(v, d) + for v, d in value_destination_pairs]): + return False + if not all([_devices_match(v, value_destination_pairs[0][0]) + for v, _ in value_destination_pairs[1:]]): + return False + return True + + +def _simple_broadcast(value, destinations): + index = {} + devices = _get_devices_from(destinations) + for d in devices: + index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + value, d) + return value_lib.Mirrored(index) + + +def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, + method_string): + # pylint: disable=g-missing-docstring + all_values = [] + count = 0 + for v in per_device_value._index.values(): # pylint: disable=protected-access + if isinstance(v, value_lib.MapOutput): + v_list = v.get() + if not v_list: + continue + count += len(v_list) + # Sum within each device before aggregating across devices. + # TODO(yuefengz): Check whether it helps to use accumulation_fn here. + v = cross_tower_utils.aggregate_tensors_or_indexed_slices( + v_list, math_ops.add_n) + else: + count += 1 + all_values.append(v) + if not all_values: + raise ValueError("`per_device_value` must be non-empty") + + with ops.device(reduce_to_device): + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( + all_values, accumulation_fn) + if method_string == "mean": + reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( + reduced, count) + elif method_string != "sum": + raise ValueError("`method_string` must be 'sum' or 'mean'") + return reduced + + +class CrossTowerOps(object): + """Base class for cross-tower reduction and broadcasting algorithms.""" + + def __init__(self): + pass + + def reduce(self, method_string, per_device_value, destinations=None): + """Reduce `per_device_value` to `destinations`. + + It runs the reduction operation defined by `method_string` and put the + result on `destinations`. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + per_device_value: a PerDevice object. + destinations: the reduction destinations. + + Returns: + a Mirrored object. + + Raises: + ValueError: if per_device_value is not a PerDevice object. + """ + if not isinstance(per_device_value, value_lib.PerDevice): + raise ValueError("`per_device_value` must be a `PerDevice` object.") + if destinations is not None: + _validate_destinations(destinations) + return self._reduce(method_string, per_device_value, destinations) + + def batch_reduce(self, method_string, value_destination_pairs): + """Reduce PerDevice objects in a batch. + + Reduce each first element in `value_destination_pairs` to each second + element which indicates the destinations. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + value_destination_pairs: a list or a tuple of tuples of PerDevice objects + and destinations. If a destination is None, then the destinations + are set to match the devices of the input PerDevice object. + + Returns: + a list of Mirrored objects. + + Raises: + ValueError: if `value_destination_pairs` is not a list or a tuple of + tuples of PerDevice objects and destinations + """ + if not _validate_value_destination_pairs(value_destination_pairs): + raise ValueError("`value_destination_pairs` must be a list or a tuple of " + "tuples of PerDevice objects and destinations") + for _, d in value_destination_pairs: + if d is not None: + _validate_destinations(d) + + return self._batch_reduce(method_string, value_destination_pairs) + + def broadcast(self, tensor, destinations): + """Broadcast the `tensor` to destinations. + + Args: + tensor: the tensor to broadcast. + destinations: the broadcast destinations. + + Returns: + a Mirrored object. + """ + _validate_destinations(destinations) + return self._broadcast(tensor, destinations) + + def _reduce(self, method_string, per_device_value, destinations): + raise NotImplementedError( + "_reduce method must be implemented in descendants.") + + def _batch_reduce(self, method_string, value_destination_pairs): + raise NotImplementedError( + "_batch_reduce method must be implemented in descendants.") + + def _broadcast(self, tensor, destinations): + return _simple_broadcast(tensor, destinations) + + +class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): + """Always do reduction to one device first and then do broadcasting. + + Batch reduction is done by reduction on each element one by one. + """ + + def __init__(self, reduce_to_device=None, accumulation_fn=math_ops.add_n): + """Constructor. + + Args: + reduce_to_device: the intermediate device to reduce to. If None, reduce + to the first device in `destinations` of the reduce() method. + accumulation_fn: a function that does accumulation. + """ + self.reduce_to_device = reduce_to_device + self.accumulation_fn = accumulation_fn + super(ReductionToOneDeviceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = self.reduce_to_device or devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + self.accumulation_fn, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + return [self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs] + + +def _group_value_by_device(per_device_values): + """Group values into sublists by their devices. + + This grouping is needed to call the all-reduce library because it expects a + list of the following form: + [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... + (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... + (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + ... + ] + + Args: + per_device_values: a list of PerDevice obejcts. + + Returns: + a list of lists, each sublist has components for its corresponding device of + PerDevice objects, paired with a None. + """ + destinations = per_device_values[0].devices + grouped = [[] for _ in range(len(destinations))] + for per_device_value in per_device_values: + # pylint: disable=protected-access + for i, v in enumerate(per_device_value._index.values()): + assert per_device_value.devices == destinations + grouped[i].append((v, None)) + return grouped + + +def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): + """Ungroup results from all-reduce and make Mirrored objects. + + Each all-reduce result will be divided by the number of destinations before + Mirrored objects are created if method_string is "mean". + + Args: + grouped_reduced: a list of lists, each sublist has components for each + device, paired with a None. It is the result from + cross_tower_utils.aggregate_gradients_using*. + destinations: a list of device strings for returned Mirrored objects. + method_string: "mean" or "sum". + + Returns: + a list of Mirrored objects. + """ + index = [{} for _ in range(len(grouped_reduced[0]))] + for d, per_device_reduced in enumerate(grouped_reduced): + for i, (v, _) in enumerate(per_device_reduced): + if method_string == "mean": + index[i][destinations[d]] = v / len(destinations) + else: + index[i][destinations[d]] = v + return [value_lib.Mirrored(v) for v in index] + + +class ConcatAndSplitPacker(object): + """Concatenate and split tensors for reduction.""" + + def __init__(self, num_packs=1): + """Initialize the ConcatAndSplitPacker object. + + Args: + num_packs: specifies the number of split packs that will be + formed. + + Raises: + ValueError: if num_packs is not greater than 0. + """ + if num_packs <= 0: + raise ValueError("num_packs must be greater than zero.") + self.num_packs = num_packs + + def pack(self, grouped_grads_and_vars): + """Pack tensors.""" + self.grouped_grads_and_vars = grouped_grads_and_vars + self.all_tower_shapes = [] + self.all_tower_sizes = [] + + device_grad_packs = [] + for tower_grads_and_vars in grouped_grads_and_vars: + with ops.colocate_with(tower_grads_and_vars[0][0]): + # Flatten all the grads. + flat_grads = [ + array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars + ] + # Remember the original shape of all the grads. + tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars] + # Remember the original sizes of all the grads. + tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars] + # Concat all the flat grads into a big flat tensor. + concat_grads = array_ops.concat(flat_grads, 0) + + # Split the big tensor into num_splits packs. In cases where the + # total size is not divisible num_splits, the last pack gets + # more elements. + # TODO(zhengxq): it is also possible to optimize away all the concat + # as well. + num_splits = self.num_packs + + # The array_ops.size function will sometimes remove static shapes. So if + # all gradient shapes are defined, we use another method to get the + # total size. + # TODO(yuefengz): move this logic to array_ops.size. + if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]): + total_grad_size = sum( + [g.shape.num_elements() for g, _ in tower_grads_and_vars]) + else: + total_grad_size = array_ops.size(concat_grads) + + split_size = total_grad_size // num_splits + split_size_last = total_grad_size - split_size * (num_splits - 1) + split_sizes = [split_size] * (num_splits - 1) + [split_size_last] + grad_packs = array_ops.split(concat_grads, split_sizes) + + # Ready to aggregate the repacked gradients, with fake variables. + # TODO(zhengxq): It is hacky to have to use fake variables. + # We should remove the need for variables in + # aggregate_gradients_using*. + device_grad_packs.append(zip(grad_packs, [None] * num_splits)) + self.all_tower_shapes.append(tower_shapes) + self.all_tower_sizes.append(tower_sizes) + + return device_grad_packs + + def unpack(self, summed_device_grad_packs): + """Reverse the pack.""" + aggregated_device_grads = [] + for (summed_tower_grad_packs, + tower_grads_and_vars, tower_shapes, tower_sizes) in zip( + summed_device_grad_packs, self.grouped_grads_and_vars, + self.all_tower_shapes, self.all_tower_sizes): + # pylint: enable=line-too-long + # Reverse the packing operations in the previous steps. Form the + # summed gradients back into their original shapes. + with ops.colocate_with(summed_tower_grad_packs[0][0]): + # Form a list of the summed grad packs. + device_grad_packs = [g for g, _ in summed_tower_grad_packs] + + # Concat them back into a big flat tensor. + device_grads_concat = array_ops.concat(device_grad_packs, 0) + + # Split the tensors back into their original sizes. + grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes) + + # Reshape the tensors back into their original shapes. + grads_with_shapes = [ + array_ops.reshape(grad, shape) + for shape, grad in zip(tower_shapes, grads_with_sizes) + ] + + # Form the list with the original list of variables. + summed_tower_grads = [ + (g, v) for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars) + ] + aggregated_device_grads.append(summed_tower_grads) + return aggregated_device_grads + + +class AggregateSmallTensorPacker(object): + """Concatenate small gradient tensors together for reduction.""" + + def __init__(self, + agg_small_grads_max_bytes=1048576, + agg_small_grads_max_group=16): + """Initialize the AggregateSmallTensorPacker object. + + Args: + agg_small_grads_max_bytes: largest tensor eligible for aggregation, + in number of bytes. + agg_small_grads_max_group: largest permitted aggregation of small + tensors. + + Raises: + ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group` + is not greater than 0. + """ + if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0: + raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group" + " should both be greater than zero.") + self.agg_small_grads_max_bytes = agg_small_grads_max_bytes + self.agg_small_grads_max_group = agg_small_grads_max_group + + def pack(self, grouped_grads_and_vars): + """Aggregate small tensors.""" + if (self.agg_small_grads_max_bytes > 0 and + self.agg_small_grads_max_group > 0): + tower_grads, self.packing = cross_tower_utils.pack_small_tensors( + grouped_grads_and_vars, + max_bytes=self.agg_small_grads_max_bytes, + max_group=self.agg_small_grads_max_group) + return tower_grads + + def unpack(self, summed_device_grad_packs): + """Reverse the aggregation process.""" + return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs, + self.packing) + + +def _pack_tensors(device_grads, + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=0): + """Pack tensors if specified.""" + if num_packs > 0: + tensor_packer = ConcatAndSplitPacker(num_packs) + device_grad_packs = tensor_packer.pack(device_grads) + elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: + tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes, + agg_small_grads_max_group) + device_grad_packs = tensor_packer.pack(device_grads) + else: + tensor_packer = None + device_grad_packs = device_grads + return device_grad_packs, tensor_packer + + +def _unpack_tensors(reduced, tensor_packer=None): + """Unpack tensors if they are packed before all-reduce.""" + if tensor_packer: + return tensor_packer.unpack(reduced) + return reduced + + +class AllReduceCrossTowerOps(CrossTowerOps): + """Reduction using all reduce.""" + + def __init__(self, + all_reduce_alg="nccl", + num_packs=1, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """All-reduce implementation of CrossTowerOps. + + Before performing all-reduce, tensors will be repacked or aggregated for + more efficient cross-device transportation: + 1) If `num_packs` is non-zero, pack values into + `num_packs` splits. + 2) Otherwise, if `agg_small_grads_max_bytes` > 0 and + `agg_small_grads_max_group` > 0, aggregate values smaller than + `agg_small_grads_max_bytes` into groups with at most + `agg_small_grads_max_group` values. + 3) Otherwise, no repacking or grouping will happen. + + Args: + all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or + "hierarchical_copy" are supported. + num_packs: see above. + agg_small_grads_max_bytes: see above. + agg_small_grads_max_group: see above. + tensors. + """ + self._all_reduce_alg = all_reduce_alg + self._num_packs = num_packs + self._agg_small_grads_max_bytes = agg_small_grads_max_bytes + self._agg_small_grads_max_group = agg_small_grads_max_group + super(AllReduceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + per_device_value) + if ((destinations is None or _devices_match(per_device_value, destinations)) + and not context.executing_eagerly() + and not contains_indexed_slices): + return self._batch_all_reduce(method_string, [per_device_value])[0] + else: + if contains_indexed_slices: + logging.log_first_n( + logging.WARN, + "Efficient allreduce is not supported for IndexedSlices.", 10) + + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + math_ops.add_n, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + all_devices_match = _all_devices_match(value_destination_pairs) + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + value_destination_pairs) + if (all_devices_match and not context.executing_eagerly() + and not contains_indexed_slices): + return self._batch_all_reduce(method_string, + [v[0] for v in value_destination_pairs]) + else: + if not all_devices_match: + logging.warning("Efficient batch_reduce is not supported if " + "destinations are different.") + + return [ + self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + destinations = per_device_values[0].devices + grouped = _group_value_by_device(per_device_values) + + device_grad_packs, self._tensor_packer = _pack_tensors( + grouped, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + + # The actual aggregation of the repacked gradients. Note that they are + # sharded among different aggregation trees. So it is important to strike + # the balance on num_splits. + if self._all_reduce_alg == "nccl": + # TODO(yuefengz): merge this into the all-reduce library. + reduced = cross_tower_utils.aggregate_gradients_using_nccl( + device_grad_packs) + else: + # TODO(yuefengz): check that gpu ids in `destinations` are in ascending + # order. + reduced = ( + cross_tower_utils.aggregate_gradients_using_hierarchical_copy( + destinations, device_grad_packs)) + + reduced = _unpack_tensors(reduced, self._tensor_packer) + return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, + method_string) + + +AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", + "alg shards limit") + + +class MultiWorkerAllReduce(AllReduceCrossTowerOps): + """All-reduce algorithms for distributed TensorFlow.""" + + def __init__(self, + worker_devices, + num_gpus_per_worker, + all_reduce_spec=("pscpu/pscpu", 2, -1), + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """Initialize the all-reduce algorithm. + + Args: + worker_devices: a list of device strings for workers participating in + all-reduce. + num_gpus_per_worker: number of GPU devices per worker. + all_reduce_spec: a tuple or a named tuple or a list of tuples specifying + the all-reduce algorithm. + 1. The first element of a tuple is the name of the all-reduce algorithm. + Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", + "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with + a "/" are hierarchical, so two all-reduces are executed, the first one + aggregates tensors within a worker and the second aggregates across + workers. + 2. The second element of a tuple is the number of shards when doing + all-reduce. Let's say its values is M, each tensor after packing will be + split into M shards and then M parallel all-reduces would be performed + before finally they are concatenated backed into a complete tensor. + 3. The third element is the maximum size of tensors that will be + applicable for the algorithm specified by the first element. For + example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], + tensors with size not larger than 1024 bytes will be applied a 2-shard + "nccl" all-reduce and other tensors will be applied a 2-shard + "pscpu/pscpu" algorithm. The third elements should be in increasing + order across tuples and end with -1 which indicates infinity. + num_packs: see AllReduceCrossTowerOps. + agg_small_grads_max_bytes: see AllReduceCrossTowerOps. + agg_small_grads_max_group: see AllReduceCrossTowerOps. + """ + self._worker_devices = worker_devices + self._num_gpus_per_worker = num_gpus_per_worker + super(MultiWorkerAllReduce, self).__init__( + num_packs=num_packs, + agg_small_grads_max_bytes=agg_small_grads_max_bytes, + agg_small_grads_max_group=agg_small_grads_max_group) + + def validate_and_complete_spec(spec): + """Validate and complete the all-reduce spec.""" + # TODO(yuefengz): support namedtuple. + if not isinstance(spec, tuple): + raise ValueError( + "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) + if not spec or len(spec) > 3: + raise ValueError( + "Too many elements in the all-reduce spec tuple: %r" % spec) + if len(spec) == 1: + return AllReduceSpecTuple(spec[0], 1, -1) + elif len(spec) == 2: + return AllReduceSpecTuple(spec[0], spec[1], -1) + else: + return AllReduceSpecTuple(*spec) + + self._all_reduce_spec = [] + if isinstance(all_reduce_spec, six.string_types): + self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) + elif isinstance(all_reduce_spec, tuple): + self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) + elif isinstance(all_reduce_spec, list): + self._all_reduce_spec = [ + validate_and_complete_spec(spec) for spec in all_reduce_spec + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + logging.info( + "distributed batch_all_reduce invoked for batches size = %d with " + "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " + "and agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + + destinations = sorted(per_device_values[0].devices) + device_grads = _group_value_by_device(per_device_values) + + # The all reduce library requires fully defined shapes. + # TODO(yuefengz): when tensor sharding is not needed, static shapes are not + # required as well. + for device_grad in device_grads: + for grad, _ in device_grad: + if not grad.shape.is_fully_defined(): + raise ValueError("Shape is unknown for node %r" % grad) + + remaining_grads = device_grads + aggregated_grads = [] + for spec_tuple in self._all_reduce_spec: + if spec_tuple.limit < 0: + this_grads = remaining_grads + remaining_grads = [] + else: + (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( + spec_tuple.limit, remaining_grads) + if this_grads: + device_grad_packs, self._tensor_packer = _pack_tensors( + this_grads, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( + self._worker_devices, device_grad_packs, len(self._worker_devices), + spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) + range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer) + + if not aggregated_grads: + aggregated_grads = range_agg_grads + else: + assert len(aggregated_grads) == len(range_agg_grads) + for i in range(len(aggregated_grads)): + aggregated_grads[i] += range_agg_grads[i] + assert not remaining_grads + + return _ungroup_and_make_mirrored(aggregated_grads, destinations, + method_string) + + +_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + + +def _has_dgx1_like_links(gpu_links): + if not gpu_links: + return False + # TODO(yuefengz): figure out the right topology for hierarchial copy if + # number of gpus are less than 8. + if len(gpu_links) < 8: + return False + for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)): + if (set(gpu_link) != set(dgx1_link) and + set(gpu_link) != set(dgx1_link + [i])): + return False + return True + + +def _choose_all_reduce_algorithm(device_links): + if _has_dgx1_like_links(device_links): + logging.info("Configured hierarchical_copy with num_packs=%d", + len(device_links)) + return AllReduceCrossTowerOps( + "hierarchical_copy", num_packs=len(device_links)) + else: + logging.info("Configured nccl all-reduce.") + return AllReduceCrossTowerOps("nccl", num_packs=1) + + +def choose_the_best(devices, session_config=None): + """Find the best subclass of CrossTowerOps given a tensorflow session. + + Args: + devices: a list of devices passed for distribute strategy. + session_config: a tensorflow session config or None. If None, it will make + deciesion based on all local devices. + + Returns: + a subclass of CrossTowerOps. + """ + requested_devices = set([device_util.canonicalize(d) for d in devices]) + machine_devices = device_lib.list_local_devices(session_config=session_config) + using_devices = [] + for d in machine_devices: + if device_util.canonicalize(d.name) in requested_devices: + using_devices.append(d) + else: + logging.info( + "Device is available but not used by distribute strategy: %s", d.name) + + if len(using_devices) != len(requested_devices): + logging.warning("Not all devices in distribute strategy are visible by " + "TensorFlow sessions.") + return ReductionToOneDeviceCrossTowerOps() + + if any([d.device_type.lower() != "gpu" for d in using_devices]): + logging.warning("Not all devices in DistributionStrategy are visible to " + "TensorFlow session.") + return ReductionToOneDeviceCrossTowerOps() + + device_links = [[] for _ in range(len(using_devices))] + for i, device in enumerate(using_devices): + for link in device.locality.links.link: + device_links[i].append(link.device_id) + + return _choose_all_reduce_algorithm(device_links) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fed5505d92ef2544215069736c166a67d6141708 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -0,0 +1,366 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CrossTowerOps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +def _make_per_device(values, devices): + devices = cross_tower_ops_lib._get_devices_from(devices) + assert len(values) == len(devices) + index = {} + for d, v in zip(devices, values): + with ops.device(d): + placed_v = array_ops.identity(v) + index[d] = placed_v + return value_lib.PerDevice(index) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _fake_mirrored(value, devices): + """Create a faked Mirrored object for testing. + + All components of the returned Mirrored have the same objects, which is not + true in reality. + """ + devices = cross_tower_ops_lib._get_devices_from(devices) + return value_lib.Mirrored( + {d: v for d, v in zip(devices, [value] * len(devices))}) + + +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + +_cpu_device = "/device:CPU:0" + + +class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): + + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(left.devices, right.devices) + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.items(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): + self.assertEqual([v.numpy() for v in left._index.values()], + list(right._index.values())) + else: + with self.test_session() as sess: + self.assertEqual( + sess.run(list(left._index.values())), list(right._index.values())) + + def _testReductionAndBroadcast(self, cross_tower_ops, distribution): + devices = distribution.worker_devices + + values = [constant_op.constant(float(d)) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = (len(devices) - 1.) / 2. + + values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = mean + 1. + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + cross_tower_ops.reduce("mean", per_device, destinations=destinations), + _fake_mirrored(mean, destinations or per_device)) + self._assert_values_equal( + cross_tower_ops.reduce( + "mean", per_device_2, destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device)) + self._assert_values_equal( + cross_tower_ops.reduce("sum", per_device, destinations=destinations), + _fake_mirrored(mean * len(devices), destinations or per_device)) + self._assert_values_equal( + cross_tower_ops.reduce( + "sum", per_device_2, destinations=destinations), + _fake_mirrored(mean_2 * len(devices), destinations or per_device)) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + cross_tower_ops.batch_reduce( + "mean", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2)]) + self._assert_values_equal( + cross_tower_ops.batch_reduce( + "sum", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) + + # test broadcast() + for destinations in all_destinations: + if destinations is None: + continue + else: + self._assert_values_equal( + cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + _fake_mirrored(1., destinations)) + + +class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) + + def testChooseAlgorithm(self): + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) + + # if there are only 4 devices + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + # if devices links contain each device itself + device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], + [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], + [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) + + # if not dgx1-like links + device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], + [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], + math_ops.add_n, "sum") + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate(combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + method_string=["sum", "mean"], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, + method_string, batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_tower_ops_instance.batch_reduce(method_string, + [(per_device, devices)]) + else: + result = cross_tower_ops_instance.reduce(method_string, per_device, + devices) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if method_string == "sum": + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert method_string == "mean" + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + + +class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, + CrossTowerOpsTestBase): + + worker_devices = [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + multi_worker_allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "MultiWorkerAllReduce", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReducePack", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReduceAggregation", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), + combinations.NamedObject( + "MultiWorkerAllReduceMultipleSpecs", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, [("pscpu/pscpu", 2, 100), + ("xring", 2, -1)], 0, 0, 0)), + ], + distribution=[ + combinations.multi_worker_strategy_with_cpu, + combinations.multi_worker_strategy_with_one_gpu, + combinations.multi_worker_strategy_with_two_gpus + ], + mode=["graph"]) + + @combinations.generate(multi_worker_allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb088e704c584598b863b1b836166af2a5bb12c --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -0,0 +1,527 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for cross_tower_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections as pycoll + +from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops + + +def aggregate_gradients_using_nccl(tower_grads): + """Aggregate gradients using nccl allreduce.""" + agg_all_g_and_v = [] + for single_g_and_v in zip(*tower_grads): + single_grads = [g for g, _ in single_g_and_v] + agg_grads = nccl.all_sum(single_grads) + agg_all_g_and_v.append( + [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) + + agg_all_g_and_v = list(zip(*agg_all_g_and_v)) + + return agg_all_g_and_v + + +def aggregate_gradients_using_hierarchical_copy(avail_devices, tower_grads): + """Aggregate gradients using hierarchical copies. + + Args: + avail_devices: available GPU devices. + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over towers. The inner list is over individual gradients. + + Returns: + The list of (aggregated_gradient, variable), where the gradient has been + summed across all towers and the variable is chosen from the first tower. + """ + # This only works for DGX-1 type of machine topology + # Device peer to peer matrix + # DMA: 0 1 2 3 4 5 6 7 + # 0: Y Y Y Y Y N N N + # 1: Y Y Y Y N Y N N + # 2: Y Y Y Y N N Y N + # 3: Y Y Y Y N N N Y + # 4: Y N N N Y Y Y Y + # 5: N Y N N Y Y Y Y + # 6: N N Y N Y Y Y Y + # 7: N N N Y Y Y Y Y + agg_grads = [] + num_devices = len(avail_devices) + # In the special case of DGX-1 machine topology, the two groups have equal + # size. + group_size = num_devices // 2 + for i, single_grads in enumerate(zip(*tower_grads)): + group_0_main_device = i % num_devices + group_1_main_device = (group_0_main_device + group_size) % num_devices + if group_0_main_device < group_size: + group_0_begin = 0 + group_1_begin = group_size + else: + group_0_begin = group_size + group_1_begin = 0 + + # Aggregate the first group. + group_0_device_grads = single_grads[group_0_begin: + group_0_begin + group_size] + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads, _ = aggregate_single_gradient_using_copy( + group_0_device_grads, False, False) + + # Aggregate the second group. + group_1_device_grads = single_grads[group_1_begin: + group_1_begin + group_size] + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads, _ = aggregate_single_gradient_using_copy( + group_1_device_grads, False, False) + + # Aggregate between the groups. + with ops.device(avail_devices[group_0_main_device]): + (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( + [group_0_agg_grads, group_1_agg_grads], False, False) + + # Broadcast the result back into the root of each group. + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) + + agg_grads_bcast = [] + for j in range(len(single_grads)): + with ops.device(avail_devices[j]): + # Broadcast the result back to each member in the group from the root. + if (group_0_main_device < group_size) == (j < group_size): + src_device_grad = group_0_agg_grads_bcast + else: + src_device_grad = group_1_agg_grads_bcast + agg_grads_bcast.append(array_ops.identity(src_device_grad)) + + agg_grads.append( + [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) + + agg_grads = list(zip(*agg_grads)) + + return agg_grads + + +def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, + check_inf_nan): + """Calculate the average gradient for a shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + grad_and_vars: A list or tuple of (gradient, variable) tuples. Each + (gradient, variable) pair within the outer list represents the gradient + of the variable calculated for a single tower, and the number of pairs + equals the number of towers. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + grads = [g for g, _ in grad_and_vars] + grad = math_ops.add_n(grads) + + if use_mean and len(grads) > 1: + grad = array_ops.multiply(grad, 1.0 / len(grads)) + + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = array_ops.logical_not( + array_ops.reduce_all(array_ops.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None + + +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: a list of canonical device strings. + group_size: integer which is equal to or greater than 1. + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size == 0 then each device will appear exactly once. + + Raises: + ValueError: if group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError( + 'only %d devices, but group_size=%d' % (num_devices, group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= threshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with ops.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + math_ops.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, math_ops.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, math_ops.add, math_ops.add_n) + elif alg == 'pscpu/pscpu': + second_gather_devices = aux_devices[:num_shards] + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, math_ops.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, + num_shards, gpu_indices): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + + Returns: + list of reduced tensors + """ + alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']]) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + # Auxiliary devices for hierarchical all-reduces. + aux_device_groups = group_device_names( + aux_devices, num_shards if alg_contains_shuffle else 1) + group_index = 0 + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads + + +def extract_ranges(index_list, range_size_limit=32): + """Extract consecutive ranges and singles from index_list. + + Args: + index_list: List of monotone increasing non-negative integers. + range_size_limit: Largest size range to return. If a larger + consecutive range exists, it will be returned as multiple + ranges. + + Returns: + (ranges, singles) where ranges is a list of [first, last] pairs of + consecutive elements in index_list, and singles is all of the + other elements, in original order. + """ + if not index_list: + return [], [] + first = index_list[0] + last = first + ranges = [] + singles = [] + for i in index_list[1:]: + if i == last + 1 and (last - first) <= range_size_limit: + last = i + else: + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + first = i + last = i + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + return ranges, singles + + +GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') + + +def pack_range(key, packing, grad_vars, rng): + """Form the concatenation of a specified range of gradient tensors. + + Args: + key: Value under which to store meta-data in packing that will be used + later to restore the grad_var list structure. + packing: Dict holding data describing packed ranges of small tensors. + grad_vars: List of (grad, var) pairs for one tower. + rng: A pair of integers giving the first, last indices of a consecutive + range of tensors to be packed. + + Returns: + A tensor that is the concatenation of all the specified small tensors. + """ + to_pack = grad_vars[rng[0]:rng[1] + 1] + members = [] + variables = [] + restore_shapes = [] + with ops.name_scope('pack'): + for g, v in to_pack: + variables.append(v) + restore_shapes.append(g.shape) + with ops.device(g.device): + members.append(array_ops.reshape(g, [-1])) + packing[key] = GradPackTuple( + indices=range(rng[0], rng[1] + 1), + vars=variables, + shapes=restore_shapes) + with ops.device(members[0].device): + return array_ops.concat(members, 0) + + +def unpack_grad_tuple(gv, gpt): + """Unpack a previously packed collection of gradient tensors. + + Args: + gv: A (grad, var) pair to be unpacked. + gpt: A GradPackTuple describing the packing operation that produced gv. + + Returns: + A list of (grad, var) pairs corresponding to the values that were + originally packed into gv, maybe following subsequent operations like + reduction. + """ + elt_widths = [x.num_elements() for x in gpt.shapes] + with ops.device(gv[0][0].device): + with ops.name_scope('unpack'): + splits = array_ops.split(gv[0], elt_widths) + unpacked_gv = [] + for idx, s in enumerate(splits): + unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]), + gpt.vars[idx])) + return unpacked_gv + + +def pack_small_tensors(tower_grads, max_bytes=0, max_group=0): + """Concatenate small gradient tensors together for reduction. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. + max_bytes: Int giving max number of bytes in a tensor that + may be considered small. + max_group: Int giving max number of small tensors that may be + concatenated into one new tensor. + + Returns: + new_tower_grads, packing where new_tower_grads is identical to + tower_grads except that all feasible small_tensors have been removed + from their places and concatenated into larger tensors that are + now in the front of the list for each tower, and packing contains + the data necessary to restore the tower_grads structure. + + Look through the first tower for gradients of the same type (float), + and small size, that are all sequential. For each such group, + replace by a new tensor that is a flattened concatenation. Note + that the corresponding variable will be absent, which doesn't matter + because it isn't used during all-reduce. + + Requires: + Every gv_list in towers must have isomorphic structure including identical + tensor sizes and types. + """ + small_indices = [] + large_indices = [] + for idx, (g, _) in enumerate(tower_grads[0]): + if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes: + small_indices.append(idx) + else: + large_indices.append(idx) + small_ranges, small_singles = extract_ranges( + small_indices, range_size_limit=max_group) + large_indices = sorted(large_indices + small_singles) + num_gv = len(tower_grads[0]) + packing = {} + if small_ranges: + new_tower_grads = [] + for dev_idx, gv_list in enumerate(tower_grads): + assert len(gv_list) == num_gv + new_gv_list = [] + for r in small_ranges: + key = '%d:%d' % (dev_idx, len(new_gv_list)) + new_gv_list.append((pack_range(key, packing, gv_list, r), + 'packing_var_placeholder')) + for i in large_indices: + new_gv_list.append(gv_list[i]) + new_tower_grads.append(new_gv_list) + return new_tower_grads, packing + else: + return tower_grads, None + + +def unpack_small_tensors(tower_grads, packing): + """Undo the structure alterations to tower_grads done by pack_small_tensors. + + Args: + tower_grads: List of List of (grad, var) tuples. + packing: A dict generated by pack_small_tensors describing the changes + it made to tower_grads. + + Returns: + new_tower_grads: identical to tower_grads except that concatenations + of small tensors have been split apart and returned to their original + positions, paired with their original variables. + """ + if not packing: + return tower_grads + new_tower_grads = [] + num_devices = len(tower_grads) + num_packed = len(packing.keys()) // num_devices + for dev_idx, gv_list in enumerate(tower_grads): + gv_list = list(gv_list) + new_gv_list = gv_list[num_packed:] + for i in range(num_packed): + k = '%d:%d' % (dev_idx, i) + gpt = packing[k] + gv = unpack_grad_tuple(gv_list[i], gpt) + for gi, idx in enumerate(gpt.indices): + assert idx == gpt.indices[gi] + new_gv_list.insert(idx, gv[gi]) + new_tower_grads.append(new_gv_list) + return new_tower_grads + + +def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): + """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" + if any(isinstance(v, ops.IndexedSlices) for v in values): + return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access + else: + return accumulation_fn(values) + + +def divide_by_n_tensors_or_indexed_slices(value, n): + if isinstance(value, ops.IndexedSlices): + value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access + return ops.IndexedSlices( + value.values / n, value.indices, value.dense_shape) + else: + return value / n + + +def copy_tensor_or_indexed_slices_to_device(value, device): + with ops.device(device): + if isinstance(value, ops.IndexedSlices): + copied_values = array_ops.identity(value.values) + copied_indices = array_ops.identity(value.indices) + copied_shape = array_ops.identity(value.dense_shape) + result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) + else: + result = array_ops.identity(value) + return result + + +def contains_indexed_slices(value): + """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" + if isinstance(value, ops.IndexedSlices): + return True + elif isinstance(value, (list, tuple)) and value: + return any(contains_indexed_slices(v) for v in value) + elif isinstance(value, value_lib.DistributedValues): + return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access + elif isinstance(value, value_lib.MapOutput): + return contains_indexed_slices(value.get()) + else: + return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef8db681503dcef8c72f641455dbb999cef05cf --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_tower_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDevice(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDeviceMapOutput(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({ + "/gpu:0": value_lib.MapOutput([t0]), + "/cpu:0": value_lib.MapOutput([t1])}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34410a6470185ac2821bc6a59de9230ff478aeb6 --- /dev/null +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -0,0 +1,128 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show that DistributionStrategy works with canned Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.optimizer_v2 import adagrad +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.summary.writer import writer_cache + + +class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, + parameterized.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def dataset_input_fn(self, x, y, batch_size, shuffle): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + if shuffle: + dataset = dataset.shuffle(batch_size) + dataset = dataset.repeat(10).batch(batch_size) + return dataset + + return input_fn + + @combinations.generate( + combinations.combine( + mode=['graph'], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ])) + def test_complete_flow_with_mode(self, distribution): + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + train_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices), + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, y=data, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, batch_size=batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + estimator = dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=label_dimension, + model_dir=self._model_dir, + # TODO(isaprykin): Work around the colocate_with error. + dnn_optimizer=adagrad.AdagradOptimizer(0.001), + linear_optimizer=adagrad.AdagradOptimizer(0.001), + config=run_config.RunConfig(train_distribute=distribution)) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..cbfd17850212a1c007e2edb9dd3986b3109f040d --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -0,0 +1,30 @@ +# Example TensorFlow models that use DistributionStrategy for training. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "simple_estimator_example", + srcs = ["simple_estimator_example.py"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "simple_tfkeras_example", + srcs = [ + "simple_tfkeras_example.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py new file mode 100644 index 0000000000000000000000000000000000000000..00c25c7a2482a559c8b94ff3be86c4961dfb439f --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple example to test the a DistributionStrategy with Estimators. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def build_model_fn_optimizer(): + """Simple model_fn with optimizer.""" + # TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is + # done? + optimizer = tf.train.GradientDescentOptimizer(0.2) + + def model_fn(features, labels, mode): # pylint: disable=unused-argument + """model_fn which uses a single unit Dense layer.""" + # You can also use the Flatten layer if you want to test a model without any + # weights. + layer = tf.layers.Dense(1, use_bias=True) + logits = layer(features) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"logits": logits} + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + def loss_fn(): + y = tf.reshape(logits, []) - tf.constant(1.) + return y * y + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec(mode, loss=loss_fn()) + + assert mode == tf.estimator.ModeKeys.TRAIN + + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss_fn(), global_step=global_step) + return tf.estimator.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op) + + return model_fn + + +def main(_): + distribution = tf.contrib.distribute.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + config = tf.estimator.RunConfig(train_distribute=distribution) + + def input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + labels = tf.data.Dataset.from_tensors([1.]).repeat(10) + return tf.data.Dataset.zip((features, labels)) + + estimator = tf.estimator.Estimator( + model_fn=build_model_fn_optimizer(), config=config) + estimator.train(input_fn=input_fn, steps=10) + + eval_result = estimator.evaluate(input_fn=input_fn) + print("Eval result: {}".format(eval_result)) + + def predict_input_fn(): + predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + return predict_features + + predictions = estimator.predict(input_fn=predict_input_fn) + # TODO(anjalsridhar): This returns a generator object, figure out how to get + # meaningful results here. + print("Prediction results: {}".format(predictions)) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py new file mode 100644 index 0000000000000000000000000000000000000000..2b05884b9b93470ef9a764cbedbc91bd3912c611 --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An example of training tf.keras Model using MirroredStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import numpy as np +import tensorflow as tf + + +def input_fn(): + x = np.random.random((1024, 10)) + y = np.random.randint(2, size=(1024, 1)) + x = tf.cast(x, tf.float32) + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(10) + dataset = dataset.batch(32) + return dataset + + +def main(args): + if len(args) < 2: + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example/.') + return + + model_dir = args[1] + print('Using %s to store checkpoints.' % model_dir) + + # Define tf.keras Model. + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) + model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + + # Compile tf.keras Model. + optimizer = tf.train.GradientDescentOptimizer(0.2) + model.compile(loss='binary_crossentropy', optimizer=optimizer) + model.summary() + tf.keras.backend.set_learning_phase(True) + + # Define a DistributionStrategy and convert the tf.keras Model to a + # tf.Estimator that utilizes the DistributionStrategy. + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) + keras_estimator = tf.keras.estimator.model_to_estimator( + keras_model=model, config=config, model_dir=model_dir) + + # Train and evaluate the tf.Estimator. + keras_estimator.train(input_fn=input_fn, steps=10) + eval_result = keras_estimator.evaluate(input_fn=input_fn) + print('Eval result: {}'.format(eval_result)) + + +if __name__ == '__main__': + tf.app.run(argv=sys.argv) diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1f24f629479b6ae93bbb8a6dfe0b33c4f6a7da35 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Input-pipeline utilities for Distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging + +# TODO(priyag): Any other reader datasets to consider here? +_READER_DATASET_OPS = [ + "TextLineDataset", + "TFRecordDataset", + "FixedLengthRecordDataset" +] + + +# pylint: disable=protected-access +def auto_shard_dataset(dataset, num_shards, index): + """Shard the input pipeline by sharding the underlying list of files. + + Args: + dataset: A `tf.data.Dataset` instance, typically the result of a bunch of + dataset transformations. + num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of + shards operating in parallel. Same usage as in `Dataset.shard`. + index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. + Same usage as in `Dataset.shard`. + + Returns: + A modified `Dataset` obtained by updating the pipeline sharded by the + files. + + Raises: + NotImplementedError: If we cannot automatically determine a good way to + shard the input dataset. + """ + + # TODO(priyag): Clone datasets instead of updating in place, similar to the + # clone method for TFRecordDataset. + def _auto_shard_impl(dataset, found_reader_op): + """Recursive implementation of auto sharding.""" + + if not found_reader_op: + # TODO(priyag): Make this check more robust by enforcing some common + # property on reader datasets. + if (isinstance(dataset, readers.TextLineDataset) or + isinstance(dataset, readers.FixedLengthRecordDataset)): + filenames_tensor = dataset._filenames + num_files = array_ops.size(filenames_tensor) + sharded_filenames_tensor = array_ops.gather( + filenames_tensor, math_ops.range(index, num_files, num_shards)) + dataset._filenames = sharded_filenames_tensor + return dataset + elif isinstance(dataset, readers.TFRecordDataset): + # `TFRecordDataset` needs to be handled separately than other readers + # because it converts filenames to a dataset first. Also, we clone it + # instead of updating in place because it has special logic in the + # constructor. Eventually we will change all cases to clone datasets + # instead of updating in-place. + return dataset._clone( + filenames=dataset._filenames.shard(num_shards, index)) + elif hasattr(dataset, "_map_func"): + # TODO(priyag): Make this check more robust by enforcing some common + # property on all map/flatmap/interleave datasets. + map_func_def = dataset._map_func.definition + for node in map_func_def.node_def: + if node.op in _READER_DATASET_OPS: + found_reader_op = True + break + elif node.op == "FlatMapDataset": + # TODO(priyag): Should this check for other map datasets? Should it + # be recursive? It is too specific to implementation of + # TFRecordDataset right now. + nested_func_name = node.attr["f"].func.name + nested_func = ops.get_default_graph()._functions[nested_func_name] + for nested_node in nested_func.definition.node_def: + if nested_node.op in _READER_DATASET_OPS: + found_reader_op = True + break + if found_reader_op: + break + if found_reader_op: + dataset._input_dataset = _auto_shard_impl( + dataset._input_dataset, found_reader_op) + return dataset + + # TODO(priyag): Make _input_dataset(s) a common property of all datasets to + # make this check more robust. + if hasattr(dataset, "_input_dataset"): + dataset._input_dataset = _auto_shard_impl( + dataset._input_dataset, found_reader_op) + if hasattr(dataset, "_dataset_to_concatenate"): + # Special case for `ConcatentateDataset`. We want to shard all input + # datasets. + dataset._dataset_to_concatenate = _auto_shard_impl( + dataset._dataset_to_concatenate, found_reader_op) + return dataset + + if hasattr(dataset, "_datasets"): + # Special case for `ZipDataset`. + dataset._datasets = nest.pack_sequence_as(dataset._datasets, [ + _auto_shard_impl(ds, found_reader_op) + for ds in nest.flatten(dataset._datasets) + ]) + return dataset + + if not found_reader_op: + tf_logging.warn( + "Could not find a standard reader in the input pipeline" + "(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)." + "Falling back to sharding the dataset anyway. Please verify" + "correctness of auto-sharding for your input.") + + # TODO(priyag): What do we want to do if the number of filenames is + # uneven in the number of shards? By default, this will just return as + # many items it can before throwing OutOfRangeError. + # TODO(priyag): This will shard the filenames before any shuffling of the + # filename dataset. It might be desirable to shard after shuffling + # filenames? If so, how do we achieve that? + return dataset.shard(num_shards, index) + + return _auto_shard_impl(dataset=dataset, found_reader_op=False) diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..16179c3a4903c8149800d411853af734c1633466 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_ops_test.py @@ -0,0 +1,265 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for input pipeline modifications for distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.distribute.python import input_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import errors +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class AutoShardDatasetTest(test.TestCase): + + def setUp(self): + super(AutoShardDatasetTest, self).setUp() + self._num_files = 10 + self._num_records = 4 + self._num_shards = 2 + self._shard_index = 0 + self._record_bytes = 10 + + def _record(self, r, f): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _text_line(self, r, f): + return compat.as_bytes("Text line %d of file %d" % (r, f)) + + def _fixed_length_record(self, r, f): + return compat.as_bytes(str((r * f) % 10) * self._record_bytes) + + def _createTFRecordFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + record = self._record(j, i) + writer.write(record) + writer.close() + return filenames + + def _createTextFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(self._num_records): + contents.append(self._text_line(j, i)) + if j + 1 != self._num_records or i == 0: + contents.append(b"\r\n") + contents = b"".join(contents) + + with open(fn, "wb") as f: + f.write(contents) + return filenames + + def _createFixedLengthRecordFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + for j in range(self._num_records): + f.write(self._fixed_length_record(j, i)) + return filenames + + def _verifySimpleShardingOutput(self, dataset, record_fn): + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(record_fn(r, f), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testTFRecordDataset(self): + dataset = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._record) + + def testFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._record) + + def testInterleave(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.interleave( + readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Since block_length == num records in each file, the output will still + # contain records in order of files. + self._verifySimpleShardingOutput(dataset, self._record) + + def testParallelInterleave(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.apply(interleave_ops.parallel_interleave( + readers.TFRecordDataset, + cycle_length=4, + block_length=self._num_records)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Since block_length == num records in each file, the output will still + # contain records in order of files. + self._verifySimpleShardingOutput(dataset, self._record) + + def testListfiles(self): + filenames = self._createTFRecordFiles() + file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt" + dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + actual, expected = [], [] + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + actual.append(sess.run(next_element)) + expected.append(self._record(r, f)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self.assertAllEqual(expected, actual) + + def testComplexPipeline(self): + # Setup a complex input pipeline. + batch_size = 2 + num_epochs = 5 + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.shuffle(buffer_size=self._num_files) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = dataset.prefetch(buffer_size=batch_size) + dataset = dataset.shuffle(2 * self._num_files * self._num_records) + dataset = dataset.repeat(num_epochs) + dataset = dataset.apply(batching.map_and_batch( + lambda x: x, batch_size=batch_size)) + dataset = dataset.prefetch(buffer_size=None) + + # Auto shard. + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Verify output. + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + actual = [] + num_iterations = (self._num_files * self._num_records * num_epochs) // ( + self._num_shards * batch_size) + for _ in range(num_iterations): + actual.extend(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + expected = [] + for f in range(0, self._num_files, self._num_shards): + for r in range(self._num_records): + expected.append(self._record(r, f)) + expected *= num_epochs + + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testZip(self): + dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset2 = readers.TextLineDataset(self._createTextFiles()) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) + self._verifySimpleShardingOutput(dataset, record_fn) + + def testConcat(self): + dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset2 = readers.TextLineDataset(self._createTextFiles()) + dataset = dataset1.concatenate(dataset2) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(self._record(r, f), sess.run(next_element)) + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(self._text_line(r, f), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testTextLineReader(self): + dataset = readers.TextLineDataset(self._createTextFiles()) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._text_line) + + def testTextLineReaderWithFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles()) + dataset = dataset.flat_map(readers.TextLineDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._text_line) + + def testFixedLengthReader(self): + dataset = readers.FixedLengthRecordDataset( + self._createFixedLengthRecordFiles(), self._record_bytes) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + + def testFixedLengthReaderWithFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createFixedLengthRecordFiles()) + dataset = dataset.flat_map( + lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75ecd90dcffa7a786b78238ef453c4c8e4346afa --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -0,0 +1,148 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras Sequential and Functional models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), + 'keras_mirrored_strategy_test') + gfile.MakeDirs(self._base_dir) + self._config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + + def tearDown(self): + writer_cache.FileWriterCache.clear() + if os.path.isdir(self._base_dir): + gfile.DeleteRecursively(self._base_dir) + + def test_train_functional_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_functional_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + def test_train_sequential_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6bf143098c1bba64d47efce1bfface7682683d --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df --- /dev/null +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -0,0 +1,362 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.ops.losses import losses_impl + + +class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]), + combinations.combine(is_tpu=[False])) + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + # TODO(isaprykin): Make Adam v2 work with while_loops + # and TPUs. + ], + mode=["graph"], + use_callable_loss=[False], + is_tpu=[True])) + def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, + is_tpu): + with distribution.scope(): + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + # TODO(isaprykin): Eliminate `is_tpu`. Probably add a + # `DistributionStrategy.create_monitor` so that each DistributionStrategy + # could influence its training loop. That method would return an instance + # of Monitor. TPUMonitor would execute tpu.initialize_system() and + # tpu.shutdown_system(). + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers() + + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph", "eager"], is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn, + ], + mode=["graph"], + is_tpu=[True])) + + def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): + created_variables = [] + trainable_variables = [] + + def appending_creator(next_creator, *args, **kwargs): + v = next_creator(*args, **kwargs) + created_variables.append(v.name) + if "trainable" in kwargs and kwargs["trainable"]: + trainable_variables.append(v.name) + return v + + # Creator scope needs to be set before it's used inside + # `distribution.scope`. + with variable_scope.variable_creator_scope( + appending_creator), distribution.scope(): + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=True, + create_optimizer_inside_model_fn=True) + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + + def get_expected_variables(optimizer_fn, num_parameter_devices): + variables_map = { + "GradientDescent": ["dense/kernel", "dense/bias"], + "Adam": [ + "dense/kernel", "dense/bias", "beta1_power", "beta2_power", + "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", + "dense/bias/Adam_1" + ] + } + variables = variables_map[optimizer_fn().get_name()] + variables.extend([ + v + "/replica_{}".format(replica) + for v in variables + for replica in range(1, num_parameter_devices) + ]) + return set([v + ":0" for v in variables]) + + self.assertEqual( + get_expected_variables(optimizer_fn, + len(distribution.parameter_devices)), + set(created_variables)) + + @combinations.generate( + combinations.times( + combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + is_tpu=[False], + # TODO(isaprykin): Allow False here. Currently subsequent + # towers will re-execute UPDATE_OPS of previous towers. + update_ops_in_cross_tower_mode=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + mode=["graph"], + is_tpu=[True], + update_ops_in_cross_tower_mode=[False]))) + def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, + renorm, is_tpu, + update_ops_in_cross_tower_mode): + """Verifies that moving mean updates are reduced across towers.""" + with distribution.scope(): + num_towers = len(distribution.worker_devices) + model_fn, dataset_fn, batchnorm = batchnorm_example( + optimizer_fn, + batch_per_epoch=num_towers, + momentum=momentum, + renorm=renorm, + update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) + + # Make sure prefetching is disabled since that makes the + # specific input on each device to be non deterministic, and + # this test relies on specific input being on each device. + if isinstance(distribution, mirrored_strategy.MirroredStrategy): + self.assertFalse(distribution._prefetch_on_device) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + fetches = distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), + run_concurrently=batchnorm.built)) + if update_ops_in_cross_tower_mode: + fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + return control_flow_ops.group(fetches) + + if not context.executing_eagerly(): + with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + expected_moving_means = [0.] * 8 + + def averaged_batch_mean(i): + # Each batch has shape [16, 8] where the ith element in jth list is + # (8 * j + i + tower_id * 100). So the batch mean in each tower is + # (60 + i + tower_id * 100). So here comes its batch mean over all + # towers: + return 60. + i + (num_towers - 1.) / 2. * 100. + + for _ in range(10): + run_step() + moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + + # We make sure that the moving_mean is updated as if the sample mean is + # calculated over all towers. + for i, expected_moving_mean in enumerate(expected_moving_means): + expected_moving_means[i] -= (( + expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) + self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + + @combinations.generate( + combinations.times( + combinations.combine( + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + loss_reduction=[ + losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS + ]), + combinations.times( + combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + is_tpu=[False]), + combinations.combine( + mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + is_tpu=[True], + mode=["graph"], + use_callable_loss=[True, False]))) + def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, + use_callable_loss, is_tpu): + with distribution.scope(): + all_vars = [] + + def model_fn(x, y): + + def loss_fn(): + # Use fixed initialization to make the steps deterministic. + w = variable_scope.get_variable("w", initializer=[[2.]]) + all_vars.append(w) + predict = math_ops.matmul(x, w) + return losses_impl.mean_squared_error( + y, predict, reduction=loss_reduction) + + optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + def dataset_fn(): + features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) + labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) + return dataset_ops.Dataset.zip((features, labels)).repeat() + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, *iterator.get_next(), run_concurrently=False)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + v = all_vars[0] + self.assertTrue(all([v is vi for vi in all_vars[1:]])) + weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + # Our model is: + # predict = x * w + # loss = (predict - y)^2 + # dloss/dpredict = 2*(predict - y) + # dloss/dw = 2 * x^T @ (predict - y) + # For our batch size of 2, assuming sum loss reduction: + # x = [2, 7] + # y = [6, 21] + # w_initial = 2 + # predict = [4, 14] + # predict - y = [-2, -7] + # dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106 + # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 + # with sum loss reduction, or 10.6 with mean. + if loss_reduction == losses_impl.Reduction.SUM: + # Note that the "distribution.num_towers" factor will go away once + # we split the input across towers, instead of pulling a complete + # batch of input per tower. + self.assertNear(weight, 2 + 21.2 * distribution.num_towers, 0.0001) + else: + # One of the mean loss reductions. + self.assertNear(weight, 2 + 10.6, 0.0001) + + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..900aa10e93e8881aa236bac8a2873d5c5531c6f6 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -0,0 +1,535 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class MirroredStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import threading +import six + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.contrib.distribute.python import values +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import coordinator +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + +def _cpu_device(device): + cpu_device = tf_device.DeviceSpec.from_string(device) + cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) + return cpu_device.to_string() + + +class _RequestedStop(Exception): + pass + + +class MirroredStrategy(distribute_lib.DistributionStrategy): + """Mirrors vars to distribute across multiple devices on a single machine. + + This strategy uses one tower per device and sync replication. + """ + + def __init__(self, + devices=None, + num_gpus=None, + cross_tower_ops=None, + prefetch_on_device=None): + super(MirroredStrategy, self).__init__() + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: + if num_gpus is None: + num_gpus = context.num_gpus() + devices = ["/device:GPU:%d" % d for d in range(num_gpus)] + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + + assert devices, "Must specify at least one device." + assert len(set(devices)) == len(devices), ( + "No duplicates allowed in `devices` argument.") + # TODO(josh11b): Require at least 2 devices? + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) + self._device_index = values.PerDevice( + dict((d, i) for i, d in enumerate(devices))) + self._cross_tower_ops = cross_tower_ops + self._prefetch_on_device = prefetch_on_device + # TODO(yuefengz): consider setting the default device. + + def _create_variable(self, next_creator, *args, **kwargs): + """Create a mirrored variable. See `DistributionStrategy.scope`.""" + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + + tower_local = kwargs.pop("tower_local_reduce_method", None) + if tower_local is not None: + kwargs["trainable"] = False + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = {} + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + # Initialize replicas with the same value: + if context.executing_eagerly(): + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) + else: + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + assert not isinstance(v, values.DistributedVariable) + index[d] = v + + if tower_local is None: + result = values.MirroredVariable(index, index[devices[0]]) + else: + result = values.TowerLocalVariable( + index, index[devices[0]], tower_local) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + def distribute_dataset(self, dataset_fn): + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) + + def _broadcast(self, tensor, destinations): + # TODO(josh11b): In eager mode, use one thread per device, or async mode. + return self._get_cross_tower_ops().broadcast(tensor, destinations or + self._devices) + + def _call_for_each_tower(self, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times for when called for different devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator( + clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(self._devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( + self, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup( + {t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + def map(self, map_over, fn, *args, **kwargs): + # TODO(josh11b): In eager mode, use one thread per device. + index = {} + i = 0 + for m in map_over: + d = self._devices[i % len(self._devices)] + with ops.device(d): + l = index.get(d, []) + l.append(fn(m, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + index[d] = l + # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput + # in addition to PerDevice data. + return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + + def configure(self, session_config=None): + if self._cross_tower_ops is None: + self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( + self._devices, session_config=session_config) + + def _get_cross_tower_ops(self): + if self._cross_tower_ops is None: + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) + return self._cross_tower_ops + + def _reduce(self, method_string, value, destinations): + if len(self._devices) == 1 and not isinstance(value, values.PerDevice): + value = values.PerDevice({self._devices[0]: value}) + assert isinstance(value, values.PerDevice) + + return self._get_cross_tower_ops().reduce( + method_string, value, destinations=destinations) + + def _batch_reduce(self, method_string, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(method_string, + value_destination_pairs) + + def _update(self, var, fn, *args, **kwargs): + # TODO(josh11b): Also support TowerLocalVariables here? If so, args and + # kwargs don't need to be mirrored. + assert isinstance(var, values.MirroredVariable) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d, v in var._index.items(): # pylint: disable=protected-access + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + assert isinstance(colocate_with, list) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d in colocate_with: + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(*values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + if isinstance(tower_local_var, values.TowerLocalVariable): + return math_ops.add_n(self.unwrap(tower_local_var)) + assert isinstance(tower_local_var, values.Mirrored) + return array_ops.identity(tower_local_var.get()) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + if isinstance(val, values.TowerLocalVariable): + val = self.reduce(val.reduce_method, val, destinations=destination) + with ops.device(destination): + return fn(self.unwrap(val)[0]) + + assert isinstance(val, values.Mirrored), ( + "val = %s (type %s)" % (val, val.__class__.__name__)) + if val.on_device(destination): + with ops.device(destination): + # Use an identity here to make sure we are returning a tensor + # instead of e.g. a variable object. + return array_ops.identity(fn(val.get(destination))) + device = None + for d in self._devices: + if val.on_device(d): + device = d + break + assert device is not None, ( + "Could not find destination %s in list of devices %s." % + (destination, val.devices)) + with ops.device(device): + v = fn(val.get(device)) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_device_set: + return [val.get(device=d) for d in self._devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + @property + def is_single_tower(self): + return len(self._devices) == 1 + + @property + def num_towers(self): + return len(self._devices) + + def _worker_device_index(self): + return self._device_index + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._devices) + + @property + def parameter_devices(self): + return list(self._devices) + + def non_slot_devices(self, var_list): + del var_list + return list(self._devices) + + def _get_devices_from(self, colocate_with=None): + if colocate_with is None: + return self._devices + elif isinstance(colocate_with, values.DistributedValues): + # pylint: disable=protected-access + return list(colocate_with._index.keys()) + elif isinstance(colocate_with, six.string_types): + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] + else: + return colocate_with + + class _MirroredTowerThread(threading.Thread): + """A thread that runs() a function on a device.""" + + def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, + **kwargs): + super(MirroredStrategy._MirroredTowerThread, self).__init__() # pylint: disable=protected-access + self.coord = coord + self.distribution = dist + self.device = device + self.tower_id = dist.worker_devices.index(device) + self.variable_creator_fn = variable_creator_fn + # State needed to run and return the results of `fn`. + self.main_fn = fn + self.main_args = args + self.main_kwargs = kwargs + self.main_result = None + self.done = False + # State needed to run the next merge_call() (if any) requested via + # TowerContext. + self.merge_fn = None + self.merge_args = None + self.merge_kwargs = None + self.merge_result = None + self.captured_name_scope = None + # We use a thread.Event for the main thread to signal when this + # thread should start running (`should_run`), and another for + # this thread to transfer control back to the main thread + # (`has_paused`, either when it gets to a + # `get_tower_context().merge_call` or when `fn` returns). In + # either case the event starts cleared, is signaled by calling + # set(). The receiving thread waits for the signal by calling + # wait() and then immediately clearing the event using clear(). + self.should_run = threading.Event() + self.has_paused = threading.Event() + # These fields have to do with inheriting various contexts from the + # parent thread: + # pylint: disable=protected-access + self.context_mode = context.context()._eager_context.mode + if not context.context()._context_handle: + context.context()._initialize_handle_and_devices() + self.context_device_policy = ( + pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( + context.context()._context_handle)) + self.graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] + self._captured_var_scope = variable_scope.get_variable_scope() + # Adding a "/" at end lets us re-enter this scope later. + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" + if self.tower_id > 0: + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id + + def run(self): + # pylint: disable=protected-access + self.graph._variable_creator_stack = self._variable_creator_stack + self.should_run.wait() + self.should_run.clear() + try: + if self.coord.should_stop(): + return + with self.coord.stop_on_exception(), \ + context.context()._mode(self.context_mode), \ + context.context().device_policy(self.context_device_policy), \ + _enter_graph(self.graph), \ + MirroredTowerContext(self.distribution, self.tower_id), \ + ops.device(self.device), \ + ops.name_scope(self._name_scope), \ + variable_scope.variable_scope( + self._captured_var_scope, reuse=self.tower_id > 0), \ + variable_scope.variable_creator_scope(self.variable_creator_fn): + self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) + self.done = True + finally: + self.has_paused.set() + + +class MirroredTowerContext(distribute_lib.TowerContext): + """TowerContext used in MirroredStrategy.call_for_each_tower(). + + Opened in `_MirroredTowerThread`, to allow the user to invoke + `MirroredStrategy`'s specific implementation of `merge_call()`, + which works by delegating the function and its arguments to + the main thread (the one that invoked + `MirroredStrategy.call_for_each_tower()`). + """ + + def _merge_call(self, fn, *args, **kwargs): + """Delegate to the main thread to actually perform merge_call().""" + t = threading.current_thread() # a _MirroredTowerThread + t.merge_fn = fn + t.merge_args = args + t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" + t.has_paused.set() + t.should_run.wait() + t.should_run.clear() + if t.coord.should_stop(): + raise _RequestedStop() + return t.merge_result + + @property + def device(self): + distribute_lib.require_tower_context(self) + return self._distribution_strategy.worker_devices[self._tower_id] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bccd278847e3c87080af3cb15665e7a0d802d8fb --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -0,0 +1,535 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multi-GPU tests for MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib + +GPU_TEST = "test_gpu" in sys.argv[0] + + +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + if context.num_gpus() > 1: + devices = ["/device:GPU:0", "/device:GPU:1"] + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + return mirrored_strategy.MirroredStrategy(devices) + + def testMinimizeLossEager(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + soft_placement = not GPU_TEST + print("testMinimizeLossGraph soft_placement:", soft_placement) + self._test_minimize_loss_graph( + self._get_distribution_strategy(), soft_placement=soft_placement) + + def testMapReduce(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_tower_id(self._get_distribution_strategy()) + + def testNumTowers(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self.assertEqual(2, self._get_distribution_strategy().num_towers) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testRunRegroupError(self): + + def run_fn(device_id): + # Generates a list with different lengths on different devices. + # Will fail in _regroup() (if more than one device). + return list(range(device_id)) + + dist = self._get_distribution_strategy() + with dist.scope(), self.assertRaises(AssertionError): + dist.call_for_each_tower(run_fn, dist.worker_device_index) + + @test_util.run_in_graph_and_eager_modes() + def testReduceToCpu(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + def run_fn(device_id): + return device_id + + dist = self._get_distribution_strategy() + with dist.scope(): + result = dist.call_for_each_tower(run_fn, dist.worker_device_index) + reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + unwrapped = dist.unwrap(reduced) + self.assertEqual(1, len(unwrapped)) + expected = sum(range(len(dist.worker_devices))) + self.assertEqual(expected, self.evaluate(unwrapped[0])) + + +class MirroredStrategyVariableCreationTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSingleVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + # This variable should be created only once across the threads because of + # special variable_creator functions used by `dist.call_for_each_tower`. + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnnamedVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v = variable_scope.variable(1.0) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # Default name of "Variable" will be used. + self.assertEquals("Variable:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariables(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + for i in range(5): + vs.append(variable_scope.variable(1.0, name="foo" + str(i))) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for i, v in enumerate(result): + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals("foo" + str(i) + ":0", v.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariablesWithSameCanonicalName(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + vs.append(variable_scope.variable(1.0, name="foo/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) + vs.append(variable_scope.variable(1.0, name="foo/bar_1")) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for v in result: + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals(4, len(result)) + self.assertEquals("foo/bar:0", result[0].name) + self.assertEquals("foo_1/bar:0", result[1].name) + self.assertEquals("foo_1/bar_1:0", result[2].name) + self.assertEquals("foo/bar_1:0", result[3].name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableWithSameCanonicalNameAcrossThreads(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(device_id): + v = variable_scope.variable(1.0, name="foo_" + str(device_id)) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # The resulting mirrored variable will use the name from the first device. + self.assertEquals("foo_0:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithLayers(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(features): + with variable_scope.variable_scope("common"): + layer1 = core.Dense(1) + layer1(features) + layer2 = core.Dense(1) + layer2(features) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + layer3 = core.Dense(1) + layer3(features) + return [(layer1.kernel, layer1.bias), + (layer2.kernel, layer2.bias), + (layer3.kernel, layer3.bias)] + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + features = dist.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + ).make_one_shot_iterator().get_next() + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, features, run_concurrently=False) + suffixes = ["", "_1", "_2"] + for (kernel, bias), suffix in zip(result, suffixes): + self.assertIsInstance(kernel, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertIsInstance(bias, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithGetVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.get_variable("var-thread0", [1]) + with variable_scope.variable_scope("common"): + v1 = variable_scope.get_variable("var-thread1", [1]) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.get_variable("var-thread2", [1]) + + return v0, v1, v2 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with variable_scope.variable_scope("main"): + v = variable_scope.get_variable("var-main0", [1]) + self.assertEquals("main/var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(3, len(result)) + v0, v1, v2 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("main/var-thread0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("main/common/var-thread1:0", v1.name) + self.assertIsInstance(v2, values.MirroredVariable) + self.assertEquals("main/common/var-thread2:0", v2.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testThreeDevices(self): + self._skip_eager_if_gpus_less_than(2) + + def model_fn(): + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNonMatchingVariableCreation(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(name): + v = variable_scope.variable(1.0, name=name) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + names = values.DistributedValues({ + "/device:CPU:0": "foo", + "/device:GPU:0": "bar" + }) + with self.assertRaises(RuntimeError): + _ = dist.call_for_each_tower(model_fn, names, run_concurrently=False) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariable(self): + self._skip_eager_if_gpus_less_than(1) + + all_v_sum = {} + all_v_mean = {} + + def model_fn(device_id): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope("sum"): + v_sum = variable_scope.variable(1.0) + with tower_context.tower_local_var_scope("mean"): + v_mean = variable_scope.variable(4.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) + updates = [v_sum.assign_add(2.0 + device_id), + v_mean.assign(6.0 * device_id)] + all_v_sum[device_id] = v_sum + all_v_mean[device_id] = v_mean + return updates, v_sum, v_mean + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + # Create "sum" and "mean" versions of TowerLocalVariables. + ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + # Should see the same wrapping instance in all towers. + self.assertIs(all_v_sum[0], ret_v_sum) + self.assertIs(all_v_mean[0], ret_v_mean) + for i in range(1, dist.num_towers): + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Apply updates + self.evaluate(variables.global_variables_initializer()) + self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + expected_sum = 0.0 + expected_mean = 0.0 + for i, d in enumerate(dist.worker_devices): + # Should see different values on different devices. + v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) + v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected + expected_mean /= len(dist.worker_devices) + + # Without get(device), should return the value you get by + # applying the reduction across all towers (whether you use + # fetch(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) + if not context.executing_eagerly(): + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + + # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not + # testing this in eager mode. + + def testNameScope(self): + def model_fn(): + with ops.name_scope("foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(1.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("main/foo/" + name + ":0", v0.name) + self.assertEquals("main/tower_1/foo/" + name + ":0", v1.name) + + def testWithDefaultName(self): + def model_fn(): + with ops.name_scope(None, "foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(2.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("foo/" + name + ":0", v0.name) + self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + + def testDynamicRnnVariables(self): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + # Two variables are created by the RNN layer. + self.assertEquals(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = dist.unwrap(v) + self.assertStartsWith(v1.name, "tower_1/") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..61cbe6df813bb28bf8baa83d9e28ffafc4f0cbb8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribute as distribute_lib + + +class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +class VariableCreatorStackTest(test.TestCase): + + def testCreatorStacksAreThreadLocal(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + + def model_fn(device_id): + assert isinstance(device_id, int) + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + str(device_id) + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + dist.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = dist.call_for_each_tower(model_fn, dist.worker_device_index) + result = dist.unwrap(result) + expected = ["main_thread:thread_0", "main_thread:thread_1"] + self.assertEquals(expected, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..7644acedc99361d7287a91832d76bc68cbc6ac0a --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Monitor is responsible for training, checkpointing and recovery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.ops import variables + + +class Monitor(object): + """Executes training steps, recovers and checkpoints. + + Note that this class is particularly preliminary, experimental, and + expected to change. + """ + # TODO(isaprykin): Support step functions that need multiple session calls. + # TODO(isaprykin): Support extra arguments to the step function. + # TODO(isaprykin): Support recovery, checkpointing and summaries. + + def __init__(self, step_callable, session=None): + """Initialize the Monitor with components for executing training steps. + + Args: + step_callable: a training `Step` that's capable of signaling when done. + session: a `Session` instance that's needed for graph mode. + + Raises: + ValueError: if `session` was provided for eager mode or not provided for + graph mode. + """ + if context.executing_eagerly(): + if session is not None: + raise ValueError("Should not provide a `session` in Eager mode.") + self._run_step = step_callable + else: + if session is None: + raise ValueError("Should provide a `session` in Graph mode.") + self._run_step = session.make_callable(step_callable()) + session.run(variables.global_variables_initializer()) + + def run_steps(self, num_steps=None): + step = 0 + while num_steps is None or step < num_steps: + try: + self._run_step() + step += 1 + except errors.OutOfRangeError: + break diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4fdb9bf69b4f6ad76b79fd298f5303f24a1bd455 --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Monitor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import monitor as monitor_lib +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.training import gradient_descent + + +class MonitorTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example(optimizer_fn, distribution) + + if context.executing_eagerly(): + monitor = monitor_lib.Monitor(single_loss_step, None) + else: + with self.test_session() as sess: + monitor = monitor_lib.Monitor(single_loss_step, sess) + + monitor.run_steps(1) + + self.assertEqual(1, len(layer.trainable_variables)) + mirrored_weight_variable = layer.trainable_variables[0] + start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = abs(numpy.array(start_error) - 1) + + monitor.run_steps(9) + end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = abs(numpy.array(end_error) - 1) + self.assertGreaterEqual(start_error, end_error) + + def testPassingASessionInEager(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with session.Session() as sess, context.eager_mode(): + with self.assertRaisesRegexp(ValueError, "Should not provide"): + _ = monitor_lib.Monitor(step_function, sess) + + def testNotPassingASessionInGraph(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with context.graph_mode(), ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, "Should provide"): + _ = monitor_lib.Monitor(step_function, session=None) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..0f21a427320510635279f80c11711e81715ec37c --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + + +# TODO(yuefengz): support between-graph replication. +# TODO(yuefengz): merge this class into its base class. +# TODO(yuefengz): in some cases, we probably want to use configure method to +# configure this class. +# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the +# class is introduced. +class MultiWorkerMirroredStrategy(MirroredStrategy): + """Mirrored strategy that works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will parition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + This class maps one tower to one device on a worker. It mirrors all model + variables on all towers. For example, if you have two `worker`s and each + `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 + GPUs. Then like in MirroredStrategy, each tower performs their computation + with their own copy of variables unless in cross-tower model where variable or + tensor reduction happens. + """ + + def __init__(self, + num_gpus_per_worker=1, + worker_job_name=None, + num_workers=None, + cluster=None, + cross_tower_ops=None, + prefetch_on_device=None): + """Initialize the strategy object. + + Args: + num_gpus_per_worker: number of GPUs per work. If it is zero, the local + CPU will be used. + worker_job_name: the job name for `worker`, typically just 'worker'. + num_workers: the number of workers. If it is 0, it regenerates to + single-worker MirroredStrategy. + cluster: a `tf.train.ClusterSpec` object or a dict that can be used to + construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` + proto buffer. It is an alternative way to initialize this object. + cross_tower_ops: the cross tower ops to use. If None, a default one will + be used. If configure method is called, a best one for the configuration + will be chosen. + prefetch_on_device: a boolean to specify whether to prefetech input to + each worker's devices. + + Raises: + ValueError: if got an unexpected `cluster`. + """ + if cluster is None: + self._workers = [ + '/job:%s/task:%d' % (worker_job_name, task_index) + for task_index in range(num_workers) + ] + else: + if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster) + elif isinstance(cluster, server_lib.ClusterSpec): + cluster_spec = cluster + else: + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + '`tf.train.ClusterDef` object') + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append('/job:%s/task:%d' % (job, task)) + + self._num_gpus_per_worker = num_gpus_per_worker + if num_gpus_per_worker > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + '/device:GPU:%d' % gpu) + for gpu in range(num_gpus_per_worker) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, '/device:CPU:0')] + for worker in self._workers + } + self._devices = nest.flatten(self._worker_device_map) + + super(MultiWorkerMirroredStrategy, self).__init__( + devices=self._devices, prefetch_on_device=prefetch_on_device) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + + def distribute_dataset(self, dataset_fn): + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..09c859b32a3150b95fbfcfa5b62b5eca426ddf18 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultiWorkerMirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.training import server_lib + + +class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus_per_worker=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + +class DeviceScopeTest(test.TestCase): + """Test the device scope of MultiWorkerMirroredStrategy.""" + + def testDeviceScope(self): + with context.graph_mode(): + strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus_per_worker=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f659be5f42594b275af06435cb0c228e5d594ac9 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -0,0 +1,90 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base testing class for strategies that require multiple nodes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import copy + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +class MultiWorkerTestBase(test.TestCase): + """Base class for testing multi node strategy and dataset.""" + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + num_workers = 2 + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + default_config = config_pb2.ConfigProto() + default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # The local cluster takes some portion of the local GPUs and there is no way + # for the cluster to terminate unless using multiple processes. Therefore, + # we have to only create only one cluster throughout a test process. + workers, _ = test_util.create_local_cluster( + num_workers, num_ps=0, worker_config=default_config) + cls._master_target = workers[0].target + + @contextlib.contextmanager + def test_session(self, graph=None, config=None): + """Create a test session with master target set to the testing cluster. + + This overrides the base class' method, removes arguments that are not needed + by the multi-node case and creates a test session that connects to the local + testing cluster. + + Args: + graph: Optional graph to use during the returned session. + config: An optional config_pb2.ConfigProto to use to configure the + session. + + Yields: + A Session object that should be used as a context manager to surround + the graph building and execution code in a test case. + """ + if self.id().endswith('.test_session'): + self.skipTest('Not a test.') + + if config is None: + config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + config = copy.deepcopy(config) + # Don't perform optimizations for tests so we don't inadvertently run + # gpu ops on cpu + config.graph_options.optimizer_options.opt_level = -1 + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + + if graph is None: + if self._cached_session is None: # pylint: disable=access-member-before-definition + self._cached_session = session.Session( + graph=None, config=config, target=self._master_target) + sess = self._cached_session + with sess.graph.as_default(), sess.as_default(): + yield sess + else: + with session.Session( + graph=graph, config=config, target=self._master_target) as sess: + yield sess diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4bab9d93814eb70a2a1586fc291a16b2766b90 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -0,0 +1,151 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class OneDeviceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import values +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +class OneDeviceStrategy(distribute_lib.DistributionStrategy): + """A distribution strategy for running on a single device.""" + # TODO(josh11b): Do we wrap values in types to generate errors if you are + # doing something that won't work with other DistributionStrategy + # implementations? + + def __init__(self, device, prefetch_on_device=None): + super(OneDeviceStrategy, self).__init__() + self._device = device + self._prefetch_on_device = prefetch_on_device + self._default_device = device + + def _create_variable(self, next_creator, *args, **kwargs): + # No need to distinguish tower-local variables when not mirroring, + # we just enforce that they are not trainable. + if kwargs.pop("tower_local_reduce_method", None) is not None: + kwargs["trainable"] = False + + colocate_with = kwargs.pop("colocate_with", None) + if colocate_with is None: + with ops.device(self._device): + return next_creator(*args, **kwargs) + if isinstance(colocate_with, six.string_types): + with ops.device(colocate_with): + return next_creator(*args, **kwargs) + if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + isinstance(colocate_with[0], six.string_types)): + with ops.device(colocate_with[0]): + return next_creator(*args, **kwargs) + with ops.colocate_with(colocate_with): + return next_creator(*args, **kwargs) + + def distribute_dataset(self, dataset_fn): + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), [self._device], + self._prefetch_on_device) + + def _broadcast(self, tensor, destinations): + return tensor + + def _call_for_each_tower(self, fn, *args, **kwargs): + # We don't run `fn` in multiple threads in OneDeviceStrategy. + kwargs.pop("run_concurrently", None) + with ops.device(self._device), _OneDeviceTowerContext(self): + return fn(*args, **kwargs) + + def map(self, map_over, fn, *args, **kwargs): + with ops.device(self._device): + return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) + + def _reduce(self, method_string, value, destinations): + if not isinstance(value, values.MapOutput): + return value + l = value.get() + assert l + with ops.device(self._device): + if method_string == "sum": + return math_ops.add_n(l) + elif method_string == "mean": + return math_ops.add_n(l) / len(l) + else: + assert False + + def _update(self, var, fn, *args, **kwargs): + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(var, *args, **kwargs) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + del colocate_with + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(*args, **kwargs) + + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + return array_ops.identity(tower_local_var) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + with ops.device(self._device): + v = fn(val) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, value): + return [value] + + @property + def is_single_tower(self): + return True + + @property + def num_towers(self): + return 1 + + @property + def worker_devices(self): + return [self._device] + + @property + def parameter_devices(self): + return [self._device] + + def non_slot_devices(self, var_list): + del var_list + return [self._device] + + def _worker_device_index(self): + return 0 + + +class _OneDeviceTowerContext(distribute_lib.TowerContext): + + def __init__(self, distribution_strategy): + distribute_lib.TowerContext.__init__( + self, distribution_strategy, tower_id=0) + + @property + def device(self): + return self._distribution_strategy.worker_devices[0] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7aad8a953cbedd30b48739416e74b3dc164dc4cd --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class OneDeviceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..abd3a65ac4e19ece6b69b9834f4218fde55b60c2 --- /dev/null +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables + + +class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetwork(self, distribution, optimizer_fn, + use_callable_loss=True): + with distribution.scope(): + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return control_flow_ops.group(distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built))) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3670b45aba801cf8c18e04bfea03e23eb67184 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -0,0 +1,225 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extension of prefetching_ops to support more than one device.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest as data_nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.util import nest + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device. + + Args: + input_dataset: The input dataset. + one_shot: If true, we make a one shot iterator that's already initialized. + devices: Devices on which to prefetch. + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). Only used if one_shot + is False. + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + devices, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + self._devices = devices + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + target_device = gen_dataset_ops.iterator_get_device( + self._input_iterator._iterator_resource) + self._buffering_resources = [] + for device in nest.flatten(self._devices): + with ops.device(device): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_prefetch_fn, + target_device=target_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name) + self._buffering_resources.append(buffer_resource_handle) + + if not self._one_shot: + reset_ops = [] + for buffer_resource in self._buffering_resources: + reset_ops.append( + prefetching_ops.function_buffering_resource_reset(buffer_resource)) + with ops.control_dependencies(reset_ops): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_result = [] + # TODO(priyag): This will fail if the input size (typically number of + # batches) is not divisible by number of devices. + # How do we handle that more gracefully / let the user know? + for buffer_resource in self._buffering_resources: + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + buffer_resource, + output_types=data_nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + data_nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + flat_result.append(ret) + + return nest.pack_sequence_as(self._devices, flat_result) + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to other device(s).""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._devices = devices + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=True, + devices=self._devices, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + if context.executing_eagerly(): + raise RuntimeError( + "make_initializable_iterator is not supported when eager " + "execution is enabled.") + + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + devices=self._devices, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_devices()` must be the last " + "transformation in a dataset pipeline.") + + # TODO(priyag): Fix the output types, shapes and classes to match the result + # of get_next (which has the additional nesting layer of devices now). + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_devices(devices, buffer_size=None): + """A transformation that prefetches dataset values to the given `devices`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + devices: A nested structure of devices on which to prefetch the data. It can + be a single device name, or a tuple or list of device names. + buffer_size: (Optional.) The number of elements to buffer on each device. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, devices, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a68dbce6c7d03f6a1695ebfcd00178e21ac1cda0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -0,0 +1,90 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class PrefetchingOpsV2Test(test.TestCase): + + def testPrefetchToOneDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesInAList(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + output = [] + with self.test_session() as sess: + for _ in range(5): + result = sess.run(next_element) + self.assertEqual(2, len(result)) + output.extend(result) + self.assertEquals(set(range(10)), set(output)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesWithReinit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(5): + sess.run(next_element) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + sess.run(iterator.initializer) + for _ in range(5): + sess.run(next_element) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator.py b/tensorflow/contrib/distribute/python/shared_variable_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..a7083e279f20803b227dcd52f6420ae832aa2df4 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to re-use variables created on first device on subsequent devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +_VARIABLE_UNIQUIFYING_REGEX = re.compile(r"_\d/") +_VARIABLE_UNIQUIFYING_REGEX_AT_END = re.compile(r"_\d$") + + +def _canonicalize_variable_name(name): + # If no name is specified, uses default name "Variable". + if name is None: + return "Variable" + # Replace all instances of "_/" with "/" + name = _VARIABLE_UNIQUIFYING_REGEX.sub("/", name) + # Replace any instances of "_" at the end of the string with "" + name = _VARIABLE_UNIQUIFYING_REGEX_AT_END.sub("", name) + return name + + +def make_fn(shared_variable_store, device_id): + """Construct the variable creator function for device `device_id`. + + Constructs custom variable creator functions for the given device. + On first device (device_id == 0), it creates the variable using the + `next_creator`, and stores it in the provided `shared_variable_store`. + On all other devices (device_id > 0), it tries to re-use the variable + already created with the same name. If no such variable exists, it throws an + error. + Additionally, we de-uniquify variable names before checking for matches. This + helps re-use variables which are intended to be the same but have different + names due to variable uniquification happening upstream. Since this might + mean we may have multiple variables with the same canonical name, we store + them in a list per canonical name and return them in the same order as well. + + Args: + shared_variable_store: A dictionary that we will use to store variables + created on the first device, and re-used by creators for other devices. + device_id: Integer index of the device whose creator should be + constructed. + + Returns: + An appropriate creator function based on device_id. + + """ + variable_scope_access_index = {} + assert isinstance(device_id, int) + + def create_new_variable(next_creator, *args, **kwargs): + """Create the variable using `next_creator` and store it.""" + canonical_name = _canonicalize_variable_name(kwargs.get("name")) + v = next_creator(*args, **kwargs) + + if canonical_name not in shared_variable_store: + shared_variable_store[canonical_name] = [] + shared_variable_store[canonical_name].append(v) + return v + + def reuse_variable(next_creator, *args, **kwargs): + """Re-use existing variable from store with same name (in order).""" + del next_creator, args + name = kwargs.get("name") + canonical_name = _canonicalize_variable_name(name) + + try: + variable_index = variable_scope_access_index.get(canonical_name, 0) + v = shared_variable_store[canonical_name][variable_index] + # TODO(priyag): Make this variable re-use more robust by adding checks + # that the requested shape and dtype match the existing variable. + variable_scope_access_index[canonical_name] = variable_index + 1 + return v + except (KeyError, IndexError): + raise RuntimeError( + "Tried to create variable {} with mismatching name on device {}". + format(name, device_id)) + + if device_id == 0: + return create_new_variable + else: + return reuse_variable diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b452fc2d445d1cf7dbf5e8fe0e29edef516207 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -0,0 +1,74 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SharedVariableCreator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope + + +class CanonicalizeVariableNameTest(test.TestCase): + + def _canonicalize(self, name): + return shared_variable_creator._canonicalize_variable_name(name) + + def testNoName(self): + self.assertEquals("Variable", self._canonicalize(None)) + + def testPatternInMiddle(self): + self.assertEquals("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz")) + + def testPatternAtEnd(self): + self.assertEquals("foo", self._canonicalize("foo_1")) + + def testWrongPatterns(self): + self.assertEquals("foo_1:0", self._canonicalize("foo_1:0")) + self.assertEquals("foo1", self._canonicalize("foo1")) + self.assertEquals("foo_a", self._canonicalize("foo_a")) + + +class SharedVariableCreatorTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSharedVariable(self): + + shared_variable_store = {} + num_devices = 3 + creator_fns = [] + for i in range(num_devices): + creator_fn = shared_variable_creator.make_fn(shared_variable_store, i) + creator_fns.append(creator_fn) + + with variable_scope.variable_creator_scope(creator_fns[0]): + v0 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[1]): + v1 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[2]): + v2 = variable_scope.variable(1.0, name="foo") + + # v1 and v2 should be same as v0 + self.assertIs(v1, v0) + self.assertIs(v2, v0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py new file mode 100644 index 0000000000000000000000000000000000000000..d1fdb3279cf2a7cba6e2282d58eedccf38bd38a3 --- /dev/null +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -0,0 +1,125 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple network to use in tests and examples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import step_fn +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def single_loss_example(optimizer_fn, distribution, use_bias=False): + """Build a very simple network to use in tests and examples.""" + + def dataset_fn(): + return dataset_ops.Dataset.from_tensors([[1.]]).repeat() + + optimizer = optimizer_fn() + layer = core.Dense(1, use_bias=use_bias) + + def loss_fn(x): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn, + optimizer, distribution) + + # Layer is returned for inspecting the kernels in tests. + return single_loss_step, layer + + +def minimize_loss_example(optimizer_fn, + use_bias=False, + use_callable_loss=True, + create_optimizer_inside_model_fn=False): + """Example of non-distribution-aware legacy code.""" + + def dataset_fn(): + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be + # fully defined for TPU. Remove this when XLA supports dynamic shapes. + return dataset.apply( + batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) + + # An Optimizer instance is created either outside or inside model_fn. + outer_optimizer = None + if not create_optimizer_inside_model_fn: + outer_optimizer = optimizer_fn() + + layer = core.Dense(1, use_bias=use_bias) + + def model_fn(x): + """A very simple model written by the user.""" + + def loss_fn(): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + optimizer = outer_optimizer or optimizer_fn() + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + return model_fn, dataset_fn, layer + + +def batchnorm_example(optimizer_fn, + batch_per_epoch=1, + momentum=0.9, + renorm=False, + update_ops_in_tower_mode=False): + """Example of non-distribution-aware legacy code with batch normalization.""" + + def dataset_fn(): + # input shape is [16, 8], input values are increasing in both dimensions. + return dataset_ops.Dataset.from_tensor_slices( + [[[float(x * 8 + y + z * 100) + for y in range(8)] + for x in range(16)] + for z in range(batch_per_epoch)]).repeat() + + optimizer = optimizer_fn() + batchnorm = normalization.BatchNormalization( + renorm=renorm, momentum=momentum, fused=False) + layer = core.Dense(1, use_bias=False) + + def model_fn(x): + """A model that uses batchnorm.""" + + def loss_fn(): + y = batchnorm(x, training=True) + with ops.control_dependencies( + ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops_in_tower_mode else []): + loss = math_ops.reduce_mean( + math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) + # `x` and `y` will be fetched by the gradient computation, but not `loss`. + return loss + + # Callable loss. + return optimizer.minimize(loss_fn) + + return model_fn, dataset_fn, batchnorm diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..d1910622b38c748fc5a814f9e83c2294850d5d12 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -0,0 +1,106 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The step function abstraction represents a single training step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import backprop +from tensorflow.python.training import optimizer as optimizer_lib + + +class Step(object): + """Interface for performing each step of a training algorithm.""" + + def __init__(self, distribution): + self._distribution = distribution + + @property + def distribution(self): + return self._distribution + + def __call__(self): + """Perform one step of this training algorithm.""" + return self.step(self.inputs()) + + def inputs(self): + """For the generating the input to be passed to `step()`.""" + raise NotImplementedError("must be implemented in descendants") + + def step(self, inputs): + """Perform the main computation of this training algorithm.""" + raise NotImplementedError("must be implemented in descendants") + + +class StandardInputStep(Step): + """Step with a standard implementation of input handling. + + Args: + dataset_fn: a function that returns a tf.data Dataset that produces the + input for the model. + """ + + def __init__(self, dataset_fn, distribution): + Step.__init__(self, distribution) + self._distributed_input = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def inputs(self): + return self._distributed_input.get_next() + + +class StandardSingleLossStep(StandardInputStep): + """A step function that implements a training step for a feed forward network. + + An instance of this class is intended to be used as a callable: + + ```python + ... + step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer) + step.initialize(distribution) + + # Run a single training step on a given DistributionStrategy: + step(distribution) + ... + ``` + + Args: + dataset_fn: a function that returns a tf.data Dataset that produces the + input for the model. + loss_fn: a function that returns loss. + optimizer: an optimizer that implements an update rule. + distribution: a `DistributionStrategy` object. + """ + + def __init__(self, dataset_fn, loss_fn, optimizer, distribution): + StandardInputStep.__init__(self, dataset_fn, distribution) + self._loss_fn = loss_fn + self._optimizer = optimizer + self._is_run_concurrently = False + + def step(self, inputs): + with self._distribution.scope(): + gradients_fn = backprop.implicit_grad(self._loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + + grads_and_vars = self.distribution.call_for_each_tower( + gradients_fn, inputs, run_concurrently=self._is_run_concurrently) + # If threads use layers, then we need to run the first step sequentially, + # so that layers.build() is not executed in parallel. Otherwise, multiple + # sets of mirrored variables are going to be created. + self._is_run_concurrently = True + return self._optimizer._distributed_apply( # pylint: disable=protected-access + self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75c5ec9659d193e77d219ba79977615d58841d64 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import variables + + +class SingleLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example( + optimizer_fn, distribution, use_bias=True) + + if context.executing_eagerly(): + run_step = single_loss_step + else: + with self.test_session() as sess: + run_step = sess.make_callable(single_loss_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4ad9f146bc1d6a987fbeecbb05122946137154 --- /dev/null +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -0,0 +1,225 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for testing DistributionStrategy descendants.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer + + +class _TestException(Exception): + pass + + +# May be the argument to either distribution.call_for_each_tower() or +# get_tower_context().merge_call() +def _raise_exception_fn(_=None): + raise _TestException() + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that raises an exception. +def _merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that raises an exception. +def _call_raises_fn(dist): + dist.call_for_each_tower(_raise_exception_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, +# calls a get_tower_context().merge_call() that calls a +# call_for_each_tower() that raises an exception. +def _merge_call_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_raises_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that calls a +# get_tower_context().merge_call() that raises an exception. +def _call_merge_raises_fn(dist): + dist.call_for_each_tower(_merge_raises_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that calls a call_for_each_tower() that +# calls a get_tower_context().merge_call() that raises an exception. +def _merge_call_merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + + +class DistributionTestBase(test.TestCase): + """Some tests that should work with any DistributionStrategy.""" + + def _test_minimize_loss_eager(self, d): + with d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a + # common `implicit_grad` function and put it in DistributionStrategy. + grad_fn = backprop.implicit_grad(loss) + grad_fn = optimizer.get_filtered_grad_fn(grad_fn) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one, run_concurrently=l.built) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + # control_dependencies irrelevant but harmless in eager execution + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + for i in range(10): + b, a = step() + if i == 0: + before, = b # pylint: disable=unbalanced-tuple-unpacking + after, = a # pylint: disable=unbalanced-tuple-unpacking + + error_before = abs(before.numpy() - 1) + error_after = abs(after.numpy() - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_minimize_loss_graph(self, d, soft_placement=False): + config = config_pb2.ConfigProto() + config.allow_soft_placement = soft_placement + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + with context.graph_mode(), \ + ops.Graph().as_default(), \ + self.test_session(config=config) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + grad_fn = backprop.implicit_grad(loss) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + before_out, after_out = step() + variables.global_variables_initializer().run() + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_map_reduce(self, d, in_graph=None): + with d.scope(): + map_in = [constant_op.constant(i) for i in range(10)] + map_out = d.map(map_in, lambda x, y: x * y, 2) + observed = d.fetch(d.reduce("sum", map_out)) + expected = 90 # 2 * (0 + 1 + ... + 9) + self.assertEqual(expected, observed.numpy()) + + def _test_device_index(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(device_id): + self.assertLess(device_id, len(d.worker_devices)) + self.assertFalse(expected_devices[device_id]) + expected_devices[device_id] = True + + d.call_for_each_tower(mark_devices_fn, d.worker_device_index) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_tower_id(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(): + tower_id = distribute_lib.get_tower_context().tower_id + self.assertLess(tower_id, len(d.worker_devices)) + self.assertFalse(expected_devices[tower_id]) + expected_devices[tower_id] = True + + d.call_for_each_tower(mark_devices_fn) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_call_and_merge_exceptions(self, dist): + with dist.scope(): + with self.assertRaises(_TestException): + dist.call_for_each_tower(_raise_exception_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_merge_raises_fn) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..75441786a615fc0d87b4c4b0b45b9384d678c1d3 --- /dev/null +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU Distribution Strategy. + +This is experimental. It's not ready for general use. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import nest + + +class TPUStrategy(one_device_strategy.OneDeviceStrategy): + """Experimental TPU distribution strategy implementation.""" + + def __init__(self, + num_cores_per_host=2, + iterations_per_step=2): + # TODO(isaprykin): Generalize the defaults. They are currently tailored for + # the unit test. + super(TPUStrategy, self).__init__('/cpu:0') + # TODO(isaprykin): Auto-detect number of cores and hosts. + self._num_cores_per_host = num_cores_per_host + # TODO(isaprykin): This might have to be per-call. + self._iterations_per_step = iterations_per_step + + def distribute_dataset(self, dataset_fn): + return values.PerIterationDataset( + self._call_dataset_fn(dataset_fn), self._iterations_per_step, + self._num_cores_per_host) + + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) + + inputs = {'args': args, 'kwargs': kwargs} + flat_inputs = nest.flatten(inputs) + + feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] + + feeds = lambda: itertools.compress(flat_inputs, feed_mask) + shapes = [f.get_shape() for f in feeds()] + if any([not s.is_fully_defined() for s in shapes]): + raise ValueError( + 'TPU currently requires fully defined shapes. Either use ' + 'set_shape() on the input tensors or use ' + 'dataset.apply(map_and_batch(..., drop_remainder=True)).') + types = [f.get_dtype() for f in feeds()] + + def infeed_input(i): + """Get input, split it and then enqueue.""" + iteration_inputs = [f.get(i) for f in feeds()] + infeed_inputs = [[inputs_per_core[core_id] + for inputs_per_core in iteration_inputs] + for core_id in range(self._num_cores_per_host)] + + infeed_ops = [] + for core_id, infeed_input in enumerate(infeed_inputs): + infeed_ops.append( + tpu_ops.infeed_enqueue_tuple( + inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) + + with ops.control_dependencies(infeed_ops): + return i + 1 + + with ops.device('/task:0/device:CPU:0'): + enqueue_ops = control_flow_ops.while_loop( + lambda i: i < self._iterations_per_step, + infeed_input, [constant_op.constant(0)], + parallel_iterations=1) + + def dequeueing_fn(*args, **kwargs): + """Dequeue input arguments and supply them to `fn`.""" + del args, kwargs + dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + dequeued = iter(dequeued) + + fn_inputs = [] + for inp, is_feed in zip(flat_inputs, feed_mask): + if is_feed: + fn_inputs.append(next(dequeued)) + else: + fn_inputs.append(inp) + + fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) + return fn(*fn_inputs['args'], **fn_inputs['kwargs']) + + def iterate_on_tpu(): + return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) + + with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access + tpu_result = tpu.batch_parallel( + iterate_on_tpu, [], num_shards=self._num_cores_per_host) + + return control_flow_ops.group(tpu_result, enqueue_ops) + + def _reduce(self, method_string, value, destinations): + del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + if method_string == 'mean': + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self._num_cores_per_host) + return tpu_ops.cross_replica_sum(value) + + @property + def num_towers(self): + return self._num_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py new file mode 100644 index 0000000000000000000000000000000000000000..9572ade8e497fa13a7ca0746399d3e0237ee79fd --- /dev/null +++ b/tensorflow/contrib/distribute/python/values.py @@ -0,0 +1,807 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various classes representing distributed values. + +See go/tf-distribution-strategy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import weakref + +import six + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import input_ops +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.util import nest + + +# pylint: disable=line-too-long +# TODO(josh11b): Should device values be strings or DeviceSpec objects +# Not sure DeviceSpec objects are usable as a dict key. +class DistributedValues(object): + """Holds a map from device to values. Either PerDevice or Mirrored.""" + + def __init__(self, index): + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + + def get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribute_lib.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + return self._get_cross_tower() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + def on_device(self, device): + device = device_util.canonicalize(device) + return device in self._index + + @property + def devices(self): + return list(self._index.keys()) + + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + # TODO(josh11b): Possibly make an accessor for _index for use by + # DistributionStrategy implementations. + + +class DistributedDelegate(DistributedValues): + """A map from device to values; acts as the same type as the values.""" + + def __init__(self, index): + super(DistributedDelegate, self).__init__(index) + + def __getattr__(self, name): + return getattr(self.get(), name) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.get() + o + def __radd__(self, o): return o + self.get() + def __sub__(self, o): return self.get() - o + def __rsub__(self, o): return o - self.get() + def __mul__(self, o): return self.get() * o + def __rmul__(self, o): return o * self.get() + def __truediv__(self, o): return self.get() / o + def __rtruediv__(self, o): return o / self.get() + def __floordiv__(self, o): return self.get() // o + def __rfloordiv__(self, o): return o // self.get() + def __mod__(self, o): return self.get() % o + def __rmod__(self, o): return o % self.get() + def __lt__(self, o): return self.get() < o + def __le__(self, o): return self.get() <= o + def __gt__(self, o): return self.get() > o + def __ge__(self, o): return self.get() >= o + def __and__(self, o): return self.get() & o + def __rand__(self, o): return o & self.get() + def __or__(self, o): return self.get() | o + def __ror__(self, o): return o | self.get() + def __xor__(self, o): return self.get() ^ o + def __rxor__(self, o): return o ^ self.get() + def __getitem__(self, o): return self.get()[o] + def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) + def __rpow__(self, o): return pow(o, self.get()) + def __invert__(self): return ~self.get() + def __neg__(self): return -self.get() + def __abs__(self): return abs(self.get()) + + def __div__(self, o): + try: + return self.get().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.get().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.get().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.get().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # TODO(josh11b): Even more operator overloads. + + +class PerDevice(DistributedValues): + """Holds a map from device to unsynchronized values.""" + pass + + +class Mirrored(DistributedValues): + """Holds a map from device to values which are kept in sync.""" + pass + + +def _assign_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign(array_ops.identity(tensor)) + + +DistributedVarOp = collections.namedtuple( + "DistributedVarOp", ["name", "graph", "type"]) + + +class DistributedVariable(DistributedDelegate): + """Holds a map from device to variables.""" + # TODO(josh11b): Support changing the set of variables if e.g. if new + # devices are joining or a device is to leave. + + def __init__(self, index): + # Child class must set self._primary_var before calling + # super(...).__init__(index). + self._common_name = self._primary_var.name.split(":")[0] + super(DistributedVariable, self).__init__(index) + + @property + def initializer(self): + return control_flow_ops.group([v.initializer for v in self._index.values()]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + @property + def op(self): + # We want cross-tower code that does some var.op.X calls + # to work (even if the current device isn't in self.devices), but + # other uses of var.op in a cross-tower context to fail. + if distribute_lib.get_cross_tower_context(): + return DistributedVarOp(self._primary_var.op.name, + self._primary_var.op.graph, + self._primary_var.op.type) + return self.get().op + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) +ops.register_dense_tensor_like_type(DistributedVariable) + + +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + +def _get_update_device(): + """Validate we are in update/update_non_slot() and return current device. + + This is used in MirroredVariable.assign* members, to make sure they + are only called via an update method, to make sure all components of the + variable are being updated in a consistent way. + + Returns: + A string device. + + Raises: + RuntimeError: If not in distribution.update()/.update_non_slot(). + """ + device = distribute_lib.get_update_device() + if device is None: + raise RuntimeError( + "Use DistributionStrategy.update() to modify a MirroredVariable.") + return device + + +class MirroredVariable(DistributedVariable, Mirrored, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are kept in sync.""" + + def __init__(self, index, primary_var): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._primary_var = primary_var + super(MirroredVariable, self).__init__(index) + + # We use _get_update_device() for the assign* methods to enforce + # that we are in an update() function. The arguments to update() are + # automatically unwrapped so the update() function would normally + # see regular variables, not MirroredVariables. However, the update + # function can still operate on wrapped MirroredVariables through + # object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def assign_sub(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign(*args, **kwargs) + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return array_ops.identity(self._index[device]) + return array_ops.identity(self._primary_var) + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): + """Class for defining how to restore a TowerLocalVariable.""" + + def __init__(self, tower_local_variable, name): + self._tower_local_variable = tower_local_variable + # We use a callable so that we don't have to evaluate this expression + # in the case where we are trying to restore instead of save. + def tensor(): + return distribute_lib.get_distribution_strategy().fetch( + tower_local_variable) + spec = saver.BaseSaverBuilder.SaveSpec( + tensor=tensor, + slice_spec="", + name=name, + dtype=tower_local_variable.dtype) + super(_TowerLocalSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + if self._tower_local_variable.reduce_method == "sum": + tensor *= 1. / len(self._tower_local_variable.devices) + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + + +def _assert_tower_context(): + if not distribute_lib.get_tower_context(): + raise RuntimeError( + "Tower-local variables may only be assigned in a tower context.") + + +class TowerLocalVariable(DistributedVariable, PerDevice, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are reduced on save.""" + + def __init__(self, index, primary_var, reduce_method): + self._primary_var = primary_var + self._reduce_method = reduce_method + super(TowerLocalVariable, self).__init__(index) + + def assign_sub(self, *args, **kwargs): + _assert_tower_context() + return self.get().assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + _assert_tower_context() + return self.get().assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + _assert_tower_context() + return self.get().assign(*args, **kwargs) + + @property + def reduce_method(self): + return self._reduce_method + + def _get_cross_tower(self): + all_components = tuple(self._index.values()) + # TODO(josh11b): Use a strategy-specific method. + total = math_ops.add_n(all_components) + if self._reduce_method == "mean": + return total * (1./ len(all_components)) + return total + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._get_cross_tower() + return self.get()._as_graph_element() + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + TowerLocalVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _TowerLocalSaveable(self, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +def _devices_match(d1, d2): + return device_util.canonicalize(d1) == device_util.canonicalize(d2) + + +def regroup(per_device, wrap_class=PerDevice): + """Makes device->nest map into a nest of PerDevice/Mirrored values.""" + items = list(per_device.items()) + assert items + v0 = items[0][1] # First value + + if isinstance(v0, list): + for _, v in items[1:]: + assert isinstance(v, list) + assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % + (len(v), len(v0), v, v0)) + return [regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))] + + if isinstance(v0, tuple): + for _, v in items[1:]: + assert isinstance(v, tuple) + assert len(v) == len(v0) + regrouped_tuple = tuple(regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))) + if hasattr(v0, "_fields"): + # This tuple is in fact a namedtuple! Create a new namedtuple instance + # and initialize it with the regrouped values: + assert hasattr(type(v0), "_make") + return type(v0)._make(regrouped_tuple) + else: + return regrouped_tuple + + if isinstance(v0, dict): + v0keys = set(v0.keys()) + for _, v in items[1:]: + assert isinstance(v, dict) + assert set(v.keys()) == v0keys + return {key: regroup({k: v[key] for k, v in items}, wrap_class) + for key in v0keys} + + # If exactly the same object across all devices, return it unwrapped. + same_id = True + for _, v in items[1:]: + if v is not v0: + same_id = False + break + # Consider three cases where same_id is true: + # * If v0 is a MirroredVariable (and same_id means it is the same + # across all devices), we want to return it. We check + # MirroredVariable specifically since it can look like it + # has a _mirrored_container member since its members do. + # * If v0 is a member of a mirrored variable, in which case + # hasattr(v0, "_mirrored_container") is true, we want to + # return the MirroredVariable that contains it using the + # _mirrored_container logic below. This case can trigger + # same_id when there is only one device. + # * In any other situation, same_id means we return v0. + if same_id and (isinstance(v0, MirroredVariable) or + not hasattr(v0, "_mirrored_container")): + return v0 + + # Detect the case where each device has a parallel component of the + # same MirroredVariable. In this case we want to return the + # containing MirroredVariable, after a bunch of sanity checking. + # In particular, each component should have the same container, + # and the devices of the variables should match the keys of the + # per-device dictionary. + # TODO(josh11b): Do we need similar logic for TowerLocalVariables? + if hasattr(v0, "_mirrored_container"): + # pylint: disable=protected-access + assert not isinstance(v0, MirroredVariable), ( + "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) + assert _devices_match(v0.device, items[0][0]), ( + "v0.device = %s, items = %s" % (v0.device, items)) + mirrored_container = v0._mirrored_container() + assert mirrored_container is not None + for d, v in items[1:]: + assert _devices_match(v.device, d), ( + "v.device = %s, d = %s, items = %s" % (v.device, d, items)) + assert mirrored_container is v._mirrored_container() + return mirrored_container + # pylint: enable=protected-access + + return wrap_class(per_device) + + +def select_device(device, structured): + """Specialize a nest of regular & per-device values for one device.""" + def _get(x): + return x.get(device) if isinstance(x, DistributedValues) else x + + return nest.map_structure(_get, structured) + + +def select_device_mirrored(device, structured): + """Specialize a nest of regular & mirrored values for one device.""" + def _get_mirrored(x): + if isinstance(x, DistributedValues): + if not isinstance(x, Mirrored): + raise TypeError( + "Expected value to be mirrored across towers: %s in %s." % + (x, structured)) + return x.get(device) + else: + return x + + return nest.map_structure(_get_mirrored, structured) + + +class PerDeviceDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" + + def __init__(self, iterator, devices, prefetch_on_device=None): + self._iterator = iterator + self._devices = devices + self._prefetch_on_device = prefetch_on_device + + @property + def initializer(self): + return self._iterator.initializer + + def get_next(self, name=None): + """Scatter the input across devices.""" + if self._prefetch_on_device: + data_list = self._iterator.get_next(name=name) + index = dict(zip(self._devices, data_list)) + else: + batch = self._iterator.get_next(name=name) + index = {} + def get_ith(i): + return lambda x: x[i] + + for i, d in enumerate(self._devices): + index[d] = nest.map_structure(get_ith(i), batch) + if context.executing_eagerly(): + with ops.device(d): + index[d] = nest.map_structure(array_ops.identity, index[d]) + + return regroup(index) + + +class PerDeviceDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" + + def __init__(self, dataset, devices, prefetch_on_device=None): + self._devices = devices + + # Default to using prefetching in graph mode, unless specified. + # TODO(priyag): Enable prefetching in eager mode. + self._prefetch_on_device = prefetch_on_device + if self._prefetch_on_device is None: + self._prefetch_on_device = not context.executing_eagerly() + assert not (self._prefetch_on_device and context.executing_eagerly()), ( + "Prefetching is only supported in graph mode currently") + + if self._prefetch_on_device: + self._dataset = dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + else: + # TODO(priyag): If dropping remainder is not appropriate, find another + # approach to distributing the dataset when not possible to divide evenly. + # Possibly not an issue when we start using PartitionedDataset. + self._dataset = dataset.apply( + batching.batch_and_drop_remainder(len(devices))) + + def make_one_shot_iterator(self): + """Get a one time use iterator for the distributed PerDeviceDataset.""" + dataset_iterator = self._dataset.make_one_shot_iterator() + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + + def make_initializable_iterator(self): + """Get an initializable iterator for the distributed PerDeviceDataset.""" + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + + +class MultiWorkerDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" + + def __init__(self, iterators, worker_device_map): + """Initialize the MultiWorkerDataIterator object. + + Args: + iterators: a dict mapping from each worker to an iterator for + that worker. + worker_device_map: a dict mapping from each worker's devices to a list of + devices that belong to this worker. + + Raises: + ValueError: if iterators and worker_device_map are not compatible. + """ + self._iterators = iterators + self._worker_device_map = worker_device_map + if set(self._iterators) != set(self._worker_device_map): + raise ValueError("iterators and worker_device_map are not compatible.") + + @property + def initializer(self): + return control_flow_ops.group( + [iterator.initializer for iterator in self._iterators.values()]) + + def get_next(self, name=None): + """Scatter the input across hosts and devices.""" + index = {} + for worker, iterator in six.iteritems(self._iterators): + if name is not None: + d = tf_device.DeviceSpec.from_string(worker) + new_name = "%s_%s_%d" % (name, d.job, d.task) + else: + new_name = None + with ops.device(worker): + data_per_worker = iterator.get_next(name=new_name) + + worker_devices = self._worker_device_map[worker] + # Ungroup these per-device value so as to get a flat map from devices to + # values. + for d in worker_devices: + v = select_device(d, data_per_worker) + if d in index: + raise ValueError("Duplicated devices in worker_device_map: %r" % v) + index[d] = v + + return regroup(index) + + +class MultiWorkerDataset(object): + """Like a `tf.data.Dataset` that distributes data to different workers. + + Each worker gets one shard of the input dataset. It is currently not working + in + eager mode. + """ + + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + """Initialize the MultiWorkerDataset object. + + Args: + dataset_fn: a function that returns a `tf.data.Dataset`. + worker_device_map: a dict mapping from each worker to a list of devices + that belong to this worker. + prefetch_on_device: whether to prefetch to devices. + """ + self._worker_device_map = worker_device_map + self._datasets = {} + # TODO(yuefengz, priyag): support different set of jobs for input + # processing. + for i, (worker, worker_devices) in enumerate( + six.iteritems(worker_device_map)): + with ops.device(worker): + worker_input = dataset_fn() + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) + self._datasets[worker] = PerDeviceDataset( + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + + def make_one_shot_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_one_shot_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + def make_initializable_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_initializable_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + +class _PerKey(object): + """Holds data associated by keys.""" + + def __init__(self, *index): + # pylint: disable=protected-access + self._index = list(index) + + def get(self, iteration): + return array_ops.gather(self._index, iteration) + + def get_shape(self): + return self._index[-1][-1].get_shape() + + def get_dtype(self): + return self._index[-1][-1].dtype + + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + +class PerIteration(_PerKey): + """Holds input for multiple iterations at once.""" + + def __init__(self, *index): + # pylint: disable=protected-access + super(PerIteration, self).__init__(*[batch._index for batch in index]) + + +class Batches(_PerKey): + pass + + +class MultiIterator(object): + """Iterator that returns results of multiple get_next()s.""" + + def __init__(self, dataset_iterator, iterations, batches_per_iteration): + self._dataset_iterator = dataset_iterator + self._iterations = iterations + self._batches_per_iteration = batches_per_iteration + + def get_next(self, name=None): + """Return PerIteration with `iterations x batches_per_iteration` inputs.""" + data = [] + for _ in range(self._batches_per_iteration): + batch = [] + for _ in range(self._iterations): + batch.append(self._dataset_iterator.get_next(name=name)) + data.append(batch) + + # Here is an example. Suppose each get_next returns a tuple of two tensors. + # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: + # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] + # + # After the first `map_structure` it gets transformed to: + # [(Batches(a, A), Batches(z, Z)), + # (Batches(b, B), Batches(y, Y)), + # (Batches(c, C), Batches(x, X))] + # + # After the second `map_structure` it gets transformed to a tuple of: + # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), + # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) + + data = nest.map_structure(Batches, *data) + data = nest.map_structure(PerIteration, *data) + + return data + + @property + def initializer(self): + return self._dataset_iterator.initializer + + +class PerIterationDataset(object): + """A dataset that returns MultiIterators.""" + + def __init__(self, dataset, iterations, batches_per_iteration): + self._dataset = dataset + self._iterations = iterations + self._batches_per_iteration = batches_per_iteration + + def make_one_shot_iterator(self): + iterator = self._dataset.make_one_shot_iterator() + return MultiIterator(iterator, self._iterations, + self._batches_per_iteration) + + def make_initializable_iterator(self): + iterator = self._dataset.make_initializable_iterator() + return MultiIterator(iterator, self._iterations, + self._batches_per_iteration) + + +class MapOutput(object): + """Map can result in multiple outputs per device.""" + + def __init__(self, l): + self._l = l + + def get(self): + return self._l diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1c95758d96aba47e9581dde6411763e98b99a968 --- /dev/null +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -0,0 +1,971 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the distributed values library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training import device_util +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest + + +class DistributedValuesTest(test.TestCase): + + def testGetEager(self): + with ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testGetGraph(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testCanonicalization(self): + canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] + v = values.DistributedValues({"": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/device:CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/cpu:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + with self.assertRaises(AssertionError): + v = values.DistributedValues({"/device:cpu:0": 42}) + + +class DistributedDelegateTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testGetAttr(self): + with ops.device("/device:CPU:0"): + + class Foo(object): + + def __init__(self, x): + self.x = x + + v = values.DistributedDelegate( + {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + self.assertEqual(7, v.x) + with self.assertRaises(AttributeError): + _ = v.y + + @test_util.run_in_graph_and_eager_modes() + def testOperatorOverride(self): + with ops.device("/device:CPU:0"): + v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + # v should act like int(7). + self.assertEqual(8, v + 1) + self.assertEqual(10, 3 + v) + self.assertEqual(14, v + v) + self.assertEqual(5, v - 2) + self.assertEqual(6, 13 - v) + self.assertEqual(0, v - v) + self.assertEqual(14, v * 2) + self.assertEqual(21, 3 * v) + self.assertEqual(49, v * v) + self.assertEqual(3.5, v / 2) + self.assertEqual(1.5, 10.5 / v) + self.assertEqual(3, v // 2) + self.assertEqual(2, 15 // v) + self.assertEqual(1, v % 2) + self.assertEqual(2, 16 % v) + self.assertTrue(v < 12) + self.assertTrue(v <= 12) + self.assertFalse(v > 12) + self.assertFalse(v >= 12) + self.assertFalse(12 < v) + self.assertFalse(12 <= v) + self.assertTrue(12 > v) + self.assertTrue(12 >= v) + self.assertEqual(3, v & 3) + self.assertEqual(3, 11 & v) + self.assertEqual(15, v | 8) + self.assertEqual(23, 16 | v) + self.assertEqual(4, v ^ 3) + self.assertEqual(12, 11 ^ v) + self.assertEqual(343, pow(v, 3)) + self.assertEqual(3, pow(v, 3, 10)) + self.assertEqual(128, pow(2, v)) + self.assertEqual(-7, -v) + self.assertEqual(~7, ~v) + self.assertEqual(7, abs(v)) + with self.assertRaises(TypeError): + _ = v[2] + + +def _device_str(d): + return "/device:GPU:" + str(d) + + +def _nested_value(d): + return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) + + +def _make_mirrored(): + v = [] + index = {} + devices = ["/device:GPU:0", "/device:CPU:0"] + for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + mirrored = values.MirroredVariable(index, v[0]) + return v, devices, mirrored + + +class RegroupAndSelectDeviceTest(test.TestCase): + + def _is_per_device(self, result, expected, klass=values.PerDevice): + self.assertIsInstance(result, klass) + # We canonicalize the devices to match the device strings returned + # by PerDevice, which also does device string canonicalization. + devices = [device_util.canonicalize(_device_str(i)) + for i in range(len(expected))] + self.assertEqual(set(devices), set(result.devices)) + for i, d in enumerate(devices): + self.assertEqual(expected[i], result.get(d)) + self.assertEqual(expected[i], result.get(_device_str(i))) + + def testNested(self): + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"]) + self._is_per_device(result[2], ["h1", "h2"]) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"]) + self._is_per_device(result[1][2], ["g1", "g2"]) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"]) + self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # select_device_mirrored() should fail due to non-mirrored values + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(0), result) + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(1), result) + + def testWrapClass(self): + # Normally a mirrored value would be the same across devices, but + # for a test it is convenient to be able to tell the values apart. + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}, + values.Mirrored) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # Values are marked as mirrored, so select_device_mirrored() is allowed. + self.assertEqual(_nested_value("1"), + values.select_device_mirrored(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device_mirrored(_device_str(1), result)) + + def testMirroredContainer(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + v, devices, mirrored = _make_mirrored() + result = values.regroup(dict(zip(devices, v))) + self.assertIs(mirrored, result) + + def testSameId(self): + foo = object() + result = values.regroup({_device_str(0): ("a", foo), + _device_str(1): ("b", foo)}) + self.assertIsInstance(result, tuple) + self.assertEqual(2, len(result)) + self._is_per_device(result[0], ["a", "b"]) + self.assertIs(foo, result[1]) + + # Test select_device(), should undo the merge done by regroup(). + result_0 = values.select_device(_device_str(0), result) + self.assertIsInstance(result_0, tuple) + self.assertEqual(2, len(result_0)) + self.assertEqual("a", result_0[0]) + self.assertIs(foo, result_0[1]) + result_1 = values.select_device(_device_str(1), result) + self.assertIsInstance(result_1, tuple) + self.assertEqual(2, len(result_1)) + self.assertEqual("b", result_1[0]) + self.assertIs(foo, result_1[1]) + + def testOneDevice(self): + result = values.regroup({_device_str(0): _nested_value("1")}) + # On one device regroup() and select_device() are basically identity. + self.assertEqual(_nested_value("1"), result) + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + + # The one exception has to do with MirroredVariables. + d = "/device:CPU:0" + with ops.device(d): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + index = {d: v} + mirrored = values.MirroredVariable(index, v) + result = values.regroup(index) + self.assertIs(mirrored, result) + + def testNamedTupleEstimatorSpec(self): + with context.graph_mode(), ops.Graph().as_default(): + created_estimator_specs = {} + to_regroup = {} + + for device_id in range(3): + spec = model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=constant_op.constant(device_id / 2), + train_op=array_ops.identity(constant_op.constant(device_id))) + created_estimator_specs[device_id] = spec + to_regroup[_device_str(device_id)] = spec + + merged_estimator_spec = values.regroup(to_regroup) + + self.assertTrue( + isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) + self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + for device_id in range(3): + d = _device_str(device_id) + self.assertEquals(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEquals(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) + # Scaffold is populated by `EstimatorSpec.__new__`. + self.assertEquals(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) + # Also test that we can undo the merge using select_device() + self.assertEquals(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) + + +class PerDeviceDatasetTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _test_iterator_no_prefetch(self, devices, dataset, expected_values): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=False) + iterator = per_device_dataset.make_one_shot_iterator() + + for expected_value in expected_values: + next_element = iterator.get_next() + actual = self.evaluate([ + values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator_with_prefetch(self, devices, dataset, expected_values): + if not context.executing_eagerly(): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=True) + iterator = per_device_dataset.make_one_shot_iterator() + + # With prefetching, we cannot guarantee which input ends up on which + # device, so we verify that the complete set seen on all devices is + # correct, and equal numbers are distributed to each device. + combined_actual = [] + combined_expected = [] + for expected_value in expected_values: + next_element = iterator.get_next() + combined_actual.extend(self.evaluate([ + values.select_device(d, next_element) for d in devices])) + combined_expected.extend(expected_value) + + self.assertEqual(set(combined_expected), set(combined_actual)) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator(self, devices, dataset, expected_values): + self._test_iterator_no_prefetch(devices, dataset, expected_values) + self._test_iterator_with_prefetch(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes() + def testOneDevice(self): + devices = ["/device:CPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleDevices(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTupleDataset(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnevenDatasetBatches(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(devices, dataset, expected_values) + + def testInitializableIterator(self): + with context.graph_mode(): + devices = ["/device:CPU:0"] + # Using random input since that is only allowed with initializable + # iterator. + dataset = dataset_ops.Dataset.from_tensor_slices( + random_ops.random_uniform((10,))) + + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=False) + iterator = per_device_dataset.make_initializable_iterator() + + self.evaluate(iterator.initializer) + next_element = iterator.get_next() + for _ in range(10): + self.evaluate(next_element) + + # Should fail after the input is finished. + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next_element) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(iterator.initializer) + for _ in range(10): + self.evaluate(next_element) + + +class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): + + def _test_iterator(self, iterator, devices, expected_values): + next_element = iterator.get_next() + for device in devices: + v = values.select_device(device, next_element) + # The `v` here can be a tuple. + for element in nest.flatten(v): + self.assertTrue(element.device in device) + + for expected_value in expected_values: + actual = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + self.evaluate([values.select_device(d, next_element) for d in devices]) + + def _test_dataset(self, dataset_fn, worker_device_map, devices, + expected_values): + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() + self._test_iterator(multi_worker_iterator, devices, expected_values) + + def _cpu_devices(self): + worker_device_map = collections.OrderedDict( + [("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])]) + devices = [ + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_device_map, devices + + def _cpu_and_one_gpu_devices(self): + # The worker_device_map doesn't have to be a OrderDict object, this is just + # to simplify the testing so that we can pass expected values as a list + # instead of a dict. + worker_device_map = collections.OrderedDict( + [("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ])]) + devices = [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_device_map, devices + + def testDataDistributionOneDevicePerWorker(self): + worker_device_map, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset(dataset_fn, worker_device_map, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + def testDataDistributionTwoDevicePerWorker(self): + if context.num_gpus() < 1: + self.skipTest("A GPU is not available for this test.") + worker_device_map, devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset(dataset_fn, worker_device_map, devices, + [[0, 2, 1, 3], [4, 6, 5, 7]]) + + def testTupleDataset(self): + worker_device_map, devices = self._cpu_devices() + + with context.graph_mode(): + + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(8) + dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [ + [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) + ] + self._test_dataset(dataset_fn, worker_device_map, devices, + expected_values) + + def testInitializableIterator(self): + worker_device_map, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + + self.evaluate(multi_worker_iterator.initializer) + self._test_iterator(multi_worker_iterator, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(multi_worker_iterator.initializer) + self._test_iterator(multi_worker_iterator, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + def testValueErrorForIterator(self): + # Incompatiable arguments. + with self.assertRaises(ValueError): + values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) + + # Test duplicated devices under same worker. + worker_device_map, _ = self._cpu_devices() + worker_device_map["/job:worker/replica:0/task:0"].append( + "/job:worker/replica:0/task:0/device:CPU:0") + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.assertRaises(ValueError): + multi_worker_iterator.get_next() + + +class MirroredVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, _, mirrored = _make_mirrored() + + self.assertEquals(v[0].name, mirrored.name) + self.assertEquals(v[0].dtype, mirrored.dtype) + self.assertEquals(v[0].shape, mirrored.shape) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + mirrored = values.MirroredVariable(index, v) + + self.assertEquals(v.name, mirrored.name) + self.assertEquals(v.dtype, mirrored.dtype) + self.assertEquals(v.shape, mirrored.shape) + + def _assign_mirrored(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreMirroredOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path, saver = self._save_return_saver(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + + # Restores the saved value of 3. to both variables. + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + def _save_mirrored(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path = self._save(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.)) + + # Saves the current value of var, 3. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3. to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3., self.evaluate(var)) + + def _restore_mirrored(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [7., 8.]) + + # Restores the saved value of 3. to both variables. + saver = saver_lib.Saver(var_list=[mirrored]) + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_mirrored(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_mirrored(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testFetchAMirroredVariable(self): + if context.num_gpus() < 1 or context.executing_eagerly(): + self.skipTest("A GPU is not available for this test or it's eager mode.") + + with self.test_session( + graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( + ["/device:GPU:0"]).scope(): + with ops.device("/device:GPU:0"): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + sess.run(variables_lib.global_variables_initializer()) + sess.run({"complicated": mirrored}) + + +_devices = ["/device:GPU:0", "/device:CPU:0"] + + +def _make_tower_local(method): + v = [] + index = {} + for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + tower_local = values.TowerLocalVariable(index, v[0], method) + return v, tower_local + + +class TowerLocalVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, tower_local = _make_tower_local("sum") + + self.assertEquals(v[0].name, tower_local.name) + self.assertEquals(v[0].dtype, tower_local.dtype) + self.assertEquals(v[0].shape, tower_local.shape) + self.assertEquals("sum", tower_local.reduce_method) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + tower_local = values.TowerLocalVariable(index, v, "mean") + + self.assertEquals(v.name, tower_local.name) + self.assertEquals(v.dtype, tower_local.dtype) + self.assertEquals(v.shape, tower_local.shape) + self.assertEquals("mean", tower_local.reduce_method) + + def _assign_tower_local(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + def _dist_scope(self): + return mirrored_strategy.MirroredStrategy(_devices).scope() + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalSumOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 7. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 7. which gets divided equally + # between the variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalMeanOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 3.5 to both variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _save_tower_local_mean(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_tower_local_sum(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [1.5, 2.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.5)) + + # Saves the current value of var, 3.5. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3.5 to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3.5, self.evaluate(var)) + + def _restore_tower_local_mean(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _restore_tower_local_sum(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_tower_local_sum(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalMeanRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalSumRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_sum(save_path) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 7f510c42215f48a9e795eb81bd9f66b0a2108335..ad00d1734dd14ed846522a33d888a5387cb25cc6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -16,6 +16,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "bijectors_py", srcs = glob(["python/ops/bijectors/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/linalg:linalg_py", @@ -42,6 +49,13 @@ py_library( py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ ":bijectors_py", @@ -94,7 +108,7 @@ cuda_py_test( cuda_py_test( name = "distribution_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/distribution_test.py"], additional_deps = [ ":distributions_py", @@ -251,6 +265,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_test", + srcs = ["python/kernel_tests/kumaraswamy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "moving_stats_test", size = "small", @@ -322,7 +351,7 @@ cuda_py_test( cuda_py_test( name = "mvn_tril_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_tril_test.py"], additional_deps = [ ":distributions_py", @@ -335,6 +364,7 @@ cuda_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", ], + tags = ["nomsan"], ) cuda_py_test( @@ -356,6 +386,7 @@ cuda_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:variables", ], + shard_count = 4, ) cuda_py_test( @@ -403,7 +434,7 @@ cuda_py_test( cuda_py_test( name = "poisson_lognormal_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/poisson_lognormal_test.py"], additional_deps = [ ":distributions_py", @@ -438,6 +469,21 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = ["no_windows"], # TODO: needs investigation on Windows +) + +cuda_py_test( + name = "batch_reshape_test", + size = "medium", + srcs = ["python/kernel_tests/batch_reshape_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], ) cuda_py_test( @@ -459,6 +505,30 @@ cuda_py_test( tags = ["nomsan"], # disable to avoid false positives from scipy. ) +cuda_py_test( + name = "seed_stream_test", + size = "small", + srcs = ["python/kernel_tests/seed_stream_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:client_testlib", + ], +) + +cuda_py_test( + name = "statistical_testing_test", + size = "medium", + srcs = [ + "python/kernel_tests/statistical_testing_test.py", + ], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + ], + shard_count = 4, +) + cuda_py_test( name = "vector_sinh_arcsinh_diag_test", size = "medium", @@ -523,7 +593,7 @@ cuda_py_test( cuda_py_test( name = "wishart_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/wishart_test.py"], additional_deps = [ ":distributions_py", @@ -654,6 +724,8 @@ cuda_py_test( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", ], + shard_count = 4, + tags = ["noasan"], # times out, http://b/78588814 ) cuda_py_test( @@ -710,18 +782,6 @@ cuda_py_test( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # === Bijector Tests ========================================================== cuda_py_test( @@ -782,6 +842,25 @@ cuda_py_test( tags = ["noasan"], # times out b/63678675 ) +cuda_py_test( + name = "affine_scalar_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/affine_scalar_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "affine_linear_operator_test", size = "small", @@ -801,6 +880,23 @@ cuda_py_test( ], ) +cuda_py_test( + name = "batch_normalization_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = ["optonly"], +) + cuda_py_test( name = "chain_test", size = "small", @@ -858,6 +954,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "gumbel_test", size = "small", @@ -915,6 +1030,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_bijector_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "masked_autoregressive_test", size = "small", @@ -931,6 +1065,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "real_nvp_test", size = "small", @@ -984,7 +1137,7 @@ cuda_py_test( cuda_py_test( name = "reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/reshape_test.py"], additional_deps = [ ":bijectors_py", @@ -999,9 +1152,9 @@ cuda_py_test( ) cuda_py_test( - name = "sigmoid_test", + name = "scale_tril_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sigmoid_test.py"], + srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1018,9 +1171,9 @@ cuda_py_test( ) cuda_py_test( - name = "sigmoid_centered_test", + name = "sigmoid_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sigmoid_centered_test.py"], + srcs = ["python/kernel_tests/bijectors/sigmoid_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1055,6 +1208,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) cuda_py_test( @@ -1095,6 +1249,63 @@ cuda_py_test( ], ) +cuda_py_test( + name = "softsign_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/softsign_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "square_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/square_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "transform_diagonal_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "weibull_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 61c411271d0bb8d7b4cc3b14992b82ec1e5674ed..5cec93c4df2e970f203253be6342bb292f296eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. - -See the @{$python/contrib.distributions} guide. """ from __future__ import absolute_import from __future__ import division @@ -24,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.autoregressive import * +from tensorflow.contrib.distributions.python.ops.batch_reshape import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * @@ -31,6 +30,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse @@ -58,6 +58,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.sample_stats import * +from tensorflow.contrib.distributions.python.ops.seed_stream import * from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.test_util import * from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * @@ -96,9 +97,10 @@ _allowed_symbols = [ 'ReparameterizationType', 'Distribution', 'Autoregressive', - 'Binomial', + 'BatchReshape', 'Bernoulli', 'Beta', + 'Binomial', 'BetaWithSoftplusConcentration', 'Categorical', 'Chi2', @@ -124,6 +126,7 @@ _allowed_symbols = [ 'NormalWithSoftplusScale', 'Poisson', 'PoissonLogNormalQuadratureCompound', + 'SeedStream', 'SinhArcsinh', 'StudentT', 'StudentTWithAbsDfSoftplusScale', @@ -152,6 +155,7 @@ _allowed_symbols = [ 'kl_divergence', 'RegisterKL', 'fill_triangular', + 'fill_triangular_inverse', 'matrix_diag_transform', 'reduce_weighted_logsumexp', 'softplus_inverse', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bb2d3325a7cc6ec5803860600149522752a4c0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -0,0 +1,572 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BatchReshape.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib +from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class _BatchReshapeTest(object): + + def make_wishart(self, dims, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = self.dtype([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale], axis=0), + old_batch_shape + [dims, dims]) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + wishart = wishart_lib.WishartFull(df=5, scale=scale_ph) + reshape_wishart = batch_reshape_lib.BatchReshape( + distribution=wishart, + batch_shape=new_batch_shape_ph, + validate_args=True) + + return wishart, reshape_wishart + + def test_matrix_variate_sample_and_log_prob(self): + dims = 2 + new_batch_shape = [4] + old_batch_shape = [2, 2] + wishart, reshape_wishart = self.make_wishart( + dims, new_batch_shape, old_batch_shape) + + batch_shape = reshape_wishart.batch_shape_tensor() + event_shape = reshape_wishart.event_shape_tensor() + + expected_sample_shape = [3, 1] + new_batch_shape + [dims, dims] + x = wishart.sample([3, 1], seed=42) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_wishart.sample([3, 1], seed=42) + + expected_log_prob_shape = [3, 1] + new_batch_shape + expected_log_prob = array_ops.reshape( + wishart.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_wishart.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([dims, dims], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_wishart.batch_shape) + self.assertAllEqual([dims, dims], reshape_wishart.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_matrix_variate_stats(self): + dims = 2 + new_batch_shape = [4] + old_batch_shape = [2, 2] + wishart, reshape_wishart = self.make_wishart( + dims, new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + expected_matrix_stat_shape = new_batch_shape + [dims, dims] + + expected_entropy = array_ops.reshape( + wishart.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_wishart.entropy() + + expected_mean = array_ops.reshape( + wishart.mean(), expected_matrix_stat_shape) + actual_mean = reshape_wishart.mean() + + expected_mode = array_ops.reshape( + wishart.mode(), expected_matrix_stat_shape) + actual_mode = reshape_wishart.mode() + + expected_stddev = array_ops.reshape( + wishart.stddev(), expected_matrix_stat_shape) + actual_stddev = reshape_wishart.stddev() + + expected_variance = array_ops.reshape( + wishart.variance(), expected_matrix_stat_shape) + actual_variance = reshape_wishart.variance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + ]) + + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_variance.shape) + + def make_normal(self, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = self.dtype(0.5 + np.arange( + np.prod(old_batch_shape)).reshape(old_batch_shape)) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + normal = normal_lib.Normal(loc=self.dtype(0), scale=scale_ph) + reshape_normal = batch_reshape_lib.BatchReshape( + distribution=normal, + batch_shape=new_batch_shape_ph, + validate_args=True) + return normal, reshape_normal + + def test_scalar_variate_sample_and_log_prob(self): + new_batch_shape = [2, 2] + old_batch_shape = [4] + + normal, reshape_normal = self.make_normal( + new_batch_shape, old_batch_shape) + + batch_shape = reshape_normal.batch_shape_tensor() + event_shape = reshape_normal.event_shape_tensor() + + expected_sample_shape = new_batch_shape + x = normal.sample(seed=52) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_normal.sample(seed=52) + + expected_log_prob_shape = new_batch_shape + expected_log_prob = array_ops.reshape( + normal.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_normal.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_normal.batch_shape) + self.assertAllEqual([], reshape_normal.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_scalar_variate_stats(self): + new_batch_shape = [2, 2] + old_batch_shape = [4] + + normal, reshape_normal = self.make_normal(new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + + expected_entropy = array_ops.reshape( + normal.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_normal.entropy() + + expected_mean = array_ops.reshape( + normal.mean(), expected_scalar_stat_shape) + actual_mean = reshape_normal.mean() + + expected_mode = array_ops.reshape( + normal.mode(), expected_scalar_stat_shape) + actual_mode = reshape_normal.mode() + + expected_stddev = array_ops.reshape( + normal.stddev(), expected_scalar_stat_shape) + actual_stddev = reshape_normal.stddev() + + expected_variance = array_ops.reshape( + normal.variance(), expected_scalar_stat_shape) + actual_variance = reshape_normal.variance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + ]) + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_variance.shape) + + def make_mvn(self, dims, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + reshape_mvn = batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + return mvn, reshape_mvn + + def test_vector_variate_sample_and_log_prob(self): + dims = 3 + new_batch_shape = [2, 1] + old_batch_shape = [2] + mvn, reshape_mvn = self.make_mvn( + dims, new_batch_shape, old_batch_shape) + + batch_shape = reshape_mvn.batch_shape_tensor() + event_shape = reshape_mvn.event_shape_tensor() + + expected_sample_shape = [3] + new_batch_shape + [dims] + x = mvn.sample(3, seed=62) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_mvn.sample(3, seed=62) + + expected_log_prob_shape = [3] + new_batch_shape + expected_log_prob = array_ops.reshape( + mvn.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_mvn.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([dims], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_mvn.batch_shape) + self.assertAllEqual([dims], reshape_mvn.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_vector_variate_stats(self): + dims = 3 + new_batch_shape = [2, 1] + old_batch_shape = [2] + mvn, reshape_mvn = self.make_mvn( + dims, new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + + expected_entropy = array_ops.reshape( + mvn.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_mvn.entropy() + + expected_vector_stat_shape = new_batch_shape + [dims] + + expected_mean = array_ops.reshape( + mvn.mean(), expected_vector_stat_shape) + actual_mean = reshape_mvn.mean() + + expected_mode = array_ops.reshape( + mvn.mode(), expected_vector_stat_shape) + actual_mode = reshape_mvn.mode() + + expected_stddev = array_ops.reshape( + mvn.stddev(), expected_vector_stat_shape) + actual_stddev = reshape_mvn.stddev() + + expected_variance = array_ops.reshape( + mvn.variance(), expected_vector_stat_shape) + actual_variance = reshape_mvn.variance() + + expected_matrix_stat_shape = new_batch_shape + [dims, dims] + + expected_covariance = array_ops.reshape( + mvn.covariance(), expected_matrix_stat_shape) + actual_covariance = reshape_mvn.covariance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + expected_covariance_, actual_covariance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + expected_covariance, actual_covariance, + ]) + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_covariance_, actual_covariance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_variance.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_covariance.shape) + + def test_bad_reshape_size(self): + dims = 2 + new_batch_shape = [2, 3] + old_batch_shape = [2] # 2 != 2*3 + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp( + ValueError, (r"`batch_shape` size \(6\) must match " + r"`distribution\.batch_shape` size \(2\)")): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r"Shape sizes do not match."): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_non_positive_shape(self): + dims = 2 + old_batch_shape = [4] + if self.is_static_shape: + # Unknown first dimension does not trigger size check. Note that + # any dimension < 0 is treated statically as unknown. + new_batch_shape = [-1, 0] + else: + new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r".*must be >=-1.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_non_vector_shape(self): + dims = 2 + new_batch_shape = 2 + old_batch_shape = [2] + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r".*must be a vector.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_broadcasting_explicitly_unsupported(self): + old_batch_shape = [4] + new_batch_shape = [1, 4, 1] + rate_ = self.dtype([1, 10, 2, 20]) + + rate = array_ops.placeholder_with_default( + rate_, + shape=old_batch_shape if self.is_static_shape else None) + poisson_4 = poisson_lib.Poisson(rate) + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + poisson_141_reshaped = batch_reshape_lib.BatchReshape( + poisson_4, new_batch_shape_ph, validate_args=True) + + x_4 = self.dtype([2, 12, 3, 23]) + x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) + + if self.is_static_shape: + with self.assertRaisesRegexp(NotImplementedError, + "too few batch and event dims"): + poisson_141_reshaped.log_prob(x_4) + with self.assertRaisesRegexp(NotImplementedError, + "unexpected batch and event shape"): + poisson_141_reshaped.log_prob(x_114) + return + + with self.assertRaisesOpError("too few batch and event dims"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_4).eval() + + with self.assertRaisesOpError("unexpected batch and event shape"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_114).eval() + + +class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase): + + dtype = np.float32 + is_static_shape = True + + +class BatchReshapeDynamicTest(_BatchReshapeTest, test.TestCase): + + dtype = np.float64 + is_static_shape = False + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py index e0d65c79b2654c2949de161d6317f218d11cab43..042c8ebd51c47facfc5c942cae56bd56be9df7c5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -18,11 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - # pylint: disable=g-importing-member from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import AbsoluteValue -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,50 +32,38 @@ class AbsoluteValueTest(test.TestCase): def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) self.assertEqual("absolute_value", bijector.name) x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3] y = math_ops.abs(x) y_ = y.eval() - zeros = np.zeros((2, 3)) self.assertAllClose(y_, bijector.forward(x).eval()) self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) - self.assertAllClose((zeros, zeros), - sess.run(bijector.inverse_log_det_jacobian(y))) + self.assertAllClose((0., 0.), + sess.run(bijector.inverse_log_det_jacobian( + y, event_ndims=0))) # Run things twice to make sure there are no issues in caching the tuples # returned by .inverse* self.assertAllClose(y_, bijector.forward(x).eval()) self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) - self.assertAllClose((zeros, zeros), - sess.run(bijector.inverse_log_det_jacobian(y))) - - def testEventNdimsMustBeZeroOrRaiseStatic(self): - with self.test_session(): - with self.assertRaisesRegexp(ValueError, "event_ndims.*was not 0"): - AbsoluteValue(event_ndims=1) - - def testEventNdimsMustBeZeroOrRaiseDynamic(self): - with self.test_session() as sess: - event_ndims = array_ops.placeholder(dtypes.int32) - abs_bijector = AbsoluteValue(event_ndims=event_ndims, validate_args=True) - with self.assertRaisesOpError("event_ndims was not 0"): - sess.run(abs_bijector.inverse_log_det_jacobian([1.]), - feed_dict={event_ndims: 1}) + self.assertAllClose((0., 0.), + sess.run(bijector.inverse_log_det_jacobian( + y, event_ndims=0))) def testNegativeYRaisesForInverseIfValidateArgs(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): sess.run(bijector.inverse(-1.)) def testNegativeYRaisesForILDJIfValidateArgs(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): - sess.run(bijector.inverse_log_det_jacobian(-1.)) + sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 405ddd292cacd8ace87d6caeebf3e8cfc347c22d..1e4ad724d00f751a55370ef9aa6dde0003a2098c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -38,9 +38,11 @@ class AffineLinearOperatorTest(test.TestCase): self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose(ildj, affine.inverse_log_det_jacobian( + y, event_ndims=2).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(), + affine.forward_log_det_jacobian(x, event_ndims=2).eval()) def testDiag(self): with self.test_session(): @@ -58,14 +60,16 @@ class AffineLinearOperatorTest(test.TestCase): self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, affine.inverse_log_det_jacobian(y, event_ndims=1).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=1).eval(), + affine.forward_log_det_jacobian(x, event_ndims=1).eval()) def testTriL(self): with self.test_session(): shift = np.array([-1, 0, 1], dtype=np.float32) - tril = np.array([[[1, 0, 0], + tril = np.array([[[3, 0, 0], [2, -1, 0], [3, 2, 1]], [[2, 0, 0], @@ -85,15 +89,17 @@ class AffineLinearOperatorTest(test.TestCase): # y = np.matmul(x, tril) + shift. y = np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift ildj = -np.sum(np.log(np.abs(np.diagonal( - tril, axis1=-2, axis2=-1))), - axis=-1) + tril, axis1=-2, axis2=-1)))) self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, affine.inverse_log_det_jacobian( + y, event_ndims=2).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(), + affine.forward_log_det_jacobian(x, event_ndims=2).eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d2533620bebeb0400b6d4a6346e8315c7e37c5c6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py @@ -0,0 +1,160 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Affine Scalar Tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import AffineScalar +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class AffineScalarBijectorTest(test.TestCase): + """Tests correctness of the Y = scale @ x + shift transformation.""" + + def testProperties(self): + with self.test_session(): + mu = -1. + # scale corresponds to 1. + bijector = AffineScalar(shift=mu) + self.assertEqual("affine_scalar", bijector.name) + + def testNoBatchScalar(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = -1. + # Corresponds to scale = 2 + bijector = AffineScalar(shift=mu, scale=2.) + x = [1., 2, 3] # Three scalar samples (no batches). + self.assertAllClose([1., 3, 5], run(bijector.forward, x)) + self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float64) + x = array_ops.placeholder(dtypes.float64, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = np.float64([1.]) + # One batch, scalar. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu) + x = np.float64([1.]) # One sample from one batches. + self.assertAllClose([2.], run(bijector.forward, x)) + self.assertAllClose([0.], run(bijector.inverse, x)) + self.assertAllClose( + 0., + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float64) + x = array_ops.placeholder(dtypes.float64, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + multiplier = np.float64([2.]) + # One batch, scalar. + # Corresponds to scale = 2, shift = 0. + bijector = AffineScalar(scale=multiplier) + x = np.float64([1.]) # One sample from one batches. + self.assertAllClose([2.], run(bijector.forward, x)) + self.assertAllClose([0.5], run(bijector.inverse, x)) + self.assertAllClose( + [np.log(0.5)], + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testTwoBatchScalarIdentityViaIdentity(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float32) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = [1., -1] + # Univariate, two batches. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu) + x = [1., 1] # One sample from each of two batches. + self.assertAllClose([2., 0], run(bijector.forward, x)) + self.assertAllClose([0., 2], run(bijector.inverse, x)) + self.assertAllClose( + 0., + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testTwoBatchScalarIdentityViaScale(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float32) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = [1., -1] + # Univariate, two batches. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu, scale=[2., 1]) + x = [1., 1] # One sample from each of two batches. + self.assertAllClose([3., 0], run(bijector.forward, x)) + self.assertAllClose([0., 2], run(bijector.inverse, x)) + self.assertAllClose( + [-np.log(2), 0.], + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testScalarCongruency(self): + with self.test_session(): + bijector = AffineScalar(shift=3.6, scale=0.42) + assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index c9158117f7a982e37047e8dd2b534a30040a87d9..9e14b9a53e6c63876478d876030c476c5d77dbbb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -25,7 +25,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -36,209 +35,26 @@ class AffineBijectorTest(test.TestCase): with self.test_session(): mu = -1. # scale corresponds to 1. - bijector = Affine(shift=mu, event_ndims=0) + bijector = Affine(shift=mu) self.assertEqual("affine", bijector.name) - def testNoBatchScalarViaIdentity(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2 - bijector = Affine( - shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 2, 3] # Three scalar samples (no batches). - self.assertAllClose([1., 3, 5], run(bijector.forward, x)) - self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testNoBatchScalarViaDiag(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2 - bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 2, 3] # Three scalar samples (no batches). - self.assertAllClose([1., 3, 5], run(bijector.forward, x)) - self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testWeirdSampleNoBatchScalarViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2. - bijector = Affine( - shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [[1., 2, 3], [4, 5, 6]] # Weird sample shape. - self.assertAllClose([[1., 3, 5], - [7, 9, 11]], - run(bijector.forward, x)) - self.assertAllClose([[1., 1.5, 2.], - [2.5, 3, 3.5]], - run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value).astype(np.float64) - x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = np.float64([1.]) - # One batch, scalar. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = np.float64([1.]) # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaIdentityIn64BitUserProvidesMultiplierOnly(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value).astype(np.float64) - x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - multiplier = np.float64([2.]) - # One batch, scalar. - # Corresponds to scale = 2, shift = 0. - bijector = Affine(scale_identity_multiplier=multiplier, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = np.float64([1.]) # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.5], run(bijector.inverse, x)) - self.assertAllClose([np.log(0.5)], - run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1.] - # One batch, scalar. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1.] # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testTwoBatchScalarIdentityViaIdentity(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1., -1] - # Univariate, two batches. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 1] # One sample from each of two batches. - self.assertAllClose([2., 0], run(bijector.forward, x)) - self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testTwoBatchScalarIdentityViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1., -1] - # Univariate, two batches. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 1] # One sample from each of two batches. - self.assertAllClose([2., 0], run(bijector.forward, x)) - self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - def testNoBatchMultivariateIdentity(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] # Multivariate # Corresponds to scale = [[1., 0], [0, 1.]] bijector = Affine(shift=mu) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] @@ -251,33 +67,37 @@ class AffineBijectorTest(test.TestCase): x = [[1., 1], [-1., -1]] self.assertAllClose([[2., 0], [0., -2]], run(bijector.forward, x)) self.assertAllClose([[0., 2], [-2., 0]], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + 0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateDiag(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] # Multivariate # Corresponds to scale = [[2., 0], [0, 1.]] bijector = Affine(shift=mu, scale_diag=[2., 1]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] self.assertAllClose([3., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + # Reset bijector. + bijector = Affine(shift=mu, scale_diag=[2., 1]) # x is a 2-batch of 2-vectors. # The first vector is [1, 1], the second is [-1, -1]. # Each undergoes matmul(sigma, x) + shift. @@ -289,120 +109,116 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([[0., 2], [-1., 0]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateFullDynamic(self): with self.test_session() as sess: x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") - event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims") x_value = np.array([[1., 1]], dtype=np.float32) mu_value = np.array([1., -1], dtype=np.float32) scale_diag_value = np.array([2., 2], dtype=np.float32) - event_ndims_value = np.array(1, dtype=np.int32) feed_dict = { x: x_value, mu: mu_value, scale_diag: scale_diag_value, - event_ndims: event_ndims_value } - bijector = Affine( - shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) + bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose( -np.log(4), - sess.run(bijector.inverse_log_det_jacobian(x), feed_dict)) + sess.run(bijector.inverse_log_det_jacobian(x, event_ndims=1), + feed_dict)) def testBatchMultivariateIdentity(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value, dtype=np.float32) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [[1., -1]] # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale = 2. bijector = Affine(shift=mu, scale_identity_multiplier=scale) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(4), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(4), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateDiag(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value, dtype=np.float32) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [[1., -1]] # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale_diag = [[2., 2]] bijector = Affine(shift=mu, scale_diag=scale_diag) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) - self.assertAllClose([-np.log(4)], - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + [-np.log(4)], + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateFullDynamic(self): with self.test_session() as sess: x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") - event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims") x_value = np.array([[[1., 1]]], dtype=np.float32) mu_value = np.array([[1., -1]], dtype=np.float32) scale_diag_value = np.array([[2., 2]], dtype=np.float32) - event_ndims_value = 1 feed_dict = { x: x_value, mu: mu_value, scale_diag: scale_diag_value, - event_ndims: event_ndims_value } - bijector = Affine( - shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) + bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict)) - self.assertAllClose([-np.log(4)], - sess.run( - bijector.inverse_log_det_jacobian(x), feed_dict)) + self.assertAllClose( + [-np.log(4)], + sess.run(bijector.inverse_log_det_jacobian( + x, event_ndims=1), feed_dict)) def testIdentityWithDiagUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -410,25 +226,25 @@ class AffineBijectorTest(test.TestCase): bijector = Affine( shift=mu, scale_identity_multiplier=1., - scale_diag=[1., 1., 1.], - event_ndims=1) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" + scale_diag=[1., 1., 1.]) x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.**3), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.**3), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -437,46 +253,48 @@ class AffineBijectorTest(test.TestCase): shift=mu, scale_identity_multiplier=1., scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 5]], run(bijector.forward, x)) self.assertAllClose([[1., 0.5]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(4.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(4.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. # scale = [[2., 0], [2, 3]] bijector = Affine( shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 7]], run(bijector.forward, x)) self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(6.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(6.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityAndDiagWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -486,23 +304,24 @@ class AffineBijectorTest(test.TestCase): scale_identity_multiplier=1.0, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[2., 9]], run(bijector.forward, x)) self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(12.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(12.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -514,7 +333,6 @@ class AffineBijectorTest(test.TestCase): scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = Affine(shift=mu, scale_diag=[10., 2, 3]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 3, 8], run(bijector.forward, x)) self.assertAllClose( @@ -523,22 +341,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 1.5, 4 / 3.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(60.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(60.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -550,7 +370,6 @@ class AffineBijectorTest(test.TestCase): scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = Affine(shift=mu, scale_diag=[10., 3, 5]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 5, 14], run(bijector.forward, x)) self.assertAllClose( @@ -558,22 +377,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 1., 0.8], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(150.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(150.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -586,7 +407,6 @@ class AffineBijectorTest(test.TestCase): bijector_ref = Affine( shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -594,22 +414,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 14 / 15., 4 / 25.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(150.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(150.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdateNoDiagonal(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -622,7 +444,6 @@ class AffineBijectorTest(test.TestCase): bijector_ref = Affine( shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([5., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -630,11 +451,12 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([1 / 3., 8 / 9., 4 / 30.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(90.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(90.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateRaisesWhenSingular(self): with self.test_session(): @@ -647,38 +469,6 @@ class AffineBijectorTest(test.TestCase): with self.assertRaisesOpError("diagonal part must be non-zero"): bijector.forward([1., 1.]).eval() - def testEventNdimsLargerThanOneRaises(self): - with self.test_session(): - mu = [1., -1] - with self.assertRaisesRegexp( - ValueError, (r"event_ndims\(2\) was not 0 or 1")): - # Scale corresponds to 2x2 identity matrix. - bijector = Affine(shift=mu, event_ndims=2, validate_args=True) - bijector.forward([1., 1.]).eval() - - def testScaleZeroScalarRaises(self): - with self.test_session(): - mu = -1. - # Check Identity matrix with zero scaling. - bijector = Affine( - shift=mu, - scale_identity_multiplier=0., - event_ndims=0, - validate_args=True) - with self.assertRaisesOpError("identity_multiplier should be non-zero"): - bijector.forward(1.).eval() - - def testScaleDiagAndEventNdimsZeroRaises(self): - # Check Diag matrix with zero scaling. - with self.assertRaisesRegexp(ValueError, "only scale argument"): - Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True) - - def testScalarCongruency(self): - with self.test_session(): - bijector = Affine( - shift=3.6, scale_identity_multiplier=0.42, event_ndims=0) - assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) - def _makeScale(self, x, scale_identity_multiplier=None, @@ -747,14 +537,12 @@ class AffineBijectorTest(test.TestCase): scale_args = dict({"x": x}, **args) scale = self._makeScale(**scale_args) - bijector_args = dict({"event_ndims": 1}, **args) - # We haven't specified enough information for the scale. if scale is None: with self.assertRaisesRegexp(ValueError, ("must be specified.")): - bijector = Affine(shift=shift, **bijector_args) + bijector = Affine(shift=shift, **args) else: - bijector = Affine(shift=shift, **bijector_args) + bijector = Affine(shift=shift, **args) np_x = x # For the case a vector is passed in, we need to make the shape # match the matrix for matmul to work. @@ -771,6 +559,7 @@ class AffineBijectorTest(test.TestCase): backward = np.squeeze(backward, axis=-1) self.assertAllClose(backward, bijector.inverse(x).eval()) + scale *= np.ones(shape=x.shape[:-1], dtype=scale.dtype) ildj = -np.log(np.abs(np.linalg.det(scale))) # TODO(jvdillon): We need to make it so the scale_identity_multiplier # case does not deviate in expected shape. Fixing this will get rid of @@ -781,7 +570,8 @@ class AffineBijectorTest(test.TestCase): ildj = np.squeeze(ildj[0]) elif ildj.ndim < scale.ndim - 2: ildj = np.reshape(ildj, scale.shape[0:-2]) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian(x, event_ndims=1).eval()) def testLegalInputs(self): self._testLegalInputs( @@ -829,15 +619,5 @@ class AffineBijectorTest(test.TestCase): x=np.array( [1., 2], dtype=np.float32)) - def testScalarEventIdentityScale(self): - with self.test_session() as sess: - doubler = Affine( - scale_identity_multiplier=2., - event_ndims=0) - doubler2 = doubler.inverse_log_det_jacobian(2.) - doubler2_ildj_ = sess.run([doubler2]) - self.assertAllClose([-np.log(2.)], doubler2_ildj_) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c832fcaa686c92f83810e4f99ca3b23ae694b723 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py @@ -0,0 +1,237 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BatchNorm Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import distributions +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import BatchNormalization +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +class BatchNormTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + def _reduction_axes(self, input_shape, event_dims): + if isinstance(event_dims, int): + event_dims = [event_dims] + ndims = len(input_shape) + # Convert event_dims to non-negative indexing. + event_dims = list(event_dims) + for idx, x in enumerate(event_dims): + if x < 0: + event_dims[idx] = ndims + x + return tuple(i for i in range(ndims) if i not in event_dims) + + def testForwardInverse(self): + """Tests forward and backward passes with different event shapes. + + input_shape: Tuple of shapes for input tensor. + event_dims: Tuple of dimension indices that will be normalized. + training: Boolean of whether bijector runs in training or inference mode. + """ + params = [ + ((5*2, 4), [-1], False), + ((5, 2, 4), [-1], False), + ((5, 2, 4), [1, 2], False), + ((5, 2, 4), [0, 1], False), + ((5*2, 4), [-1], True), + ((5, 2, 4), [-1], True), + ((5, 2, 4), [1, 2], True), + ((5, 2, 4), [0, 1], True) + ] + for input_shape, event_dims, training in params: + x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape) + with self.test_session() as sess: + x = constant_op.constant(x_) + # When training, memorize the exact mean of the last + # minibatch that it normalized (instead of moving average assignment). + layer = normalization.BatchNormalization( + axis=event_dims, momentum=0., epsilon=0.) + batch_norm = BatchNormalization( + batchnorm_layer=layer, training=training) + # Minibatch statistics are saved only after norm_x has been computed. + norm_x = batch_norm.inverse(x) + with ops.control_dependencies(batch_norm.batchnorm.updates): + moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) + moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) + denorm_x = batch_norm.forward(array_ops.identity(norm_x)) + fldj = batch_norm.forward_log_det_jacobian( + x, event_ndims=len(event_dims)) + # Use identity to invalidate cache. + ildj = batch_norm.inverse_log_det_jacobian( + array_ops.identity(denorm_x), event_ndims=len(event_dims)) + variables.global_variables_initializer().run() + # Update variables. + norm_x_ = sess.run(norm_x) + [ + norm_x_, + moving_mean_, + moving_var_, + denorm_x_, + ildj_, + fldj_, + ] = sess.run([ + norm_x, + moving_mean, + moving_var, + denorm_x, + ildj, + fldj, + ]) + self.assertEqual("batch_normalization", batch_norm.name) + + reduction_axes = self._reduction_axes(input_shape, event_dims) + keepdims = len(event_dims) > 1 + + expected_batch_mean = np.mean( + x_, axis=reduction_axes, keepdims=keepdims) + expected_batch_var = np.var(x_, axis=reduction_axes, keepdims=keepdims) + + if training: + # When training=True, values become normalized across batch dim and + # original values are recovered after de-normalizing. + zeros = np.zeros_like(norm_x_) + self.assertAllClose(np.mean(zeros, axis=reduction_axes), + np.mean(norm_x_, axis=reduction_axes)) + + self.assertAllClose(expected_batch_mean, moving_mean_) + self.assertAllClose(expected_batch_var, moving_var_) + self.assertAllClose(x_, denorm_x_, atol=1e-5) + # Since moving statistics are set to batch statistics after + # normalization, ildj and -fldj should match. + self.assertAllClose(ildj_, -fldj_) + # ildj is computed with minibatch statistics. + expected_ildj = np.sum(np.log(1.) - .5 * np.log( + expected_batch_var + batch_norm.batchnorm.epsilon)) + self.assertAllClose(expected_ildj, ildj_) + else: + # When training=False, moving_mean, moving_var remain at their + # initialized values (0., 1.), resulting in no scale/shift (a small + # shift occurs if epsilon > 0.) + self.assertAllClose(x_, norm_x_) + self.assertAllClose(x_, denorm_x_, atol=1e-5) + # ildj is computed with saved statistics. + expected_ildj = np.sum( + np.log(1.) - .5 * np.log(1. + batch_norm.batchnorm.epsilon)) + self.assertAllClose(expected_ildj, ildj_) + + def testMaximumLikelihoodTraining(self): + # Test Maximum Likelihood training with default bijector. + with self.test_session() as sess: + base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) + batch_norm = BatchNormalization(training=True) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=base_dist, + bijector=batch_norm) + target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.]) + target_samples = target_dist.sample(100) + dist_samples = dist.sample(3000) + loss = -math_ops.reduce_mean(dist.log_prob(target_samples)) + with ops.control_dependencies(batch_norm.batchnorm.updates): + train_op = adam.AdamOptimizer(1e-2).minimize(loss) + moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) + moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) + variables.global_variables_initializer().run() + for _ in range(3000): + sess.run(train_op) + [ + dist_samples_, + moving_mean_, + moving_var_ + ] = sess.run([ + dist_samples, + moving_mean, + moving_var + ]) + self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2) + self.assertAllClose([1., 2.], moving_mean_, atol=5e-2) + self.assertAllClose([1., 1.], moving_var_, atol=5e-2) + + def testLogProb(self): + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) + base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=base_dist, + bijector=batch_norm, + validate_args=True) + samples = dist.sample(int(1e5)) + # No volume distortion since training=False, bijector is initialized + # to the identity transformation. + base_log_prob = base_dist.log_prob(samples) + dist_log_prob = dist.log_prob(samples) + variables.global_variables_initializer().run() + base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob]) + self.assertAllClose(base_log_prob_, dist_log_prob_) + + def testMutuallyConsistent(self): + # BatchNorm bijector is only mutually consistent when training=False. + dims = 4 + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=batch_norm, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=2., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + # BatchNorm bijector is only mutually consistent when training=False. + dims = 4 + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = Invert( + BatchNormalization(batchnorm_layer=layer, training=False)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=batch_norm, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=2., + center=0., + rtol=0.02) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index 20e754308449af3f0399101f4ea1bb47b3356424..dc45114b1c23b5edb78d68ad4f38f5201d265170 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -20,21 +20,34 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test +class ShapeChanging(bijector.Bijector): + """Only used for op_ndims manipulation.""" + + def __init__(self, forward_min_event_ndims=0, inverse_min_event_ndims=3): + super(ShapeChanging, self).__init__( + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, + validate_args=False, name="shape_changer") + + class ChainBijectorTest(test.TestCase): """Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation.""" def testBijector(self): with self.test_session(): - chain = Chain((Exp(event_ndims=1), Softplus(event_ndims=1))) + chain = Chain((Exp(), Softplus())) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) @@ -42,9 +55,10 @@ class ChainBijectorTest(test.TestCase): self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval()) self.assertAllClose( -np.sum(np.log(x - 1.), axis=2), - chain.inverse_log_det_jacobian(x).eval()) + chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - np.sum(x, axis=2), chain.forward_log_det_jacobian(x).eval()) + np.sum(x, axis=2), + chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testBijectorIdentity(self): with self.test_session(): @@ -54,33 +68,135 @@ class ChainBijectorTest(test.TestCase): [2., 3.]]]) self.assertAllClose(x, chain.forward(x).eval()) self.assertAllClose(x, chain.inverse(x).eval()) - self.assertAllClose(0., chain.inverse_log_det_jacobian(x).eval()) - self.assertAllClose(0., chain.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + 0., chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) + self.assertAllClose( + 0., chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): - bijector = Chain((Exp(), Softplus())) + chain = Chain((Exp(), Softplus())) assert_scalar_congruency( - bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + chain, lower_x=1e-3, upper_x=1.5, rtol=0.05) def testShapeGetters(self): with self.test_session(): - bijector = Chain([ - SoftmaxCentered( - event_ndims=1, validate_args=True), - SoftmaxCentered( - event_ndims=0, validate_args=True) + chain = Chain([ + SoftmaxCentered(validate_args=True), + SoftmaxCentered(validate_args=True), ]) - x = tensor_shape.TensorShape([]) + x = tensor_shape.TensorShape([1]) y = tensor_shape.TensorShape([2 + 1]) - self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y, chain.forward_event_shape(x)) self.assertAllEqual( y.as_list(), - bijector.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) + chain.forward_event_shape_tensor(x.as_list()).eval()) + self.assertAllEqual(x, chain.inverse_event_shape(y)) self.assertAllEqual( x.as_list(), - bijector.inverse_event_shape_tensor(y.as_list()).eval()) + chain.inverse_event_shape_tensor(y.as_list()).eval()) + + def testMinEventNdimsChain(self): + chain = Chain([Exp(), Exp(), Exp()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Affine(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Exp(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Exp()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Exp(), Softplus(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingAddDims(self): + chain = Chain([ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(3, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(4, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(3, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(), ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(6, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingRemoveDims(self): + chain = Chain([ShapeChanging(3, 0)]) + self.assertEqual(3, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(3, 0), Affine()]) + self.assertEqual(3, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), ShapeChanging(3, 0)]) + self.assertEqual(4, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) + self.assertEqual(6, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingAddRemoveDims(self): + chain = Chain([ + ShapeChanging(2, 1), + ShapeChanging(3, 0), + ShapeChanging(1, 2)]) + self.assertEqual(4, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + def testChainExpAffine(self): + scale_diag = np.array([1., 2., 3.], dtype=np.float32) + chain = Chain([Exp(), Affine(scale_diag=scale_diag)]) + x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] + y = [1., 4., 27.] + self.assertAllClose(y, self.evaluate(chain.forward(x))) + self.assertAllClose(x, self.evaluate(chain.inverse(y))) + self.assertAllClose( + np.log(6, dtype=np.float32) + np.sum(scale_diag * x), + self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) + + self.assertAllClose( + -np.log(6, dtype=np.float32) - np.sum(scale_diag * x), + self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + + def testChainAffineExp(self): + scale_diag = np.array([1., 2., 3.], dtype=np.float32) + chain = Chain([Affine(scale_diag=scale_diag), Exp()]) + x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] + y = [1., 4., 9.] + self.assertAllClose(y, self.evaluate(chain.forward(x))) + self.assertAllClose(x, self.evaluate(chain.inverse(y))) + self.assertAllClose( + np.log(6, dtype=np.float32) + np.sum(x), + self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) + + self.assertAllClose( + -np.log(6, dtype=np.float32) - np.sum(x), + self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + + def testChainIldjWithPlaceholder(self): + chain = Chain((Exp(), Exp())) + samples = array_ops.placeholder( + dtype=np.float32, shape=[None, 10], name="samples") + ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) + self.assertTrue(ildj is not None) + with self.test_session(): + ildj.eval({samples: np.zeros([2, 10], np.float32)}) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index 0ff35304283fce9ce3f9e5d31b1258394e384d7b..e281e81bdf0698c1f7b2f60fb27783dd1351773f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -18,70 +18,114 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib -from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test -class InvertBijectorTest(test.TestCase): - """Tests the correctness of the Y = Invert(bij) transformation.""" +class CholeskyOuterProductBijectorTest(test.TestCase): + """Tests the correctness of the Y = X @ X.T transformation.""" - def testBijector(self): + def testBijectorMatrix(self): with self.test_session(): - for fwd in [ - bijectors.Identity(), - bijectors.Exp(event_ndims=1), - bijectors.Affine( - shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1), - bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), - ]: - rev = bijectors.Invert(fwd) - self.assertEqual("_".join(["invert", fwd.name]), rev.name) - x = [[[1., 2.], - [2., 3.]]] - self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval()) - self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval()) - self.assertAllClose( - fwd.forward_log_det_jacobian(x).eval(), - rev.inverse_log_det_jacobian(x).eval()) - self.assertAllClose( - fwd.inverse_log_det_jacobian(x).eval(), - rev.forward_log_det_jacobian(x).eval()) + bijector = bijectors.CholeskyOuterProduct(validate_args=True) + self.assertEqual("cholesky_outer_product", bijector.name) + x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]] + y = np.matmul(x, np.transpose(x, axes=(0, 2, 1))) + # Fairly easy to compute differentials since we have 2x2. + dx_dy = [[[2. * 1, 0, 0], + [2, 1, 0], + [0, 2 * 2, 2 * 1]], + [[2 * np.sqrt(2.), 0, 0], + [np.sqrt(8.), np.sqrt(2.), 0], + [0, 2 * np.sqrt(8.), 2 * 1]]] + ildj = -np.sum( + np.log(np.asarray(dx_dy).diagonal( + offset=0, axis1=1, axis2=2)), + axis=1) + self.assertAllEqual((2, 2, 2), bijector.forward(x).get_shape()) + self.assertAllEqual((2, 2, 2), bijector.inverse(y).get_shape()) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=2).eval(), atol=0., rtol=1e-7) + self.assertAllClose( + -bijector.inverse_log_det_jacobian( + y, event_ndims=2).eval(), + bijector.forward_log_det_jacobian( + x, event_ndims=2).eval(), + atol=0., + rtol=1e-7) - def testScalarCongruency(self): - with self.test_session(): - bijector = bijectors.Invert(bijectors.Exp()) - assert_scalar_congruency( - bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + def testNoBatchStatic(self): + x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) + y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) + with self.test_session() as sess: + y_actual = bijectors.CholeskyOuterProduct().forward(x=x) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) + self.assertAllEqual([2, 2], y_actual.get_shape()) + self.assertAllEqual([2, 2], x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) - def testShapeGetters(self): - with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) - x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) - self.assertAllEqual(y, bijector.forward_event_shape(x)) - self.assertAllEqual( - y.as_list(), - bijector.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) - self.assertAllEqual( - x.as_list(), - bijector.inverse_event_shape_tensor(y.as_list()).eval()) + def testNoBatchDeferred(self): + x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) + y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtypes.float32) + y_pl = array_ops.placeholder(dtypes.float32) + y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual], + feed_dict={x_pl: x, y_pl: y}) + self.assertEqual(None, y_actual.get_shape()) + self.assertEqual(None, x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) - def testDocstringExample(self): - with self.test_session(): - exp_gamma_distribution = ( - transformed_distribution_lib.TransformedDistribution( - distribution=gamma_lib.Gamma(concentration=1., rate=2.), - bijector=bijectors.Invert(bijectors.Exp()))) - self.assertAllEqual( - [], array_ops.shape(exp_gamma_distribution.sample()).eval()) + def testBatchStatic(self): + x = np.array([[[1., 0], + [2, 1]], + [[3., 0], + [1, 2]]]) # np.linalg.cholesky(y) + y = np.array([[[1., 2], + [2, 5]], + [[9., 3], + [3, 5]]]) # np.matmul(x, x.T) + with self.test_session() as sess: + y_actual = bijectors.CholeskyOuterProduct().forward(x=x) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) + self.assertEqual([2, 2, 2], y_actual.get_shape()) + self.assertEqual([2, 2, 2], x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) + + def testBatchDeferred(self): + x = np.array([[[1., 0], + [2, 1]], + [[3., 0], + [1, 2]]]) # np.linalg.cholesky(y) + y = np.array([[[1., 2], + [2, 5]], + [[9., 3], + [3, 5]]]) # np.matmul(x, x.T) + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtypes.float32) + y_pl = array_ops.placeholder(dtypes.float32) + y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual], + feed_dict={x_pl: x, y_pl: y}) + self.assertEqual(None, y_actual.get_shape()) + self.assertEqual(None, x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 26e0d2a539c78540603281ae0f361987a7bf8d90..f8a52615b0f3f5ad0c7e01e0f76c7d7a6b455ef7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -27,7 +27,7 @@ class _TestBijector(ConditionalBijector): def __init__(self): super(_TestBijector, self).__init__( - event_ndims=0, + forward_min_event_ndims=0, graph_parents=[], is_constant_jacobian=True, validate_args=False, @@ -51,11 +51,15 @@ class ConditionalBijectorTest(test.TestCase): def testConditionalBijector(self): b = _TestBijector() - for name in ["forward", "inverse", "inverse_log_det_jacobian", - "forward_log_det_jacobian"]: + for name in ["forward", "inverse"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1.0, arg1="b1", arg2="b2") + method(1., arg1="b1", arg2="b2") + + for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: + method = getattr(b, name) + with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): + method(1., event_ndims=0, arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py index 9970c0b4d86afda188d9401ebaf3c98d3fffbfdf..7be939cd274e6f0e33c9b01c82494755db2caa73 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py @@ -31,17 +31,21 @@ class ExpBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - bijector = Exp(event_ndims=1) + bijector = Exp() self.assertEqual("exp", bijector.name) x = [[[1.], [2.]]] y = np.exp(x) self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - -np.sum(np.log(y), axis=-1), - bijector.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-bijector.inverse_log_det_jacobian(np.exp(x)).eval(), - bijector.forward_log_det_jacobian(x).eval()) + -np.squeeze(np.log(y), axis=-1), + bijector.inverse_log_det_jacobian( + y, event_ndims=1).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian( + np.exp(x), event_ndims=1).eval(), + bijector.forward_log_det_jacobian( + x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): @@ -51,10 +55,10 @@ class ExpBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): - bijector = Exp(event_ndims=0) + bijector = Exp() x = np.linspace(-10, 10, num=10).astype(np.float32) y = np.logspace(-10, 10, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y) + assert_bijective_and_finite(bijector, x, y, event_ndims=0) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000000000000000000000000000000000..caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes() + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes() + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py index 9a905980c7581a86bbcda8c6c726da57c09fe4f8..54e54c3296a89a4fe29a3cce971760502b65e784 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -34,7 +34,7 @@ class GumbelBijectorTest(test.TestCase): with self.test_session(): loc = 0.3 scale = 5. - bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True) + bijector = Gumbel(loc=loc, scale=scale, validate_args=True) self.assertEqual("gumbel", bijector.name) x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32) # Gumbel distribution @@ -43,13 +43,11 @@ class GumbelBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - np.squeeze(gumbel_dist.logpdf(x), axis=2), - bijector.forward_log_det_jacobian(x).eval()) + np.squeeze(gumbel_dist.logpdf(x), axis=-1), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -60,10 +58,10 @@ class GumbelBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): - bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True) + bijector = Gumbel(loc=0., scale=3.0, validate_args=True) x = np.linspace(-10., 10., num=10).astype(np.float32) y = np.linspace(0.01, 0.99, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py index 739fa6d439a8bce993ab1b4601489d9bbcd69bee..7d3bd758cd2db307f95d2d934923ea2133dc1217 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py @@ -33,15 +33,13 @@ class InlineBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - exp = Exp(event_ndims=1) + exp = Exp() inline = Inline( forward_fn=math_ops.exp, inverse_fn=math_ops.log, - inverse_log_det_jacobian_fn=( - lambda y: -math_ops.reduce_sum( # pylint: disable=g-long-lambda - math_ops.log(y), reduction_indices=-1)), - forward_log_det_jacobian_fn=( - lambda x: math_ops.reduce_sum(x, reduction_indices=-1)), + inverse_log_det_jacobian_fn=lambda y: -math_ops.log(y), + forward_log_det_jacobian_fn=lambda x: x, + forward_min_event_ndims=0, name="exp") self.assertEqual(exp.name, inline.name) @@ -51,9 +49,10 @@ class InlineBijectorTest(test.TestCase): self.assertAllClose(x, inline.inverse(y).eval()) self.assertAllClose( -np.sum(np.log(y), axis=-1), - inline.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-inline.inverse_log_det_jacobian(y).eval(), - inline.forward_log_det_jacobian(x).eval()) + inline.inverse_log_det_jacobian(y, event_ndims=1).eval()) + self.assertAllClose( + -inline.inverse_log_det_jacobian(y, event_ndims=1).eval(), + inline.forward_log_det_jacobian(x, event_ndims=1).eval()) def testShapeGetters(self): with self.test_session(): @@ -62,6 +61,7 @@ class InlineBijectorTest(test.TestCase): forward_event_shape_fn=lambda x: x.as_list() + [1], inverse_event_shape_tensor_fn=lambda x: x[:-1], inverse_event_shape_fn=lambda x: x[:-1], + forward_min_event_ndims=0, name="shape_only") x = tensor_shape.TensorShape([1, 2, 3]) y = tensor_shape.TensorShape([1, 2, 3, 1]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index 0ff35304283fce9ce3f9e5d31b1258394e384d7b..8b14c8327f08902044f50483f9f8dfe67b58cd70 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -34,12 +34,10 @@ class InvertBijectorTest(test.TestCase): with self.test_session(): for fwd in [ bijectors.Identity(), - bijectors.Exp(event_ndims=1), - bijectors.Affine( - shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1), - bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), + bijectors.Exp(), + bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]), + bijectors.Softplus(), + bijectors.SoftmaxCentered(), ]: rev = bijectors.Invert(fwd) self.assertEqual("_".join(["invert", fwd.name]), rev.name) @@ -48,11 +46,11 @@ class InvertBijectorTest(test.TestCase): self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval()) self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval()) self.assertAllClose( - fwd.forward_log_det_jacobian(x).eval(), - rev.inverse_log_det_jacobian(x).eval()) + fwd.forward_log_det_jacobian(x, event_ndims=1).eval(), + rev.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - fwd.inverse_log_det_jacobian(x).eval(), - rev.forward_log_det_jacobian(x).eval()) + fwd.inverse_log_det_jacobian(x, event_ndims=1).eval(), + rev.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): @@ -62,9 +60,9 @@ class InvertBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) + bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True)) x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) + y = tensor_shape.TensorShape([1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) self.assertAllEqual( y.as_list(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a8089881f684db9f8876d6dd738e52bf2f1f7606 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Kumaraswamy Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class KumaraswamyBijectorTest(test.TestCase): + """Tests correctness of the Kumaraswamy bijector.""" + + def testBijector(self): + with self.test_session(): + a = 2. + b = 0.3 + bijector = Kumaraswamy( + concentration1=a, concentration0=b, validate_args=True) + self.assertEqual("kumaraswamy", bijector.name) + x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32) + # Kumaraswamy cdf. This is the same as inverse(x). + y = 1. - (1. - x ** a) ** b + self.assertAllClose(y, bijector.inverse(x).eval()) + self.assertAllClose(x, bijector.forward(y).eval()) + kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) + + (b - 1) * np.log1p(-x ** a)) + + self.assertAllClose( + np.squeeze(kumaraswamy_log_pdf, axis=-1), + bijector.inverse_log_det_jacobian(x, event_ndims=1).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(x, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(y, event_ndims=1).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Kumaraswamy(concentration1=0.5, concentration0=1.1), + lower_x=0., upper_x=1., n=int(10e3), rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + concentration1 = 1.2 + concentration0 = 2. + bijector = Kumaraswamy( + concentration1=concentration1, + concentration0=concentration0, validate_args=True) + # Omitting the endpoints 0 and 1, since idlj will be infinity at these + # endpoints. + y = np.linspace(.01, 0.99, num=10).astype(np.float32) + x = 1 - (1 - y ** concentration1) ** concentration0 + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index dcfb0eb05185d36d96947905c2eb91b2201aece1..5ba5a2083bf11791d7d58146dc2e6283b524d241 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -79,9 +79,10 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, forward_x = ma.forward(x) # Use identity to invalidate cache. inverse_y = ma.inverse(array_ops.identity(forward_x)) - fldj = ma.forward_log_det_jacobian(x) + fldj = ma.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. - ildj = ma.inverse_log_det_jacobian(array_ops.identity(forward_x)) + ildj = ma.inverse_log_det_jacobian( + array_ops.identity(forward_x), event_ndims=1) variables.global_variables_initializer().run() [ forward_x_, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18397035571561731698b06d90e20dc74e3cf83c --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes() + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f5219588fb3be67beb797ba68ed8148e9e9fd2 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.ordered import Ordered +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + + +class OrderedBijectorTest(test.TestCase): + """Tests correctness of the ordered transformation.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorVector(self): + with self.test_session(): + ordered = Ordered() + self.assertEqual("ordered", ordered.name) + x = np.asarray([[2., 3, 4], [4., 8, 13]]) + y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] + self.assertAllClose(y, self.evaluate(ordered.forward(x))) + self.assertAllClose(x, self.evaluate(ordered.inverse(y))) + self.assertAllClose( + np.sum(np.asarray(y)[..., 1:], axis=-1), + self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)), + atol=0., + rtol=1e-7) + self.assertAllClose( + self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)), + self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)), + atol=0., + rtol=1e-7) + + def testBijectorUnknownShape(self): + with self.test_session(): + ordered = Ordered() + self.assertEqual("ordered", ordered.name) + x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_x = np.asarray([[2., 3, 4], [4., 8, 13]]) + y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] + self.assertAllClose(real_y, ordered.forward(x).eval( + feed_dict={x: real_x})) + self.assertAllClose(real_x, ordered.inverse(y).eval( + feed_dict={y: real_y})) + self.assertAllClose( + np.sum(np.asarray(real_y)[..., 1:], axis=-1), + ordered.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), + atol=0., + rtol=1e-7) + self.assertAllClose( + -ordered.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), + ordered.forward_log_det_jacobian(x, event_ndims=1).eval( + feed_dict={x: real_x}), + atol=0., + rtol=1e-7) + + @test_util.run_in_graph_and_eager_modes() + def testShapeGetters(self): + with self.test_session(): + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([4]) + bijector = Ordered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + self.evaluate(bijector.forward_event_shape_tensor( + x.as_list()))) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + self.evaluate(bijector.inverse_event_shape_tensor( + y.as_list()))) + + def testBijectiveAndFinite(self): + with self.test_session(): + ordered = Ordered() + x = np.sort(self._rng.randn(3, 10), axis=-1).astype(np.float32) + y = (self._rng.randn(3, 10)).astype(np.float32) + assert_bijective_and_finite(ordered, x, y, event_ndims=1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py index 54590de373441c32cc3214cb04d45cfc2d1807ed..7eef4ab599951bbb624652f13a0091363b36b93d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -53,8 +53,8 @@ class PermuteBijectorTest(test.TestCase): bijector.permutation, bijector.inverse(expected_y), bijector.forward(expected_x), - bijector.forward_log_det_jacobian(expected_x), - bijector.inverse_log_det_jacobian(expected_y), + bijector.forward_log_det_jacobian(expected_x, event_ndims=1), + bijector.inverse_log_det_jacobian(expected_y, event_ndims=1), ], feed_dict={permutation_ph: expected_permutation}) self.assertEqual("permute", bijector.name) self.assertAllEqual(expected_permutation, permutation_) @@ -78,10 +78,9 @@ class PermuteBijectorTest(test.TestCase): x = np.random.randn(4, 2, 3) y = x[..., permutation] with self.test_session(): - bijector = Permute( - permutation=permutation, - validate_args=True) - assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + bijector = Permute(permutation=permutation, validate_args=True) + assert_bijective_and_finite( + bijector, x, y, event_ndims=1, rtol=1e-6, atol=0) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py index de1659aa9f4d0f7d19ec2e8185715573b78eaf2b..85d22830132816cd6c77cd0b07870f3a22ae9798 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py @@ -32,8 +32,7 @@ class PowerTransformBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): c = 0.2 - bijector = PowerTransform( - power=c, event_ndims=1, validate_args=True) + bijector = PowerTransform(power=c, validate_args=True) self.assertEqual("power_transform", bijector.name) x = np.array([[[-1.], [2.], [-5. + 1e-4]]]) y = (1. + x * c)**(1. / c) @@ -41,27 +40,25 @@ class PowerTransformBijectorTest(test.TestCase): self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( (c - 1.) * np.sum(np.log(y), axis=-1), - bijector.inverse_log_det_jacobian(y).eval()) + bijector.inverse_log_det_jacobian(y, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) def testScalarCongruency(self): with self.test_session(): - bijector = PowerTransform( - power=0.2, validate_args=True) + bijector = PowerTransform(power=0.2, validate_args=True) assert_scalar_congruency( bijector, lower_x=-2., upper_x=1.5, rtol=0.05) def testBijectiveAndFinite(self): with self.test_session(): - bijector = PowerTransform( - power=0.2, event_ndims=0, validate_args=True) + bijector = PowerTransform(power=0.2, validate_args=True) x = np.linspace(-4.999, 10, num=10).astype(np.float32) y = np.logspace(0.001, 10, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py index 46fe7797419a9906ecdad60dd0dfe1e9d7c743ed..2d52895fbe0967cdd2260d6d298a291286858d09 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py @@ -52,24 +52,28 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): forward_x = nvp.forward(x) # Use identity to invalidate cache. inverse_y = nvp.inverse(array_ops.identity(forward_x)) - fldj = nvp.forward_log_det_jacobian(x) + forward_inverse_y = nvp.forward(inverse_y) + fldj = nvp.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. - ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x)) + ildj = nvp.inverse_log_det_jacobian( + array_ops.identity(forward_x), event_ndims=1) variables.global_variables_initializer().run() [ forward_x_, inverse_y_, + forward_inverse_y_, ildj_, fldj_, ] = sess.run([ forward_x, inverse_y, + forward_inverse_y, ildj, fldj, ]) self.assertEqual("real_nvp", nvp.name) - self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) - self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-1, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-1, atol=0.) self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) def testMutuallyConsistent(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index e216d88cb190dc16fc0056186f80817d6f2d7c67..d44e49b4874a5b91f7633cd9c97dbb1a7da70f27 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,15 +22,12 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test -@test_util.with_c_api class _ReshapeBijectorTest(object): """Base class for testing the reshape transformation. @@ -65,8 +62,8 @@ class _ReshapeBijectorTest(object): ildj_) = sess.run(( bijector.inverse(expected_y), bijector.forward(expected_x), - bijector.forward_log_det_jacobian(expected_x), - bijector.inverse_log_det_jacobian(expected_y), + bijector.forward_log_det_jacobian(expected_x, event_ndims=2), + bijector.inverse_log_det_jacobian(expected_y, event_ndims=2), ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) @@ -265,7 +262,6 @@ class _ReshapeBijectorTest(object): raise NotImplementedError("Subclass failed to implement `build_shapes`.") -@test_util.with_c_api class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -301,24 +297,17 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): event_shape_in=[2, 3], event_shape_out=[1, 2, 3], validate_args=True) - assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + assert_bijective_and_finite( + bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): - if ops._USE_C_API: - error_message = "Invalid value in tensor used for shape: -2" - else: - error_message = "elements must be either positive integers or `-1`." - self._testInvalidDimensionsOpError(error_message) + self._testInvalidDimensionsOpError( + "Invalid value in tensor used for shape: -2") def testInputOutputMismatchOpError(self): - if ops._USE_C_API: - error_message = "Cannot reshape a tensor with" - else: - error_message = "Input to reshape is a tensor with" - self._testInputOutputMismatchOpError(error_message) + self._testInputOutputMismatchOpError("Cannot reshape a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -340,7 +329,6 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): self._testInputOutputMismatchOpError("Input to reshape is a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..566a7b3dff9b5d97a1cb143e0b32fc15984c3a02 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ScaleTriLBijectorTest(test.TestCase): + """Tests the correctness of the ScaleTriL bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testComputesCorrectValues(self): + shift = 1.61803398875 + x = np.float32(np.array([-1, .5, 2])) + y = np.float32(np.array([[np.exp(2) + shift, 0.], + [.5, np.exp(-1) + shift]])) + + b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(), + diag_shift=shift) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + @test_util.run_in_graph_and_eager_modes() + def testInvertible(self): + + # Generate random inputs from an unconstrained space, with + # event size 6 to specify 3x3 triangular matrices. + batch_shape = [2, 1] + x = np.float32(np.random.randn(*(batch_shape + [6]))) + b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), + diag_shift=3.14159) + y = self.evaluate(b.forward(x)) + self.assertAllEqual(y.shape, batch_shape + [3, 3]) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(fldj, -ildj) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py deleted file mode 100644 index 4ff3f334ccb59f1c117b3d35032d9e799cfd79bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import SigmoidCentered -from tensorflow.python.platform import test - - -class SigmoidCenteredBijectorTest(test.TestCase): - """Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation.""" - - def testBijector(self): - with self.test_session(): - sigmoid = SigmoidCentered() - self.assertEqual("sigmoid_centered", sigmoid.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] - self.assertAllClose(y, sigmoid.forward(x).eval()) - self.assertAllClose(x, sigmoid.inverse(y).eval()) - self.assertAllClose( - -np.sum(np.log(y), axis=2), - sigmoid.inverse_log_det_jacobian(y).eval(), - atol=0., - rtol=1e-7) - self.assertAllClose( - -sigmoid.inverse_log_det_jacobian(y).eval(), - sigmoid.forward_log_det_jacobian(x).eval(), - atol=0., - rtol=1e-7) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index e4f9d72785c301284812a48c0a67614ca439ffae..cea4a62c22af5d98d38ee881b29c773e6a27a4b4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -36,12 +36,13 @@ class SigmoidBijectorTest(test.TestCase): x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32) y = special.expit(x) ildj = -np.log(y) - np.log1p(-y) - self.assertAllClose(y, Sigmoid().forward(x).eval(), atol=0., rtol=1e-2) - self.assertAllClose(x, Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4) - self.assertAllClose(ildj, Sigmoid().inverse_log_det_jacobian(y).eval(), - atol=0., rtol=1e-6) - self.assertAllClose(-ildj, Sigmoid().forward_log_det_jacobian(x).eval(), - atol=0., rtol=1e-4) + bijector = Sigmoid() + self.assertAllClose(y, bijector.forward(x).eval(), atol=0., rtol=1e-2) + self.assertAllClose(x, bijector.inverse(y).eval(), atol=0., rtol=1e-4) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval(), atol=0., rtol=1e-6) + self.assertAllClose(-ildj, bijector.forward_log_det_jacobian( + x, event_ndims=0).eval(), atol=0., rtol=1e-4) def testScalarCongruency(self): with self.test_session(): @@ -52,7 +53,8 @@ class SigmoidBijectorTest(test.TestCase): x = np.linspace(-7., 7., 100).astype(np.float32) eps = 1e-3 y = np.linspace(eps, 1. - eps, 100).astype(np.float32) - assert_bijective_and_finite(Sigmoid(), x, y, atol=0., rtol=1e-4) + assert_bijective_and_finite( + Sigmoid(), x, y, event_ndims=0, atol=0., rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 172c180a44229089f06f250a872bc47a89991cf0..795f1993ba5c31bf5a26333f31f1bc73125bff07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -39,7 +39,6 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector = SinhArcsinh( skewness=skewness, tailweight=tailweight, - event_ndims=1, validate_args=True) self.assertEqual("SinhArcsinh", bijector.name) x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32) @@ -50,10 +49,11 @@ class SinhArcsinhBijectorTest(test.TestCase): np.sum( np.log(np.cosh(np.arcsinh(y) / tailweight - skewness)) - np.log(tailweight) - np.log(np.sqrt(y**2 + 1)), - axis=-1), bijector.inverse_log_det_jacobian(y).eval()) + axis=-1), + bijector.inverse_log_det_jacobian(y, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -106,14 +106,15 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True) x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace( -2, 10, 1000))).astype(np.float32) - assert_bijective_and_finite(bijector, x, x, rtol=1e-3) + assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectiveAndFiniteSkewness1Tailweight3(self): with self.test_session(): bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True) x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace( -2, 5, 1000))).astype(np.float32) - assert_bijective_and_finite(bijector, x, x, rtol=1e-3) + assert_bijective_and_finite( + bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectorEndpoints(self): with self.test_session(): @@ -124,7 +125,8 @@ class SinhArcsinhBijectorTest(test.TestCase): [np.finfo(dtype).min, np.finfo(dtype).max], dtype=dtype) # Note that the above bijector is the identity bijector. Hence, the # log_det_jacobian will be 0. Because of this we use atol. - assert_bijective_and_finite(bijector, bounds, bounds, atol=2e-6) + assert_bijective_and_finite( + bijector, bounds, bounds, event_ndims=0, atol=2e-6) def testBijectorOverRange(self): with self.test_session(): @@ -149,19 +151,27 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.) self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.) - # Do the numpy calculation in float128 to avoid inf/nan. - y_float128 = np.float128(y) + # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision. + # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and + # below test fails due to overflow error giving inf. So this check avoids that error by skipping square + # calculation and corresponding assert. + + if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \ + np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)): + + # Do the numpy calculation in float128 to avoid inf/nan. + y_float128 = np.float128(y) + self.assertAllClose( + np.log(np.cosh( + np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( + y_float128**2 + 1)) - + np.log(tailweight), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + rtol=1e-4, + atol=0.) self.assertAllClose( - np.log(np.cosh( - np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( - y_float128**2 + 1)) - - np.log(tailweight), - bijector.inverse_log_det_jacobian(y).eval(), - rtol=1e-4, - atol=0.) - self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), rtol=1e-4, atol=0.) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 62e3869db090e9c9327bc552d10234ff76ba28fd..0f0a2fa531a0585a709df4c2c3e2631e5c275986 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test @@ -32,70 +34,68 @@ rng = np.random.RandomState(42) class SoftmaxCenteredBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation.""" - def testBijectorScalar(self): + def testBijectorVector(self): with self.test_session(): - softmax = SoftmaxCentered() # scalar by default + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] + x = np.log([[2., 3, 4], [4., 8, 12]]) + y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] self.assertAllClose(y, softmax.forward(x).eval()) self.assertAllClose(x, softmax.inverse(y).eval()) self.assertAllClose( - -np.sum(np.log(y), axis=2), - softmax.inverse_log_det_jacobian(y).eval(), + -np.sum(np.log(y), axis=1), + softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), + -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(), + softmax.forward_log_det_jacobian(x, event_ndims=1).eval(), atol=0., rtol=1e-7) - def testBijectorVector(self): + def testBijectorUnknownShape(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], [4., 8, 12]]) - y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] - self.assertAllClose(y, softmax.forward(x).eval()) - self.assertAllClose(x, softmax.inverse(y).eval()) + x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_x = np.log([[2., 3, 4], [4., 8, 12]]) + y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] + self.assertAllClose(real_y, softmax.forward(x).eval( + feed_dict={x: real_x})) + self.assertAllClose(real_x, softmax.inverse(y).eval( + feed_dict={y: real_y})) self.assertAllClose( - -np.sum(np.log(y), axis=1), - softmax.inverse_log_det_jacobian(y).eval(), + -np.sum(np.log(real_y), axis=1), + softmax.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), + -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), + softmax.forward_log_det_jacobian(x, event_ndims=1).eval( + feed_dict={x: real_x}), atol=0., rtol=1e-7) def testShapeGetters(self): with self.test_session(): - for x, y, b in ((tensor_shape.TensorShape([]), - tensor_shape.TensorShape([2]), - SoftmaxCentered( - event_ndims=0, validate_args=True)), - (tensor_shape.TensorShape([4]), - tensor_shape.TensorShape([5]), - SoftmaxCentered( - event_ndims=1, validate_args=True))): - self.assertAllEqual(y, b.forward_event_shape(x)) - self.assertAllEqual(y.as_list(), - b.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, b.inverse_event_shape(y)) - self.assertAllEqual(x.as_list(), - b.inverse_event_shape_tensor(y.as_list()).eval()) + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([5]) + bijector = SoftmaxCentered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + bijector.forward_event_shape_tensor( + x.as_list()).eval()) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + bijector.inverse_event_shape_tensor( + y.as_list()).eval()) def testBijectiveAndFinite(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32) # Make y values on the simplex with a wide range. y_0 = np.ones(5).astype(np.float32) @@ -104,7 +104,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): y = np.array([y_0, y_1, y_2]) y /= y.sum(axis=0) y = y.T # y.shape = [5, 3] - assert_bijective_and_finite(softmax, x, y) + assert_bijective_and_finite(softmax, x, y, event_ndims=1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index d9af9aec50d3d69bb10f69f2ffd6ca3a24c316f8..3d8a0a32bba3539f732140e8eb7ebeb532d73ff5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -43,13 +43,13 @@ class SoftplusBijectorTest(test.TestCase): def testHingeSoftnessZeroRaises(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=0., validate_args=True) + bijector = Softplus(hinge_softness=0., validate_args=True) with self.assertRaisesOpError("must be non-zero"): bijector.forward([1., 1.]).eval() def testBijectorForwardInverseEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) @@ -59,7 +59,7 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.5) + bijector = Softplus(hinge_softness=1.5) x = 2 * rng.randn(2, 10) y = 1.5 * self._softplus(x / 1.5) @@ -68,16 +68,17 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorLogDetJacobianEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() y = 2 * rng.rand(2, 10) # No reduction needed if event_dims = 0. ildj = self._softplus_ildj_before_reduction(y) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval()) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval()) def testBijectorForwardInverseEventDimsOne(self): with self.test_session(): - bijector = Softplus(event_ndims=1) + bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) @@ -87,58 +88,59 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorLogDetJacobianEventDimsOne(self): with self.test_session(): - bijector = Softplus(event_ndims=1) + bijector = Softplus() y = 2 * rng.rand(2, 10) ildj_before = self._softplus_ildj_before_reduction(y) ildj = np.sum(ildj_before, axis=1) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval()) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithPositiveHingeSoftness(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.3) + bijector = Softplus(hinge_softness=1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithNegativeHingeSoftness(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=-1.3) + bijector = Softplus(hinge_softness=-1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testBijectiveAndFinite32bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.23) + bijector = Softplus(hinge_softness=1.23) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=-0.7) + bijector = Softplus(hinge_softness=-0.7) x = np.linspace(-20., 20., 100).astype(np.float32) y = -np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFinite16bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() # softplus(-20) is zero, so we can't use such a large range as in 32bit. x = np.linspace(-10., 20., 100).astype(np.float16) # Note that float16 is only in the open set (0, inf) for a smaller @@ -146,7 +148,7 @@ class SoftplusBijectorTest(test.TestCase): # for the test. y = np.logspace(-6, 3, 100).astype(np.float16) assert_bijective_and_finite( - bijector, x, y, rtol=1e-1, atol=1e-3) + bijector, x, y, event_ndims=0, rtol=1e-1, atol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac06fce55b448a5f3da7ccb7f8766b5b1404ad7 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.softsign import Softsign +from tensorflow.python.framework import test_util +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class SoftsignBijectorTest(test.TestCase): + """Tests the correctness of the Y = g(X) = X / (1 + |X|) transformation.""" + + def _softsign(self, x): + return x / (1. + np.abs(x)) + + def _softsign_ildj_before_reduction(self, y): + """Inverse log det jacobian, before being reduced.""" + return -2. * np.log1p(-np.abs(y)) + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorBounds(self): + bijector = Softsign(validate_args=True) + with self.test_session(): + with self.assertRaisesOpError("greater than -1"): + bijector.inverse(-3.).eval() + with self.assertRaisesOpError("greater than -1"): + bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval() + + with self.assertRaisesOpError("less than 1"): + bijector.inverse(3.).eval() + with self.assertRaisesOpError("less than 1"): + bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() + + @test_util.run_in_graph_and_eager_modes() + def testBijectorForwardInverse(self): + bijector = Softsign(validate_args=True) + self.assertEqual("softsign", bijector.name) + x = 2. * self._rng.randn(2, 10) + y = self._softsign(x) + + self.assertAllClose(y, self.evaluate(bijector.forward(x))) + self.assertAllClose(x, self.evaluate(bijector.inverse(y))) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorLogDetJacobianEventDimsZero(self): + bijector = Softsign(validate_args=True) + y = self._rng.rand(2, 10) + # No reduction needed if event_dims = 0. + ildj = self._softsign_ildj_before_reduction(y) + + self.assertAllClose(ildj, self.evaluate( + bijector.inverse_log_det_jacobian(y, event_ndims=0))) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorForwardInverseEventDimsOne(self): + bijector = Softsign(validate_args=True) + self.assertEqual("softsign", bijector.name) + x = 2. * self._rng.randn(2, 10) + y = self._softsign(x) + self.assertAllClose(y, self.evaluate(bijector.forward(x))) + self.assertAllClose(x, self.evaluate(bijector.inverse(y))) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorLogDetJacobianEventDimsOne(self): + bijector = Softsign(validate_args=True) + y = self._rng.rand(2, 10) + ildj_before = self._softsign_ildj_before_reduction(y) + ildj = np.sum(ildj_before, axis=1) + self.assertAllClose( + ildj, self.evaluate( + bijector.inverse_log_det_jacobian(y, event_ndims=1))) + + def testScalarCongruency(self): + with self.test_session(): + bijector = Softsign(validate_args=True) + assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.) + + def testBijectiveAndFinite(self): + with self.test_session(): + bijector = Softsign(validate_args=True) + x = np.linspace(-20., 20., 100).astype(np.float32) + y = np.linspace(-0.99, 0.99, 100).astype(np.float32) + assert_bijective_and_finite( + bijector, x, y, event_ndims=0, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30c7a738c320b609ce90685512e6b8344dffc9dc --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class SquareBijectorTest(test.TestCase): + """Tests the correctness of the Y = X ** 2 transformation.""" + + def testBijectorScalar(self): + with self.test_session(): + bijector = bijectors.Square(validate_args=True) + self.assertEqual("square", bijector.name) + x = [[[1., 5], + [2, 1]], + [[np.sqrt(2.), 3], + [np.sqrt(8.), 1]]] + y = np.square(x) + ildj = -np.log(2.) - np.log(x) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval(), atol=0., rtol=1e-7) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), + atol=0., + rtol=1e-7) + + def testScalarCongruency(self): + with self.test_session(): + bijector = bijectors.Square(validate_args=True) + assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6428a68702274fae384ae3de6d03f7ca126e2346 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TransformDiagonalBijectorTest(test.TestCase): + """Tests correctness of the TransformDiagonal bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.random.randn(3, 4, 4)) + + y = x.copy() + for i in range(x.shape[0]): + np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :]))) + + exp = bijectors.Exp() + b = bijectors.TransformDiagonal(diag_bijector=exp) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllEqual( + fldj, + self.evaluate(exp.forward_log_det_jacobian( + np.array([np.diag(x_mat) for x_mat in x]), + event_ndims=1))) + self.assertAllEqual( + ildj, + self.evaluate(exp.inverse_log_det_jacobian( + np.array([np.diag(y_mat) for y_mat in y]), + event_ndims=1))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py index 7a31228d1ade55ce32b511dca073657d3bab53ae..f57adcda898a1fdb18aacbb0804411db1bb4e4c8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py @@ -36,7 +36,7 @@ class WeibullBijectorTest(test.TestCase): concentration = 0.3 bijector = Weibull( scale=scale, concentration=concentration, - event_ndims=1, validate_args=True) + validate_args=True) self.assertEqual("weibull", bijector.name) x = np.array([[[0.], [1.], [14.], [20.], [100.]]], dtype=np.float32) # Weibull distribution @@ -45,13 +45,11 @@ class WeibullBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - np.squeeze(weibull_dist.logpdf(x), axis=2), - bijector.forward_log_det_jacobian(x).eval()) + weibull_dist.logpdf(x), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), rtol=1e-4, atol=0.) @@ -64,12 +62,12 @@ class WeibullBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): bijector = Weibull( - scale=20., concentration=2., event_ndims=0, validate_args=True) + scale=20., concentration=2., validate_args=True) x = np.linspace(1., 8., num=10).astype(np.float32) y = np.linspace( -np.expm1(-1 / 400.), -np.expm1(-16), num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py index 545471907f1eabc822b3d28ea9c57e183a09ff50..4e8989b6c2f93560b1fccbc99491d7809f494263 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py @@ -44,6 +44,7 @@ class _ChooseLocation(ConditionalBijector): graph_parents=[self._loc], is_constant_jacobian=True, validate_args=False, + forward_min_event_ndims=0, name=name) def _forward(self, x, z): @@ -52,7 +53,7 @@ class _ChooseLocation(ConditionalBijector): def _inverse(self, x, z): return x - self._gather_loc(z) - def _inverse_log_det_jacobian(self, x, z=None): + def _inverse_log_det_jacobian(self, x, event_ndims, z=None): return 0. def _gather_loc(self, z): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 507ceb35853ebe0a996d789b3bdf8a5f2284549c..f42feae25d851eb9ae0bf48649fc3bbe2a221be0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import distributions from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -25,23 +27,23 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -ds = distributions +tfd = distributions class DistributionTest(test.TestCase): def testParamShapesAndFromParams(self): classes = [ - ds.Normal, - ds.Bernoulli, - ds.Beta, - ds.Chi2, - ds.Exponential, - ds.Gamma, - ds.InverseGamma, - ds.Laplace, - ds.StudentT, - ds.Uniform, + tfd.Normal, + tfd.Bernoulli, + tfd.Beta, + tfd.Chi2, + tfd.Exponential, + tfd.Gamma, + tfd.InverseGamma, + tfd.Laplace, + tfd.StudentT, + tfd.Uniform, ] sample_shapes = [(), (10,), (10, 20, 30)] @@ -63,15 +65,15 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) - wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], - validate_args=True) + wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]], + validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters) def testCopyOverride(self): with self.test_session(): - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) unused_normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() @@ -84,19 +86,19 @@ class DistributionTest(test.TestCase): mu = 1. sigma = 2. - normal = ds.Normal(mu, sigma, validate_args=True) + normal = tfd.Normal(mu, sigma, validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch())) - normal = ds.Normal([mu], [sigma], validate_args=True) + normal = tfd.Normal([mu], [sigma], validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) + mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) + mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch())) @@ -126,7 +128,7 @@ class DistributionTest(test.TestCase): self.assertFalse(is_scalar.eval(feed_dict={x: [1]})) def _GetFakeDistribution(self): - class FakeDistribution(ds.Distribution): + class FakeDistribution(tfd.Distribution): """Fake Distribution for testing _set_sample_static_shape.""" def __init__(self, batch_shape=None, event_shape=None): @@ -188,6 +190,124 @@ class DistributionTest(test.TestCase): y = dist._set_sample_static_shape(x, sample_shape) self.assertTrue(y.get_shape().ndims is None) + def testNameScopeWorksCorrectly(self): + x = tfd.Normal(loc=0., scale=1., name="x") + x_duplicate = tfd.Normal(loc=0., scale=1., name="x") + with ops.name_scope("y") as name: + y = tfd.Bernoulli(logits=0., name=name) + x_sample = x.sample(name="custom_sample") + x_sample_duplicate = x.sample(name="custom_sample") + x_log_prob = x.log_prob(0., name="custom_log_prob") + x_duplicate_sample = x_duplicate.sample(name="custom_sample") + + self.assertEqual(x.name, "x/") + self.assertEqual(x_duplicate.name, "x_1/") + self.assertEqual(y.name, "y/") + self.assertTrue(x_sample.name.startswith("x/custom_sample")) + self.assertTrue(x_sample_duplicate.name.startswith("x/custom_sample_1")) + self.assertTrue(x_log_prob.name.startswith("x/custom_log_prob")) + self.assertTrue(x_duplicate_sample.name.startswith( + "x_1/custom_sample")) + + def testStrWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + ("tf.distributions.Normal(" + "\"Normal/\", " + "batch_shape=(), " + "event_shape=(), " + "dtype=float16)"), # Got the dtype right. + str(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + ("tf.distributions.Chi2(" + "\"silly/\", " # What a silly name that is! + "batch_shape=(2,), " + "event_shape=(), " + "dtype=float32)"), + str(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("tf.distributions.Exponential(\"Exponential/\", " + # No batch shape. + "event_shape=(), " + "dtype=float32)"), + str(exp)) + + def testStrWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN/\", " + "batch_shape=(2,), " + "event_shape=(2,), " + "dtype=float64)"), + str(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN2/\", " + "batch_shape=(?,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) + + def testReprWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + (""), # Got the dtype right. + repr(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + (""), + repr(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("" + " event_shape=()" + " dtype=float32>"), + repr(exp)) + + def testReprWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + (""), + repr(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + (""), + repr(mvn_dynamic)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 31d24aa9ea09007b8db40e4869371b1f62639ac7..bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag @@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +class TestMoveDimension(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_static_shape(self): + + x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_dynamic_shape(self): + + x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + x = array_ops.placeholder_with_default(input=x_, shape=None) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + x_perm = distribution_util.move_dimension(x, -1, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py index 06318ca09dec851cf025fa35c83732b85824cbee..6a69f9e60b99a17c657f074597a075890265a93b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -126,6 +127,100 @@ class ProductDistributionTest(test.TestCase): self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) + def testKLRaises(self): + ind1 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=1) + ind2 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32(-1), + scale=np.float32(0.5)), + reinterpreted_batch_ndims=0) + + with self.assertRaisesRegexp( + ValueError, "Event shapes do not match"): + kullback_leibler.kl_divergence(ind1, ind2) + + ind1 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=1) + ind2 = independent_lib.Independent( + distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([-1., 1]), + scale_diag=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=0) + + with self.assertRaisesRegexp( + NotImplementedError, "different event shapes"): + kullback_leibler.kl_divergence(ind1, ind2) + + def testKLScalarToMultivariate(self): + normal1 = normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])) + ind1 = independent_lib.Independent( + distribution=normal1, reinterpreted_batch_ndims=1) + + normal2 = normal_lib.Normal( + loc=np.float32([-3., 3]), + scale=np.float32([0.3, 0.3])) + ind2 = independent_lib.Independent( + distribution=normal2, reinterpreted_batch_ndims=1) + + normal_kl = kullback_leibler.kl_divergence(normal1, normal2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)), + self.evaluate(ind_kl)) + + def testKLIdentity(self): + normal1 = normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])) + # This is functionally just a wrapper around normal1, + # and doesn't change any outputs. + ind1 = independent_lib.Independent( + distribution=normal1, reinterpreted_batch_ndims=0) + + normal2 = normal_lib.Normal( + loc=np.float32([-3., 3]), + scale=np.float32([0.3, 0.3])) + # This is functionally just a wrapper around normal2, + # and doesn't change any outputs. + ind2 = independent_lib.Independent( + distribution=normal2, reinterpreted_batch_ndims=0) + + normal_kl = kullback_leibler.kl_divergence(normal1, normal2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(normal_kl), self.evaluate(ind_kl)) + + def testKLMultivariateToMultivariate(self): + # (1, 1, 2) batch of MVNDiag + mvn1 = mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), + scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) + ind1 = independent_lib.Independent( + distribution=mvn1, reinterpreted_batch_ndims=2) + + # (1, 1, 2) batch of MVNDiag + mvn2 = mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), + scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) + + ind2 = independent_lib.Independent( + distribution=mvn2, reinterpreted_batch_ndims=2) + + mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])), + self.evaluate(ind_kl)) + def _testMnistLike(self, static_shape): sample_shape = [4, 5] batch_shape = [10] diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py index ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2..2980e2bfe93b2e2aa01d38fc9fa4650a015efc06 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -130,10 +130,8 @@ class KumaraswamyTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): + with self.assertRaisesOpError("sample must be non-negative"): dist.prob([-1., 0.1, 0.5]).eval() - with self.assertRaisesOpError("sample must be positive"): - dist.prob([0., 0.1, 0.5]).eval() with self.assertRaisesOpError("sample must be no larger than `1`"): dist.prob([.1, .2, 1.2]).eval() @@ -249,13 +247,13 @@ class KumaraswamyTest(test.TestCase): a = np.array([1., 2, 3]) b = np.array([2., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."): dist.mode().eval() a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."): dist.mode().eval() def testKumaraswamyModeEnableAllowNanStats(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 933756aa8e12cca4c42eb98d9193512bbf2ad585..9635134b08db47a47a17c869fe813e0376ae6f1e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -68,7 +68,7 @@ class MultivariateNormalDiagTest(test.TestCase): dist = ds.TransformedDistribution( base_dist, validate_args=True, - bijector=bijectors.Softplus(event_ndims=1)) + bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py index 1a02fbefb8e88599f5fedeb38fb06f5a09036439..b003526392709b61e9cc46e0ff8e5fa78edc0568 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py @@ -52,7 +52,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): mu = [1., 2.] sigma = [[1., 0.], [0., 1.]] mvn = ds.MultivariateNormalFullCovariance(mu, sigma, name="Billy") - self.assertEqual(mvn.name, "Billy") + self.assertEqual(mvn.name, "Billy/") def testDoesNotRaiseIfInitializedWithSymmetricMatrix(self): with self.test_session(): @@ -131,8 +131,8 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): return mu, sigma def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -156,6 +156,33 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalFullCovariance( + loc=mu_a, + covariance_matrix=sigma_a, + validate_args=True) + mvn_b = ds.MultivariateNormalFullCovariance( + loc=mu_b, + covariance_matrix=sigma_b, + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py index 685f32883dae5b8513badeb05e1508cd611d6e93..b556d06123800f22f5d9a90dd18f3c745aec90a1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -235,8 +235,8 @@ class MultivariateNormalTriLTest(test.TestCase): return mu, sigma def testKLNonBatch(self): - batch_shape = () - event_shape = (2,) + batch_shape = [] + event_shape = [2] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -257,8 +257,8 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl, kl_v) def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -282,9 +282,36 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLTwoIdenticalDistributionsIsZero(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = ds.MultivariateNormalTriL( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index d9c9008417cdb20b62390630cf887d3bd888a0d3..19a7472d91758a2dbd00c4d918853d7bae33685d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from scipy import special from scipy import stats from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import constant_op @@ -110,7 +111,7 @@ class PoissonTest(test.TestCase): batch_size = 6 lam = constant_op.constant([3.0] * batch_size) lam_v = 3.0 - x = [2.2, 3.1, 4., 5.5, 6., 7.] + x = [2., 3., 4., 5., 6., 7.] poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) @@ -121,12 +122,31 @@ class PoissonTest(test.TestCase): self.assertEqual(cdf.get_shape(), (6,)) self.assertAllClose(cdf.eval(), stats.poisson.cdf(x, lam_v)) + def testPoissonCDFNonIntegerValues(self): + with self.test_session(): + batch_size = 6 + lam = constant_op.constant([3.0] * batch_size) + lam_v = 3.0 + x = np.array([2.2, 3.1, 4., 5.5, 6., 7.], dtype=np.float32) + + poisson = self._make_poisson(rate=lam) + cdf = poisson.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + + # The Poisson CDF should be valid on these non-integer values, and + # equal to igammac(1 + x, rate). + self.assertAllClose(cdf.eval(), special.gammaincc(1. + x, lam_v)) + + with self.assertRaisesOpError("cannot contain fractional components"): + poisson_validate = self._make_poisson(rate=lam, validate_args=True) + poisson_validate.cdf(x).eval() + def testPoissonCdfMultidimensional(self): with self.test_session(): batch_size = 6 lam = constant_op.constant([[2.0, 4.0, 5.0]] * batch_size) lam_v = [2.0, 4.0, 5.0] - x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T + x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index 4186cf129dbf31724c84133734da3f226817c71a..ea04e8c29a2c94d4939bad277afa380401067ff2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test @@ -455,6 +456,16 @@ class PercentileTestWithNearestInterpolation(test.TestCase): with self.assertRaisesOpError("rank"): pct.eval(feed_dict={q_ph: [0.5]}) + def test_finds_max_of_long_array(self): + # d - 1 == d in float32 and d = 3e7. + # So this test only passes if we use double for the percentile indices. + # If float is used, it fails with InvalidArgumentError about an index out of + # bounds. + x = math_ops.linspace(0., 3e7, num=int(3e7)) + with self.test_session(): + minval = sample_stats.percentile(x, q=0, validate_args=True) + self.assertAllEqual(0, minval.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b91a610acf1a9094d612504d63030b3bffb873ac --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SeedStream class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import seed_stream +from tensorflow.python.platform import test + + +class SeedStreamTest(test.TestCase): + + def assertAllUnique(self, items): + self.assertEqual(len(items), len(set(items))) + + def testNonRepetition(self): + # The probability of repetitions in a short stream from a correct + # PRNG is negligible; this test catches bugs that prevent state + # updates. + strm = seed_stream.SeedStream(seed=4, salt="salt") + output = [strm() for _ in range(50)] + self.assertEqual(sorted(output), sorted(list(set(output)))) + + def testReproducibility(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(seed=4, salt="salt") + strm3 = seed_stream.SeedStream(seed=4, salt="salt") + outputs = [strm1() for _ in range(50)] + self.assertEqual(outputs, [strm2() for _ in range(50)]) + self.assertEqual(outputs, [strm3() for _ in range(50)]) + + def testSeededDistinctness(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(seed=5, salt="salt") + self.assertAllUnique( + [strm1() for _ in range(50)] + [strm2() for _ in range(50)]) + + def testSaltedDistinctness(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(seed=4, salt="another salt") + self.assertAllUnique( + [strm1() for _ in range(50)] + [strm2() for _ in range(50)]) + + def testNestingRobustness(self): + # SeedStreams started from generated seeds should not collide with + # the master or with each other, even if the salts are the same. + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1(), salt="salt") + strm3 = seed_stream.SeedStream(strm1(), salt="salt") + outputs = [strm1() for _ in range(50)] + self.assertAllUnique( + outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py index c8d795c3f6afbec5b41755951174439f7703efb9..243b5a034859288b0e2e120f09258cfee77fbdea 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py @@ -584,7 +584,6 @@ class DistributionShapeTest(test.TestCase): def testDistributionShapeGetDimsStatic(self): with self.test_session(): - shaper = _DistributionShape(batch_ndims=0, event_ndims=0) shaper = _DistributionShape(batch_ndims=0, event_ndims=0) x = 1 self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4dfed83631e9f0815fb674d650cac2e570b923 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -0,0 +1,259 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the statistical testing library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class StatisticalTestingTest(test.TestCase): + + def test_dkwm_design_mean_one_sample_soundness(self): + thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] + false_fail_rates, false_pass_rates = np.meshgrid(rates, rates) + false_fail_rates = false_fail_rates.flatten().astype(np.float32) + false_pass_rates = false_pass_rates.flatten().astype(np.float32) + + detectable_discrepancies = [] + for false_pass_rate, false_fail_rate in zip( + false_pass_rates, false_fail_rates): + sufficient_n = st.min_num_samples_for_dkwm_mean_test( + thresholds, low=0., high=1., false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate) + detectable_discrepancies.append( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + sufficient_n, low=0., high=1., false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate)) + + detectable_discrepancies_ = self.evaluate(detectable_discrepancies) + for discrepancies, false_pass_rate, false_fail_rate in zip( + detectable_discrepancies_, false_pass_rates, false_fail_rates): + below_threshold = discrepancies <= thresholds + self.assertAllEqual( + np.ones_like(below_threshold, np.bool), below_threshold, + msg='false_pass_rate({}), false_fail_rate({})'.format( + false_pass_rate, false_fail_rate)) + + def test_dkwm_design_mean_two_sample_soundness(self): + thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] + false_fail_rates, false_pass_rates = np.meshgrid(rates, rates) + false_fail_rates = false_fail_rates.flatten().astype(np.float32) + false_pass_rates = false_pass_rates.flatten().astype(np.float32) + + detectable_discrepancies = [] + for false_pass_rate, false_fail_rate in zip( + false_pass_rates, false_fail_rates): + [ + sufficient_n1, + sufficient_n2 + ] = st.min_num_samples_for_dkwm_mean_two_sample_test( + thresholds, low1=0., high1=1., low2=0., high2=1., + false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate) + + detectable_discrepancies.append( + st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + n1=sufficient_n1, + low1=0., + high1=1., + n2=sufficient_n2, + low2=0., + high2=1., + false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate)) + + detectable_discrepancies_ = self.evaluate(detectable_discrepancies) + for discrepancies, false_pass_rate, false_fail_rate in zip( + detectable_discrepancies_, false_pass_rates, false_fail_rates): + below_threshold = discrepancies <= thresholds + self.assertAllEqual( + np.ones_like(below_threshold, np.bool), below_threshold, + msg='false_pass_rate({}), false_fail_rate({})'.format( + false_pass_rate, false_fail_rate)) + + def test_true_mean_confidence_interval_by_dkwm_one_sample(self): + rng = np.random.RandomState(seed=0) + + num_samples = 5000 + # 5000 samples is chosen to be enough to find discrepancies of + # size 0.1 or more with assurance 1e-6, as confirmed here: + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.1) + + # Test that the confidence interval computed for the mean includes + # 0.5 and excludes 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = self.evaluate([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) + + def test_dkwm_mean_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is 0.5. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_in_interval_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is between 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.2 and 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.6 and 0.8. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + + # 4000 samples is chosen to be enough to find discrepancies of + # size 0.2 or more with assurance 1e-6, as confirmed here: + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.2) + + # Test that the test assertion agrees that the standard + # uniform distribution has the same mean as itself. + samples1 = rng.uniform(size=num_samples).astype(np.float32) + samples2 = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + samples1 = rng.uniform(size=num_samples).astype(np.float32) + + # As established above, 4000 samples is enough to find discrepancies + # of size 0.2 or more with assurance 1e-6. + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + samples1 = rng.uniform(size=num_samples).astype(np.float32) + + # As established above, 4000 samples is enough to find discrepancies + # of size 0.2 or more with assurance 1e-6. + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) + + def test_dkwm_argument_validity_checking(self): + rng = np.random.RandomState(seed=0) + samples = rng.uniform( + low=[0., 1.], high=[1., 2.], size=(2500, 1, 2)).astype(np.float32) + + # Test that the test library complains if the given samples fall + # outside the purported bounds. + with self.assertRaisesOpError("maximum value exceeds expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) + with self.assertRaisesOpError("minimum value falls below expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) + _ = self.evaluate(op) + + def test_do_maximum_mean(self): + n = 117 + envelope = 0.02 # > 2 / n, but < 3 / n + rng = np.random.RandomState(seed=8) + samples = rng.uniform(size=n).astype(np.float32) + + # Compute the answer in TF using the code under test + envelope_t = ops.convert_to_tensor(envelope) + max_mean = st._do_maximum_mean(samples, envelope_t, 1) + max_mean = self.evaluate(max_mean) + + # Compute the correct answer for this case in numpy. In this + # example, `n` and `envelope` are such that `samples[2]` is the + # element that should be taken partially, regardless of the + # content of the `samples` array (see algorithm description in + # `../ops/statistical_testing.py`). + samples = sorted(samples) + weight = 1. / n - (envelope - 2. / n) + answer = samples[2] * weight + sum(samples[3:]) / n + envelope * 1. + self.assertAllClose(max_mean, answer, rtol=1e-9) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index cbaf74d3f66253ae5727e1ba579e2d49235b748e..5fe1331d2c34612e980c7b376367cd63b627533d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -36,6 +37,35 @@ ds = distributions la = linalg +class DummyMatrixTransform(bs.Bijector): + """Tractable matrix transformation. + + This is a non-sensical bijector that has forward/inverse_min_event_ndims=2. + The main use is to check that transformed distribution calculations are done + appropriately. + """ + + def __init__(self): + super(DummyMatrixTransform, self).__init__( + forward_min_event_ndims=2, + is_constant_jacobian=False, + validate_args=False, + name="dummy") + + def _forward(self, x): + return x + + def _inverse(self, y): + return y + + # Note: These jacobians don't make sense. + def _forward_log_det_jacobian(self, x): + return -linalg_ops.matrix_determinant(x) + + def _inverse_log_det_jacobian(self, x): + return linalg_ops.matrix_determinant(x) + + class TransformedDistributionTest(test.TestCase): def _cls(self): @@ -55,7 +85,7 @@ class TransformedDistributionTest(test.TestCase): # you may or may not need a reduce_sum. log_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.Exp(event_ndims=0)) + bijector=bs.Exp()) sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu)) # sample @@ -87,7 +117,7 @@ class TransformedDistributionTest(test.TestCase): sigma = 2.0 abs_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.AbsoluteValue(event_ndims=0)) + bijector=bs.AbsoluteValue()) sp_normal = stats.norm(mu, sigma) # sample @@ -129,7 +159,7 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(grid, cdf_, rtol=1e-6, atol=0.) def testCachedSamples(self): - exp_forward_only = bs.Exp(event_ndims=0) + exp_forward_only = bs.Exp() exp_forward_only._inverse = self._make_unimplemented( "inverse") exp_forward_only._inverse_event_shape_tensor = self._make_unimplemented( @@ -153,7 +183,7 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(expected_log_pdf, log_pdf_val, rtol=1e-4, atol=0.) def testCachedSamplesInvert(self): - exp_inverse_only = bs.Exp(event_ndims=0) + exp_inverse_only = bs.Exp() exp_inverse_only._forward = self._make_unimplemented( "forward") exp_inverse_only._forward_event_shape_tensor = self._make_unimplemented( @@ -186,12 +216,14 @@ class TransformedDistributionTest(test.TestCase): standard_normal = ds.Normal(loc=0., scale=1.) multi_logit_normal = self._cls()( distribution=standard_normal, - bijector=softmax) - x = [[-np.log(3.), 0.], - [np.log(3), np.log(5)]] + bijector=softmax, + event_shape=[1]) + x = [[[-np.log(3.)], [0.]], + [[np.log(3)], [np.log(5)]]] y = softmax.forward(x).eval() - expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) - - np.sum(np.log(y), axis=-1)) + expected_log_pdf = ( + np.squeeze(stats.norm(loc=0., scale=1.).logpdf(x)) - + np.sum(np.log(y), axis=-1)) self.assertAllClose(expected_log_pdf, multi_logit_normal.log_prob(y).eval()) self.assertAllClose( @@ -208,8 +240,11 @@ class TransformedDistributionTest(test.TestCase): int_identity = bs.Inline( forward_fn=array_ops.identity, inverse_fn=array_ops.identity, - inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), - forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + inverse_log_det_jacobian_fn=( + lambda y: math_ops.cast(0, dtypes.int32)), + forward_log_det_jacobian_fn=( + lambda x: math_ops.cast(0, dtypes.int32)), + forward_min_event_ndims=0, is_constant_jacobian=True) normal = self._cls()( distribution=ds.Normal(loc=0., scale=1.), @@ -245,9 +280,8 @@ class TransformedDistributionTest(test.TestCase): with self.test_session() as sess: exp2 = self._cls()( ds.Exponential(rate=0.25), - bijector=ds.bijectors.Affine( - scale_identity_multiplier=2., - event_ndims=0)) + bijector=ds.bijectors.AffineScalar(scale=2.) + ) log_prob = exp2.log_prob(1.) log_prob_ = sess.run(log_prob) base_log_prob = -0.5 * 0.25 + np.log(0.25) @@ -434,6 +468,82 @@ class ScalarToMultiTest(test.TestCase): event_shape=[3], validate_args=True) + def testMatrixEvent(self): + with self.test_session() as sess: + batch_shape = [2] + event_shape = [2, 3, 3] + batch_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_batch_shape") + event_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_event_shape") + feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32), + event_shape_pl: np.array(event_shape, dtype=np.int32)} + + scale = 2. + loc = 0. + fake_mvn_dynamic = self._cls()( + distribution=ds.Normal( + loc=loc, + scale=scale), + bijector=DummyMatrixTransform(), + batch_shape=batch_shape_pl, + event_shape=event_shape_pl, + validate_args=True) + + fake_mvn_static = self._cls()( + distribution=ds.Normal( + loc=loc, + scale=scale), + bijector=DummyMatrixTransform(), + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=True) + + def actual_mvn_log_prob(x): + # This distribution is the normal PDF, reduced over the + # last 3 dimensions + a jacobian term which corresponds + # to the determinant of x. + return (np.sum( + stats.norm(loc, scale).logpdf(x), axis=(-1, -2, -3)) + + np.sum(np.linalg.det(x), axis=-1)) + + self.assertAllEqual([2, 3, 3], fake_mvn_static.event_shape) + self.assertAllEqual([2], fake_mvn_static.batch_shape) + + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.event_shape) + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.batch_shape) + + num_samples = 5e3 + for fake_mvn, feed_dict in ((fake_mvn_static, {}), + (fake_mvn_dynamic, feed_dict)): + # Ensure sample works by checking first, second moments. + y = fake_mvn.sample(int(num_samples), seed=0) + x = y[0:5, ...] + [ + x_, + fake_event_shape_, + fake_batch_shape_, + fake_log_prob_, + fake_prob_, + ] = sess.run([ + x, + fake_mvn.event_shape_tensor(), + fake_mvn.batch_shape_tensor(), + fake_mvn.log_prob(x), + fake_mvn.prob(x), + ], feed_dict=feed_dict) + + # Ensure all other functions work as intended. + self.assertAllEqual([5, 2, 2, 3, 3], x_.shape) + self.assertAllEqual([2, 3, 3], fake_event_shape_) + self.assertAllEqual([2], fake_batch_shape_) + self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_, + atol=0., rtol=1e-6) + self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_, + atol=0., rtol=1e-5) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..03e26b198ea02ad1bef8bcd2f6076078ecd7df0b --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -0,0 +1,48 @@ +# Description: +# Internal testing utilities, e.g., computing the correct answer to +# put in a unit test. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "correlation_matrix_volumes_py", + srcs = [ + "correlation_matrix_volumes_lib.py", + ], + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "correlation_matrix_volumes", + srcs = [ + "correlation_matrix_volumes.py", + ], + deps = [ + ":correlation_matrix_volumes_py", + ], +) + +py_test( + name = "correlation_matrix_volumes_test", + size = "medium", + srcs = ["correlation_matrix_volumes_test.py"], + tags = ["no_pip"], + deps = [ + ":correlation_matrix_volumes_py", + # For statistical testing + "//tensorflow/contrib/distributions:distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py new file mode 100644 index 0000000000000000000000000000000000000000..2eab51cd3053ea55f2e03619fd002fbf48251fb1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Executable to estimate the volume of various sets of correlation matrices. + +See correlation_matrix_volumes_lib.py for purpose and methodology. + +Invocation example: +``` +python correlation_matrix_volumes.py --num_samples 1e7 +``` + +This will compute 10,000,000-sample confidence intervals for the +volumes of several sets of correlation matrices. Which sets, and the +desired statistical significance, are hard-coded in this source file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint + +from absl import app +from absl import flags + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr + +FLAGS = flags.FLAGS + +# Float to support giving the number of samples in scientific notation. +# The production run used for the LKJ test used 1e7 samples. +flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.') + + +def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + # This wrapper undoes the batching in compute_true_volumes, because + # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM. + bounds = {} + for db in det_bounds: + bounds[db] = corr.compute_true_volumes( + [db], dim, num_samples, error_rate=error_rate, seed=seed)[db] + return bounds + + +# The particular bounds in all three of these functions were chosen by +# a somewhat arbitrary walk through an empirical tradeoff, for the +# purpose of testing the LKJ distribution. Setting the determinant +# bound lower +# - Covers more of the testee's sample space, and +# - Increases the probability that the rejection sampler will hit, thus +# - Decreases the relative error (at a fixed sample count) in the +# rejection-based volume estimate; +# but also +# - Increases the variance of the estimator used in the LKJ test. +# This latter variance is also affected by the dimension and the +# tested concentration parameter, and can be compensated for with more +# compute (expensive) or a looser discrepancy limit (unsatisfying). +# The values here are the projection of the points in that test design +# space that ended up getting chosen. +def compute_3x3_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 3, num_samples, error_rate=5e-7, seed=46) + + +def compute_4x4_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 4, num_samples, error_rate=5e-7, seed=47) + + +def compute_5x5_volumes(num_samples): + det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4] + return ctv_debatched( + det_bounds, 5, num_samples, error_rate=5e-7, seed=48) + + +def main(_): + full_bounds = {} + full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples)) + full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples)) + full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples)) + pprint.pprint(full_bounds) + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..455e71f00c96e799c4aaae25050c77a9ae36df06 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimating the volume of the correlation matrices with bounded determinant. + +Why? Because lkj_test.py tests the sampler for the LKJ distribution +by estimating the same volume another way. + +How? Rejection sampling. Or, more precisely, importance sampling, +proposing from the uniform distribution on symmetric matrices with +diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation +matrix if and only if it is also positive semi-definite. + +The samples can then be converted into a confidence interval on the +volume in question by the [Clopper-Pearson +method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval), +also implemented here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.ops.distributions import util +from tensorflow.python.platform import tf_logging + +__all__ = [ + "correlation_matrix_volume_rejection_samples", + "compute_true_volumes", +] + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +optimize = try_import("scipy.optimize") +stats = try_import("scipy.stats") + + +def _psd_mask(x): + """Computes whether each square matrix in the input is positive semi-definite. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + + Returns: + mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each + scalar is 1 if the corresponding matrix was PSD, otherwise 0. + """ + # Allegedly + # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite + # it is more efficient to test for positive semi-definiteness by + # trying to compute the Cholesky decomposition -- the matrix is PSD + # if you succeed and not PSD if you fail. However, TensorFlow's + # Cholesky raises an exception if _any_ of the input matrices are + # not PSD, from which I don't know how to extract _which ones_, so I + # proceed by explicitly computing all the eigenvalues and checking + # whether they are all positive or not. + # + # Also, as was discussed in the answer, it is somewhat dangerous to + # treat SPD-ness as binary in floating-point arithmetic. Cholesky + # factorization can complete and 'look' like everything is fine + # (e.g., O(1) entries and a diagonal of all ones) but the matrix can + # have an exponential condition number. + eigenvalues, _ = linalg_ops.self_adjoint_eig(x) + return math_ops.cast( + math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype) + + +def _det_large_enough_mask(x, det_bounds): + """Returns whether the input matches the given determinant limit. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + det_bounds: A floating-point `Tensor` that must broadcast to shape + `[B1, ..., Bn]`, giving the desired lower bound on the + determinants in `x`. + + Returns: + mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each + scalar is 1 if the corresponding matrix had determinant above + the corresponding bound, otherwise 0. + """ + # For the curious: I wonder whether it is possible and desirable to + # use a Cholesky decomposition-based algorithm for this, since the + # only matrices whose determinant this code cares about will be PSD. + # Didn't figure out how to code that in TensorFlow. + # + # Expert opinion is that it would be about twice as fast since + # Cholesky is roughly half the cost of Gaussian Elimination with + # Partial Pivoting. But this is less of an impact than the switch in + # _psd_mask. + return math_ops.cast( + linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) + + +def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): + """Returns a uniformly random `Tensor` of "correlation-like" matrices. + + A "correlation-like" matrix is a symmetric square matrix with all entries + between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, + the ones that are positive semi-definite are exactly the correlation + matrices. + + Args: + num_rows: Python `int` dimension of the correlation-like matrices. + batch_shape: `Tensor` or Python `tuple` of `int` shape of the + batch to return. + dtype: `dtype` of the `Tensor` to return. + seed: Random seed. + + Returns: + matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` + and dtype `dtype`. Each entry is in [-1, 1], and each matrix + along the bottom two dimensions is symmetric and has 1s on the + main diagonal. + """ + num_entries = num_rows * (num_rows + 1) / 2 + ones = array_ops.ones(shape=[num_entries], dtype=dtype) + # It seems wasteful to generate random values for the diagonal since + # I am going to throw them away, but `fill_triangular` fills the + # diagonal, so I probably need them. + # It's not impossible that it would be more efficient to just fill + # the whole matrix with random values instead of messing with + # `fill_triangular`. Then would need to filter almost half out with + # `matrix_band_part`. + unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) + tril = util.fill_triangular(unifs) + symmetric = tril + array_ops.matrix_transpose(tril) + diagonal_ones = array_ops.ones( + shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), + dtype=dtype) + return array_ops.matrix_set_diag(symmetric, diagonal_ones) + + +def correlation_matrix_volume_rejection_samples( + det_bounds, dim, sample_shape, dtype, seed): + """Returns rejection samples from trying to get good correlation matrices. + + The proposal being rejected from is the uniform distribution on + "correlation-like" matrices. We say a matrix is "correlation-like" + if it is a symmetric square matrix with all entries between -1 and 1 + (inclusive) and 1s on the main diagonal. Of these, the ones that + are positive semi-definite are exactly the correlation matrices. + + The rejection algorithm, then, is to sample a `Tensor` of + `sample_shape` correlation-like matrices of dimensions `dim` by + `dim`, and check each one for (i) being a correlation matrix (i.e., + PSD), and (ii) having determinant at least the corresponding entry + of `det_bounds`. + + Args: + det_bounds: A `Tensor` of lower bounds on the determinants of + acceptable matrices. The shape must broadcast with `sample_shape`. + dim: A Python `int` dimension of correlation matrices to sample. + sample_shape: Python `tuple` of `int` shape of the samples to + compute, excluding the two matrix dimensions. + dtype: The `dtype` in which to do the computation. + seed: Random seed. + + Returns: + weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the + corresponding matrix was not a correlation matrix, or had too + small of a determinant. Otherwise, the entry is the + multiplicative inverse of the density of proposing that matrix + uniformly, i.e., the volume of the set of `dim` by `dim` + correlation-like matrices. + volume: The volume of the set of `dim` by `dim` correlation-like + matrices. + """ + with ops.name_scope("rejection_sampler"): + rej_proposals = _uniform_correlation_like_matrix( + dim, sample_shape, dtype, seed=seed) + rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) + # The density of proposing any given point is 1 / rej_proposal_volume; + # The weight of that point should be scaled by + # 1 / density = rej_proposal_volume. + rej_weights = rej_proposal_volume * _psd_mask( + rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) + return rej_weights, rej_proposal_volume + + +def _clopper_pearson_confidence_interval(samples, error_rate): + """Computes a confidence interval for the mean of the given 1-D distribution. + + Assumes (and checks) that the given distribution is Bernoulli, i.e., + takes only two values. This licenses using the CDF of the binomial + distribution for the confidence, which is tighter (for extreme + probabilities) than the DKWM inequality. The method is known as the + [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Assumes: + + - The given samples were drawn iid from the distribution of interest. + + - The given distribution is a Bernoulli, i.e., supported only on + low and high. + + Guarantees: + + - The probability (over the randomness of drawing the given sample) + that the true mean is outside the returned interval is no more + than the given error_rate. + + Args: + samples: `np.ndarray` of samples drawn iid from the distribution + of interest. + error_rate: Python `float` admissible rate of mistakes. + + Returns: + low: Lower bound of confidence interval. + high: Upper bound of confidence interval. + + Raises: + ValueError: If `samples` has rank other than 1 (batch semantics + are not implemented), or if `samples` contains values other than + `low` or `high` (as that makes the distribution not Bernoulli). + """ + # TODO(b/78025336) Migrate this confidence interval function + # to statistical_testing.py. In order to do that + # - Get the binomial CDF from the Binomial distribution + # - Implement scalar root finding in TF. Batch bisection search + # shouldn't be too hard, and is definitely good enough for this + # problem. Batching the Brent algorithm (from scipy) that is used + # here may be more involved, but may also not be necessary---it's + # only used here because scipy made it convenient. In particular, + # robustness is more important than speed here, which may make + # bisection search actively better. + # - The rest is just a matter of rewriting in the appropriate style. + if optimize is None or stats is None: + raise ValueError( + "Scipy is required for computing Clopper-Pearson confidence intervals") + if len(samples.shape) != 1: + raise ValueError("Batch semantics not implemented") + n = len(samples) + low = np.amin(samples) + high = np.amax(samples) + successes = np.count_nonzero(samples - low) + failures = np.count_nonzero(samples - high) + if successes + failures != n: + uniques = np.unique(samples) + msg = ("Purportedly Bernoulli distribution had distinct samples" + " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) + raise ValueError(msg) + def p_small_enough(p): + prob = stats.binom.logcdf(successes, n, p) + return prob - np.log(error_rate / 2.) + def p_big_enough(p): + prob = stats.binom.logsf(successes, n, p) + return prob - np.log(error_rate / 2.) + high_p = optimize.brentq( + p_small_enough, float(successes) / n, 1., rtol=1e-9) + low_p = optimize.brentq( + p_big_enough, 0., float(successes) / n, rtol=1e-9) + low_interval = low + (high - low) * low_p + high_interval = low + (high - low) * high_p + return (low_interval, high_interval) + + +def compute_true_volumes( + det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + """Returns confidence intervals for the desired correlation matrix volumes. + + The confidence intervals are computed by the [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Args: + det_bounds: A rank-1 numpy array of lower bounds on the + determinants of acceptable matrices. Entries must be unique. + dim: A Python `int` dimension of correlation matrices to sample. + num_samples: The number of samples to draw. + error_rate: The statistical significance of the returned + confidence intervals. The significance is broadcast: Each + returned interval separately may be incorrect with probability + (under the sample of correlation-like matrices drawn internally) + at most `error_rate`. + seed: Random seed. + + Returns: + bounds: A Python `dict` mapping each determinant bound to the low, high + tuple giving the confidence interval. + """ + bounds = {} + with session.Session() as sess: + rej_weights, _ = correlation_matrix_volume_rejection_samples( + det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) + rej_weights = sess.run(rej_weights) + for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): + template = ("Estimating volume of {}x{} correlation " + "matrices with determinant >= {}.") + print(template.format(dim, dim, det)) + sys.stdout.flush() + bounds[det] = _clopper_pearson_confidence_interval( + rw, error_rate=error_rate) + return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f99300e63871119800a42f122c8321e5986541a --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for correlation_matrix_volumes_lib.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +# NxN correlation matrices are determined by the N*(N-1)/2 +# lower-triangular entries. In addition to being between -1 and 1, +# they must also obey the constraint that the determinant of the +# resulting symmetric matrix is non-negative. In 2x2, we can even +# analytically compute the volume when the determinant is bounded to > +# epsilon, as that boils down to the one lower-triangular entry being +# less than 1 - epsilon in absolute value. +def two_by_two_volume(det_bound): + return 2 * np.sqrt(1.0 - det_bound) + + +# The post +# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ +# derives (with elementary calculus) that the volume (with respect to +# Lebesgue^3 measure) of the set of 3x3 correlation matrices is +# pi^2/2. The same result is also obtained by [1]. +def three_by_three_volume(): + return np.pi**2 / 2. + + +# The volume of the unconstrained set of correlation matrices is also +# the normalization constant of the LKJ distribution from [2]. As +# part of defining the distribution, that reference a derives general +# formula for this volume for all dimensions. A TensorFlow +# computation thereof gave the below result for 4x4: +def four_by_four_volume(): + # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) + return 11.6973076 + +# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of +# correlation matrices." The American Statistician, 48(4), 276-279. + +# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating +# random correlation matrices based on vines and extended onion +# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. + + +class CorrelationMatrixVolumesTest(test.TestCase): + + def testRejection2D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + exact_volumes = two_by_two_volume(det_bounds) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) + # shape of rej_weights: [num_samples, 9, 2, 2] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + # Correct the false fail rate due to different broadcasting + false_fail_rate=1.1e-7, false_pass_rate=1e-6), + 0.036) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection3D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) + # shape of rej_weights: [num_samples, 1, 3, 3] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 3% relative error + 0.15) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection4D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = [four_by_four_volume()] + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) + # shape of rej_weights: [num_samples, 1, 4, 4] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 10% relative error + 1.1) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testVolumeEstimation2D(self): + # Test that the confidence intervals produced by + # corr.compte_true_volumes are sound, in the sense of containing + # the exact volume. + num_samples = int(1e5) # Chosen by symmetry with testRejection2D + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + volume_bounds = corr.compute_true_volumes( + det_bounds, 2, num_samples, error_rate=1e-6, seed=47) + exact_volumes = two_by_two_volume(det_bounds) + for det, volume in zip(det_bounds, exact_volumes): + computed_low, computed_high = volume_bounds[det] + self.assertLess(computed_low, volume) + self.assertGreater(computed_high, volume) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py index c355adeedbfff1072281a81de726ddb0ece07882..1226c66113ec4b43f57371abf4983aef1a529ec1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py @@ -61,7 +61,7 @@ class VectorLaplaceDiagTest(test.TestCase): dist = ds.TransformedDistribution( base_dist, validate_args=True, - bijector=bijectors.Softplus(event_ndims=1)) + bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index 9044aa2850ae35f29cd48b0c5f54aa948bea0408..dcecce981f16a2d9e772d4e40062ff250725c3ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -390,6 +390,26 @@ class WishartCholeskyTest(test.TestCase): chol_scale, dtype=np.int32), validate_args=False) + def testSampleBroadcasts(self): + dims = 2 + batch_shape = [2, 3] + sample_shape = [2, 1] + scale = np.float32([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale, scale], axis=0), + batch_shape + [dims, dims]) + wishart = distributions.WishartFull(df=5, scale=scale) + x = wishart.sample(sample_shape, seed=42) + with self.test_session() as sess: + x_ = sess.run(x) + expected_shape = sample_shape + batch_shape + [dims, dims] + self.assertAllEqual(expected_shape, x.shape) + self.assertAllEqual(expected_shape, x_.shape) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 852298bf334666db003353d5fc8e172ffb738668..bb9b8043b2233b2109f51b5dde188d088fdb0d39 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Autoregressive(distribution_lib.Distribution): @@ -36,7 +37,8 @@ class Autoregressive(distribution_lib.Distribution): "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." [(Papamakarios et + al., 2016)][1] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -45,17 +47,18 @@ class Autoregressive(distribution_lib.Distribution): Practically speaking the autoregressive property means that there exists a permutation of the event coordinates such that each coordinate is a - diffeomorphic function of only preceding coordinates. [2] + diffeomorphic function of only preceding coordinates + [(van den Oord et al., 2016)][2]. #### Mathematical Details - The probability function is, + The probability function is ```none prob(x; fn, n) = fn(x).prob(x) ``` - And a sample is generated by, + And a sample is generated by ```none x = fn(...fn(fn(x0).sample()).sample()).sample() @@ -93,16 +96,26 @@ class Autoregressive(distribution_lib.Distribution): ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "Conditional Image Generation with PixelCNN Decoders." - Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex - Graves, Koray Kavukcuoglu. Arxiv, 2016. + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + + [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, + Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with + PixelCNN Decoders. In _Neural Information Processing Systems_, 2016. https://arxiv.org/abs/1606.05328 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution_fn, sample0=None, @@ -140,8 +153,8 @@ class Autoregressive(distribution_lib.Distribution): `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = locals() - with ops.name_scope(name): + parameters = dict(locals()) + with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 self._distribution0 = (distribution_fn() if sample0 is None diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..519077bc9ab1063a1135486cfae34656f3f68157 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -0,0 +1,432 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The BatchReshape distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.util import deprecation + + +__all__ = [ + "BatchReshape", +] + + +class BatchReshape(distribution_lib.Distribution): + """The Batch-Reshaping distribution. + + This "meta-distribution" reshapes the batch dimensions of another + distribution. + + #### Examples + + ```python + tfd = tf.contrib.distributions + + dtype = np.float32 + dims = 2 + new_batch_shape = [1, 2, -1] + old_batch_shape = [6] + + scale = np.ones(old_batch_shape + [dims], dtype) + mvn = tfd.MultivariateNormalDiag(scale_diag=scale) + reshape_mvn = tfd.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape, + validate_args=True) + + reshape_mvn.batch_shape + # ==> [1, 2, 3] + + x = reshape_mvn.sample(sample_shape=[4, 5]) + x.shape + # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims] + + reshape_mvn.log_prob(x).shape + # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + distribution, + batch_shape, + validate_args=False, + allow_nan_stats=True, + name=None): + """Construct BatchReshape distribution. + + Args: + distribution: The base distribution instance to reshape. Typically an + instance of `Distribution`. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing + the new shape of the batch dimensions. Up to one dimension may contain + `-1`, meaning the remainder of the batch size. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: The name to give Ops created by the initializer. + Default value: `"BatchReshape" + distribution.name`. + + Raises: + ValueError: if `batch_shape` is not a vector. + ValueError: if `batch_shape` has non-positive elements. + ValueError: if `batch_shape` size is not the same as a + `distribution.batch_shape` size. + """ + parameters = dict(locals()) + name = name or "BatchReshape" + distribution.name + with ops.name_scope(name, values=[batch_shape]) as name: + # The unexpanded batch shape may contain up to one dimension of -1. + self._batch_shape_unexpanded = ops.convert_to_tensor( + batch_shape, dtype=dtypes.int32, name="batch_shape") + validate_init_args_statically(distribution, self._batch_shape_unexpanded) + batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( + distribution.batch_shape_tensor(), self._batch_shape_unexpanded, + validate_args) + self._distribution = distribution + self._batch_shape_ = batch_shape + self._batch_shape_static = batch_shape_static + self._runtime_assertions = runtime_assertions + super(BatchReshape, self).__init__( + dtype=distribution.dtype, + reparameterization_type=distribution.reparameterization_type, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=( + [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access + name=name) + + @property + def distribution(self): + return self._distribution + + def _batch_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return array_ops.identity(self._batch_shape_) + + def _batch_shape(self): + return self._batch_shape_static + + def _event_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return array_ops.identity(self.distribution.event_shape_tensor()) + + def _event_shape(self): + return self.distribution.event_shape + + def _sample_n(self, n, seed=None): + with ops.control_dependencies(self._runtime_assertions): + x = self.distribution.sample(sample_shape=n, seed=seed) + new_shape = array_ops.concat( + [ + [n], + self._batch_shape_unexpanded, + self.event_shape_tensor(), + ], + axis=0) + return array_ops.reshape(x, new_shape) + + def _log_prob(self, x): + return self._call_reshape_input_output( + self.distribution.log_prob, x) + + def _prob(self, x): + return self._call_reshape_input_output( + self.distribution.prob, x) + + def _log_cdf(self, x): + return self._call_reshape_input_output( + self.distribution.log_cdf, x) + + def _cdf(self, x): + return self._call_reshape_input_output( + self.distribution.cdf, x) + + def _log_survival_function(self, x): + return self._call_reshape_input_output( + self.distribution.log_survival_function, x) + + def _survival_function(self, x): + return self._call_reshape_input_output( + self.distribution.survival_function, x) + + def _entropy(self): + return self._call_and_reshape_output( + self.distribution.entropy, + [], + [tensor_shape.scalar()]) + + def _mean(self): + return self._call_and_reshape_output(self.distribution.mean) + + def _mode(self): + return self._call_and_reshape_output(self.distribution.mode) + + def _stddev(self): + return self._call_and_reshape_output(self.distribution.stddev) + + def _variance(self): + return self._call_and_reshape_output(self.distribution.variance) + + def _covariance(self): + return self._call_and_reshape_output( + self.distribution.covariance, + [self.event_shape_tensor()]*2, + [self.event_shape]*2) + + def _sample_shape(self, x): + """Computes graph and static `sample_shape`.""" + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) + sample_ndims = x_ndims - batch_ndims - event_ndims + if isinstance(sample_ndims, int): + static_sample_shape = x.shape[:sample_ndims] + else: + static_sample_shape = tensor_shape.TensorShape(None) + if static_sample_shape.is_fully_defined(): + sample_shape = np.int32(static_sample_shape.as_list()) + else: + sample_shape = array_ops.shape(x)[:sample_ndims] + return sample_shape, static_sample_shape + + def _call_reshape_input_output(self, fn, x): + """Calls `fn`, appropriately reshaping its input `x` and output.""" + with ops.control_dependencies( + self._runtime_assertions + self._validate_sample_arg(x)): + sample_shape, static_sample_shape = self._sample_shape(x) + old_shape = array_ops.concat([ + sample_shape, + self.distribution.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + result = fn(array_ops.reshape(x, old_shape)) + new_shape = array_ops.concat( + [ + sample_shape, + self._batch_shape_unexpanded, + ], axis=0) + result = array_ops.reshape(result, new_shape) + if (static_sample_shape.ndims is not None and + self.batch_shape.ndims is not None): + new_shape = static_sample_shape.concatenate(self.batch_shape) + result.set_shape(result.shape.merge_with(new_shape)) + return result + + def _call_and_reshape_output( + self, + fn, + event_shape_list=None, + static_event_shape_list=None): + """Calls `fn` and appropriately reshapes its output.""" + with ops.control_dependencies(self._runtime_assertions): + if event_shape_list is None: + event_shape_list = [self._event_shape_tensor()] + if static_event_shape_list is None: + static_event_shape_list = [self.event_shape] + new_shape = array_ops.concat( + [self._batch_shape_unexpanded] + event_shape_list, axis=0) + result = array_ops.reshape(fn(), new_shape) + if (self.batch_shape.ndims is not None and + self.event_shape.ndims is not None): + event_shape = tensor_shape.TensorShape([]) + for rss in static_event_shape_list: + event_shape = event_shape.concatenate(rss) + static_shape = result.shape.merge_with( + self.batch_shape.concatenate(event_shape)) + result.set_shape(static_shape) + return result + + def _validate_sample_arg(self, x): + """Helper which validates sample arg, e.g., input to `log_prob`.""" + with ops.name_scope(name="validate_sample_arg", values=[x]): + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) + expected_batch_event_ndims = batch_ndims + event_ndims + + if (isinstance(x_ndims, int) and + isinstance(expected_batch_event_ndims, int)): + if x_ndims < expected_batch_event_ndims: + raise NotImplementedError( + "Broadcasting is not supported; too few batch and event dims " + "(expected at least {}, saw {}).".format( + expected_batch_event_ndims, x_ndims)) + ndims_assertion = [] + elif self.validate_args: + ndims_assertion = [ + check_ops.assert_greater_equal( + x_ndims, + expected_batch_event_ndims, + message=("Broadcasting is not supported; too few " + "batch and event dims."), + name="assert_batch_and_event_ndims_large_enough"), + ] + + if (self.batch_shape.is_fully_defined() and + self.event_shape.is_fully_defined()): + expected_batch_event_shape = np.int32(self.batch_shape.concatenate( + self.event_shape).as_list()) + else: + expected_batch_event_shape = array_ops.concat([ + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + + sample_ndims = x_ndims - expected_batch_event_ndims + if isinstance(sample_ndims, int): + sample_ndims = max(sample_ndims, 0) + if (isinstance(sample_ndims, int) and + x.shape[sample_ndims:].is_fully_defined()): + actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list()) + else: + sample_ndims = math_ops.maximum(sample_ndims, 0) + actual_batch_event_shape = array_ops.shape(x)[sample_ndims:] + + if (isinstance(expected_batch_event_shape, np.ndarray) and + isinstance(actual_batch_event_shape, np.ndarray)): + if any(expected_batch_event_shape != actual_batch_event_shape): + raise NotImplementedError("Broadcasting is not supported; " + "unexpected batch and event shape " + "(expected {}, saw {}).".format( + expected_batch_event_shape, + actual_batch_event_shape)) + # We need to set the final runtime-assertions to `ndims_assertion` since + # its possible this assertion was created. We could add a condition to + # only do so if `self.validate_args == True`, however this is redundant + # as `ndims_assertion` already encodes this information. + runtime_assertions = ndims_assertion + elif self.validate_args: + # We need to make the `ndims_assertion` a control dep because otherwise + # TF itself might raise an exception owing to this assertion being + # ill-defined, ie, one cannot even compare different rank Tensors. + with ops.control_dependencies(ndims_assertion): + shape_assertion = check_ops.assert_equal( + expected_batch_event_shape, + actual_batch_event_shape, + message=("Broadcasting is not supported; " + "unexpected batch and event shape."), + name="assert_batch_and_event_shape_same") + runtime_assertions = [shape_assertion] + else: + runtime_assertions = [] + + return runtime_assertions + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def calculate_reshape(original_shape, new_shape, validate=False, name=None): + """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" + batch_shape_static = tensor_util.constant_value_as_shape(new_shape) + if batch_shape_static.is_fully_defined(): + return np.int32(batch_shape_static.as_list()), batch_shape_static, [] + with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): + original_size = math_ops.reduce_prod(original_shape) + implicit_dim = math_ops.equal(new_shape, -1) + size_implicit_dim = ( + original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) + new_ndims = array_ops.shape(new_shape) + expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) + validations = [] if not validate else [ + check_ops.assert_rank( + original_shape, 1, message="Original shape must be a vector."), + check_ops.assert_rank( + new_shape, 1, message="New shape must be a vector."), + check_ops.assert_less_equal( + math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), + 1, + message="At most one dimension can be unknown."), + check_ops.assert_positive( + expanded_new_shape, message="Shape elements must be >=-1."), + check_ops.assert_equal( + math_ops.reduce_prod(expanded_new_shape), + original_size, + message="Shape sizes do not match."), + ] + return expanded_new_shape, batch_shape_static, validations + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def validate_init_args_statically(distribution, batch_shape): + """Helper to __init__ which makes or raises assertions.""" + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format(batch_shape.shape.ndims)) + + batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) + batch_size_static = batch_shape_static.num_elements() + dist_batch_size_static = distribution.batch_shape.num_elements() + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, dist_batch_size_static)) + + if batch_shape_static.dims is not None: + if any( + dim.value is not None and dim.value < 1 for dim in batch_shape_static): + raise ValueError("`batch_shape` elements must be >=-1.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 93923c3f083c7f5136b55e9021cbd6323684b976..e141f8b5c6423bd6cce4d09da6f49d55b3e25a24 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -17,25 +17,34 @@ @@AbsoluteValue @@Affine @@AffineLinearOperator +@@AffineScalar @@Bijector +@@BatchNormalization @@Chain @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@Invert +@@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL +@@Ordered @@Permute @@PowerTransform @@RealNVP @@Reshape +@@ScaleTriL @@Sigmoid -@@SigmoidCentered @@SinhArcsinh @@SoftmaxCentered @@Softplus +@@Softsign +@@Square +@@TransformDiagonal @@Weibull @@masked_autoregressive_default_template @@ -52,23 +61,32 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import * from tensorflow.contrib.distributions.python.ops.bijectors.affine import * from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import * +from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import * +from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import * from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * +from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * +from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * +from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * +from tensorflow.contrib.distributions.python.ops.bijectors.square import * +from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7..4d6a46e7358933fdf512f49eae2673f35953c90a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,13 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import constant_op from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "AbsoluteValue", @@ -72,38 +71,30 @@ class AbsoluteValue(bijector.Bijector): """ - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. validate_args: Python `bool` indicating whether arguments should be checked for correctness, in particular whether inputs to `inverse` and `inverse_log_det_jacobian` are non-negative. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. """ self._graph_parents = [] self._name = name - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - with self._name_scope("init"): super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, + is_constant_jacobian=True, validate_args=validate_args, name=name) @@ -121,8 +112,7 @@ class AbsoluteValue(bijector.Bijector): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + zeros = constant_op.constant(0., dtype=y.dtype) if self.validate_args: zeros = control_flow_ops.with_dependencies( [check_ops.assert_non_negative(y, message="Argument y was negative")], diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 05bb9c2f9bdf35e222c94db3491157893da64ebd..25f29452c3949600b8a4153a8585dd7269bd3b2b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,6 +37,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _as_tensor(x, name): """Convenience to convert to `Tensor` or leave as `None`.""" return None if x is None else ops.convert_to_tensor(x, name=name) @@ -62,7 +71,7 @@ class Affine(bijector.Bijector): matrices, i.e., the matmul is [matrix-free]( https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - Examples: + #### Examples ```python # Y = X @@ -97,6 +106,14 @@ class Affine(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale_identity_multiplier=None, @@ -104,7 +121,6 @@ class Affine(bijector.Bijector): scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, - event_ndims=1, validate_args=False, name="affine"): """Instantiates the `Affine` bijector. @@ -157,8 +173,6 @@ class Affine(bijector.Bijector): matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which represents an `r x r` diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -187,22 +201,6 @@ class Affine(bijector.Bijector): with self._name_scope("init", values=[ shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -251,12 +249,11 @@ class Affine(bijector.Bijector): self._scale = scale self._shaper = _DistributionShape( batch_ndims=batch_ndims, - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args) super(Affine, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, graph_parents=( - [event_ndims] + [self._scale] if tensor_util.is_tensor(self._scale) else self._scale.graph_parents + [self._shift] if self._shift is not None else []), @@ -381,18 +378,17 @@ class Affine(bijector.Bijector): x, sample_shape, expand_batch_dim=False) return x - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] + event_size = array_ops.shape(x)[-1] event_size = math_ops.cast(event_size, dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() def _maybe_check_scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 89043b1410370074f11f2cfa59b6b6663fa62521..91301f15ad87e133777371b346864ecf7b964f27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -22,11 +22,9 @@ from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util import deprecation __all__ = [ @@ -91,10 +89,17 @@ class AffineLinearOperator(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, - event_ndims=1, validate_args=False, name="affine_linear_operator"): """Instantiates the `AffineLinearOperator` bijector. @@ -103,14 +108,11 @@ class AffineLinearOperator(bijector.Bijector): shift: Floating-point `Tensor`. scale: Subclass of `LinearOperator`. Represents the (batch) positive definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: - ValueError: if `event_ndims` is not 0 or 1. TypeError: if `scale` is not a `LinearOperator`. TypeError: if `shift.dtype` does not match `scale.dtype`. ValueError: if not `scale.is_non_singular`. @@ -120,20 +122,6 @@ class AffineLinearOperator(bijector.Bijector): self._validate_args = validate_args graph_parents = [] with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -166,10 +154,10 @@ class AffineLinearOperator(bijector.Bijector): self._scale = scale self._shaper = _DistributionShape( batch_ndims=batch_ndims, - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args) super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, graph_parents=graph_parents, is_constant_jacobian=True, dtype=dtype, @@ -213,12 +201,13 @@ class AffineLinearOperator(bijector.Bijector): x, sample_shape, expand_batch_dim=False) return x - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) + return constant_op.constant(0., dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if self.validate_args else []): return self.scale.log_abs_determinant() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..460d906231bd30f8cec4fe21d42afe7b2a05805e --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -0,0 +1,150 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Affine bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "AffineScalar", +] + + +class AffineScalar(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale * X + shift`. + + Examples: + + ```python + # Y = X + b = AffineScalar() + + # Y = X + shift + b = AffineScalar(shift=[1., 2, 3]) + + # Y = 2 * X + shift + b = AffineScalar( + shift=[1., 2, 3], + scale=2.) + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + shift=None, + scale=None, + validate_args=False, + name="affine_scalar"): + """Instantiates the `AffineScalar` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale * X + shift + ``` + + if `scale` is not specified, then the bijector has the semantics of + `scale = 1.`. Similarly, if `shift` is not specified, then the bijector + has the semantics of `shift = 0.`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale: Floating-point `Tensor`. If this is set to `None`, no scale is + applied. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + with self._name_scope("init", values=[scale, shift]): + self._shift = shift + self._scale = scale + + if self._shift is not None: + self._shift = ops.convert_to_tensor(shift, name="shift") + + if self._scale is not None: + self._scale = ops.convert_to_tensor(self._scale, name="scale") + if validate_args: + self._scale = control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + self._scale, + array_ops.zeros([], dtype=self._scale.dtype))], + self._scale) + + super(AffineScalar, self).__init__( + forward_min_event_ndims=0, + is_constant_jacobian=True, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = array_ops.identity(x) + if self.scale is not None: + y *= self.scale + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = array_ops.identity(y) + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x /= self.scale + return x + + def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. + if self.scale is None: + return constant_op.constant(0., dtype=x.dtype.base_dtype) + + return math_ops.log(math_ops.abs(self.scale)) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..f19f147dd645b4f805f1905899b44293284d4225 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -0,0 +1,280 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Batch Norm bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "BatchNormalization", +] + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def _undo_batch_normalization(x, + mean, + variance, + offset, + scale, + variance_epsilon, + name=None): + r"""Inverse of tf.nn.batch_normalization. + + Args: + x: Input `Tensor` of arbitrary dimensionality. + mean: A mean `Tensor`. + variance: A variance `Tensor`. + offset: An offset `Tensor`, often denoted `beta` in equations, or + None. If present, will be added to the normalized tensor. + scale: A scale `Tensor`, often denoted `gamma` in equations, or + `None`. If present, the scale is applied to the normalized tensor. + variance_epsilon: A small `float` added to the minibatch `variance` to + prevent dividing by zero. + name: A name for this operation (optional). + + Returns: + batch_unnormalized: The de-normalized, de-scaled, de-offset `Tensor`. + """ + with ops.name_scope( + name, "undo_batchnorm", [x, mean, variance, scale, offset]): + # inv = math_ops.rsqrt(variance + variance_epsilon) + # if scale is not None: + # inv *= scale + # return x * inv + ( + # offset - mean * inv if offset is not None else -mean * inv) + rescale = math_ops.sqrt(variance + variance_epsilon) + if scale is not None: + rescale /= scale + batch_unnormalized = x * rescale + ( + mean - offset * rescale if offset is not None else mean) + return batch_unnormalized + + +class BatchNormalization(bijector.Bijector): + """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. + + Applies Batch Normalization [(Ioffe and Szegedy, 2015)][1] to samples from a + data distribution. This can be used to stabilize training of normalizing + flows ([Papamakarios et al., 2016][3]; [Dinh et al., 2017][2]) + + When training Deep Neural Networks (DNNs), it is common practice to + normalize or whiten features by shifting them to have zero mean and + scaling them to have unit variance. + + The `inverse()` method of the `BatchNormalization` bijector, which is used in + the log-likelihood computation of data samples, implements the normalization + procedure (shift-and-scale) using the mean and standard deviation of the + current minibatch. + + Conversely, the `forward()` method of the bijector de-normalizes samples (e.g. + `X*std(Y) + mean(Y)` with the running-average mean and standard deviation + computed at training-time. De-normalization is useful for sampling. + + ```python + + dist = tfd.TransformedDistribution( + distribution=tfd.Normal()), + bijector=tfb.BatchNorm()) + + y = tfd.MultivariateNormalDiag(loc=1., scale=2.).sample(100) # ~ N(1, 2) + x = dist.bijector.inverse(y) # ~ N(0, 1) + y = dist.sample() # ~ N(1, 2) + ``` + + During training time, `BatchNorm.inverse` and `BatchNorm.forward` are not + guaranteed to be inverses of each other because `inverse(y)` uses statistics + of the current minibatch, while `forward(x)` uses running-average statistics + accumulated from training. In other words, + `BatchNorm.inverse(BatchNorm.forward(...))` and + `BatchNorm.forward(BatchNorm.inverse(...))` will be identical when + `training=False` but may be different when `training=True`. + + #### References + + [1]: Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating + Deep Network Training by Reducing Internal Covariate Shift. In + _International Conference on Machine Learning_, 2015. + https://arxiv.org/abs/1502.03167 + + [2]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + batchnorm_layer=None, + training=True, + validate_args=False, + name="batch_normalization"): + """Instantiates the `BatchNorm` bijector. + + Args: + batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`, + defaults to + `tf.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + 1e-6)`. + This ensures positivity of the scale variable. + + training: If True, updates running-average statistics during call to + `inverse()`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + Raises: + ValueError: If bn_layer is not an instance of + `tf.layers.BatchNormalization`, or if it is specified with `renorm=True` + or a virtual batch size. + """ + # Scale must be positive. + g_constraint = lambda x: nn.relu(x) + 1e-6 + self.batchnorm = batchnorm_layer or normalization.BatchNormalization( + gamma_constraint=g_constraint) + self._validate_bn_layer(self.batchnorm) + self._training = training + if isinstance(self.batchnorm.axis, int): + forward_min_event_ndims = 1 + else: + forward_min_event_ndims = len(self.batchnorm.axis) + super(BatchNormalization, self).__init__( + forward_min_event_ndims=forward_min_event_ndims, + validate_args=validate_args, name=name) + + def _validate_bn_layer(self, layer): + """Check for valid BatchNormalization layer. + + Args: + layer: Instance of `tf.layers.BatchNormalization`. + Raises: + ValueError: If batchnorm_layer argument is not an instance of + `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or + if `batchnorm_layer.virtual_batch_size` is specified. + """ + if not isinstance(layer, normalization.BatchNormalization): + raise ValueError( + "batchnorm_layer must be an instance of BatchNormalization layer.") + if layer.renorm: + raise ValueError("BatchNorm Bijector does not support renormalization.") + if layer.virtual_batch_size: + raise ValueError( + "BatchNorm Bijector does not support virtual batch sizes.") + + def _get_broadcast_fn(self, x): + # Compute shape to broadcast scale/shift parameters to. + if not x.shape.is_fully_defined(): + raise ValueError("Input must have shape known at graph construction.") + input_shape = np.int32(x.shape.as_list()) + + ndims = len(input_shape) + reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis] + # Broadcasting only necessary for single-axis batch norm where the axis is + # not the last dimension + broadcast_shape = [1] * ndims + broadcast_shape[self.batchnorm.axis[0]] = ( + input_shape[self.batchnorm.axis[0]]) + def _broadcast(v): + if (v is not None and + len(v.get_shape()) != ndims and + reduction_axes != list(range(ndims - 1))): + return array_ops.reshape(v, broadcast_shape) + return v + return _broadcast + + def _normalize(self, y): + return self.batchnorm.apply(y, training=self._training) + + def _de_normalize(self, x): + # Uses the saved statistics. + if not self.batchnorm.built: + input_shape = x.get_shape() + self.batchnorm.build(input_shape) + broadcast_fn = self._get_broadcast_fn(x) + mean = broadcast_fn(self.batchnorm.moving_mean) + variance = broadcast_fn(self.batchnorm.moving_variance) + beta = broadcast_fn(self.batchnorm.beta) if self.batchnorm.center else None + gamma = broadcast_fn(self.batchnorm.gamma) if self.batchnorm.scale else None + return _undo_batch_normalization( + x, mean, variance, beta, gamma, self.batchnorm.epsilon) + + def _forward(self, x): + return self._de_normalize(x) + + def _inverse(self, y): + return self._normalize(y) + + def _forward_log_det_jacobian(self, x): + # Uses saved statistics to compute volume distortion. + return -self._inverse_log_det_jacobian(x, use_saved_statistics=True) + + def _inverse_log_det_jacobian(self, y, use_saved_statistics=False): + if not y.shape.is_fully_defined(): + raise ValueError("Input must have shape known at graph construction.") + input_shape = np.int32(y.shape.as_list()) + + if not self.batchnorm.built: + # Create variables. + self.batchnorm.build(input_shape) + + event_dims = self.batchnorm.axis + reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims] + + if use_saved_statistics or not self._training: + log_variance = math_ops.log( + self.batchnorm.moving_variance + self.batchnorm.epsilon) + else: + # At training-time, ildj is computed from the mean and log-variance across + # the current minibatch. + _, v = nn.moments(y, axes=reduction_axes, keep_dims=True) + log_variance = math_ops.log(v + self.batchnorm.epsilon) + + # `gamma` and `log Var(y)` reductions over event_dims. + # Log(total change in area from gamma term). + log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma)) + + # Log(total change in area from log-variance term). + log_total_variance = math_ops.reduce_sum(log_variance) + # The ildj is scalar, as it does not depend on the values of x and are + # constant across minibatch elements. + return log_total_gamma - 0.5 * log_total_variance diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 3ce7c26213034c7345a20faa803c94a1bfa8d579..910774ea5bb4106a948567144c46c6db23a2c6e0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -20,8 +20,11 @@ from __future__ import print_function import itertools -from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -29,6 +32,98 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def _use_static_shape(input_tensor, ndims): + return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def _compute_min_event_ndims(bijector_list, compute_forward=True): + """Computes the min_event_ndims associated with the give list of bijectors. + + Given a list `bijector_list` of bijectors, compute the min_event_ndims that is + associated with the composition of bijectors in that list. + + min_event_ndims is the # of right most dimensions for which the bijector has + done necessary computation on (i.e. the non-broadcastable part of the + computation). + + We can derive the min_event_ndims for a chain of bijectors as follows: + + In the case where there are no rank changing bijectors, this will simply be + `max(b.forward_min_event_ndims for b in bijector_list)`. This is because the + bijector with the most forward_min_event_ndims requires the most dimensions, + and hence the chain also requires operating on those dimensions. + + However in the case of rank changing, more care is needed in determining the + exact amount of dimensions. Padding dimensions causes subsequent bijectors to + operate on the padded dimensions, and Removing dimensions causes bijectors to + operate more left. + + Args: + bijector_list: List of bijectors to be composed by chain. + compute_forward: Boolean. If True, computes the min_event_ndims associated + with a forward call to Chain, and otherwise computes the min_event_ndims + associated with an inverse call to Chain. The latter is the same as the + min_event_ndims associated with a forward call to Invert(Chain(....)). + + Returns: + min_event_ndims + """ + min_event_ndims = 0 + # This is a mouthful, but what this encapsulates is that if not for rank + # changing bijectors, we'd only need to compute the largest of the min + # required ndims. Hence "max_min". Due to rank changing bijectors, we need to + # account for synthetic rank growth / synthetic rank decrease from a rank + # changing bijector. + rank_changed_adjusted_max_min_event_ndims = 0 + + if compute_forward: + bijector_list = reversed(bijector_list) + + for b in bijector_list: + if compute_forward: + current_min_event_ndims = b.forward_min_event_ndims + current_inverse_min_event_ndims = b.inverse_min_event_ndims + else: + current_min_event_ndims = b.inverse_min_event_ndims + current_inverse_min_event_ndims = b.forward_min_event_ndims + + # New dimensions were touched. + if rank_changed_adjusted_max_min_event_ndims < current_min_event_ndims: + min_event_ndims += ( + current_min_event_ndims - rank_changed_adjusted_max_min_event_ndims) + rank_changed_adjusted_max_min_event_ndims = max( + current_min_event_ndims, rank_changed_adjusted_max_min_event_ndims) + + # If the number of dimensions has increased via forward, then + # inverse_min_event_ndims > forward_min_event_ndims, and hence the + # dimensions we computed on, have moved left (so we have operated + # on additional dimensions). + # Conversely, if the number of dimensions has decreased via forward, + # then we have inverse_min_event_ndims < forward_min_event_ndims, + # and so we will have operated on fewer right most dimensions. + + number_of_changed_dimensions = ( + current_min_event_ndims - current_inverse_min_event_ndims) + rank_changed_adjusted_max_min_event_ndims -= number_of_changed_dimensions + return min_event_ndims + + class Chain(bijector.Bijector): """Bijector which applies a sequence of bijectors. @@ -64,6 +159,14 @@ class Chain(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijectors=None, validate_args=False, name=None): """Instantiates `Chain` bijector. @@ -93,21 +196,24 @@ class Chain(bijector.Bijector): raise ValueError("incompatible dtypes: %s" % dtype) elif len(dtype) == 2: dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims elif len(dtype) == 1: dtype = dtype[0] - event_ndims = bijectors[0].event_ndims else: dtype = None - event_ndims = None + + inverse_min_event_ndims = _compute_min_event_ndims( + bijectors, compute_forward=False) + forward_min_event_ndims = _compute_min_event_ndims( + bijectors, compute_forward=True) super(Chain, self).__init__( graph_parents=list(itertools.chain.from_iterable( b.graph_parents for b in bijectors)), + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), validate_args=validate_args, dtype=dtype, - event_ndims=event_ndims, name=name or ("identity" if not bijectors else "_of_".join(["chain"] + [b.name for b in bijectors]))) @@ -147,10 +253,35 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") + y = ops.convert_to_tensor(y, name="y") + ildj = math_ops.cast(0., dtype=y.dtype.base_dtype) + + if not self.bijectors: + return ildj + + event_ndims = self._maybe_get_static_event_ndims( + self.inverse_min_event_ndims) + + if _use_static_shape(y, event_ndims): + event_shape = y.shape[y.shape.ndims - event_ndims:] + else: + event_shape = array_ops.shape(y)[array_ops.rank(y) - event_ndims:] + for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + ildj += b.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **kwargs.get(b.name, {})) + + if _use_static_shape(y, event_ndims): + event_shape = b.inverse_event_shape(event_shape) + event_ndims = self._maybe_get_static_event_ndims( + event_shape.ndims) + else: + event_shape = b.inverse_event_shape_tensor(event_shape) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -160,9 +291,34 @@ class Chain(bijector.Bijector): return x def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") + x = ops.convert_to_tensor(x, name="x") + + fldj = math_ops.cast(0., dtype=x.dtype.base_dtype) + + if not self.bijectors: + return fldj + + event_ndims = self._maybe_get_static_event_ndims( + self.forward_min_event_ndims) + + if _use_static_shape(x, event_ndims): + event_shape = x.shape[x.shape.ndims - event_ndims:] + else: + event_shape = array_ops.shape(x)[array_ops.rank(x) - event_ndims:] + for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + fldj += b.forward_log_det_jacobian( + x, event_ndims=event_ndims, **kwargs.get(b.name, {})) + if _use_static_shape(x, event_ndims): + event_shape = b.forward_event_shape(event_shape) + event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims) + else: + event_shape = b.forward_event_shape_tensor(event_shape) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..8267ee7df89f69f8d610e9507e0cca9f4a5d4323 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -20,8 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -29,6 +27,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -39,8 +38,6 @@ __all__ = [ class CholeskyOuterProduct(bijector.Bijector): """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - Note: the upper-triangular part of X is ignored (whether or not its zero). The surjectivity of g as a map from the set of n x n positive-diagonal @@ -57,53 +54,46 @@ class CholeskyOuterProduct(bijector.Bijector): its spectrum), and that the product of two positive-diagonal lower-triangular matrices is another positive-diagonal lower-triangular matrix. - A simple inductive argument (proceding one column of L_3 at a time) shows + A simple inductive argument (proceeding one column of L_3 at a time) shows that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - Examples: + #### Examples ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + bijector.CholeskyOuterProduct().forward(x=[[1., 0], [2, 1]]) # Result: [[1., 2], [2, 5]], i.e., x @ x.T - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + bijector.CholeskyOuterProduct().inverse(y=[[1., 2], [2, 5]]) # Result: [[1., 0], [2, 1]], i.e., cholesky(y). ``` """ - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. """ self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=2, validate_args=validate_args, name=name) def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) if self.validate_args: is_matrix = check_ops.assert_rank_at_least(x, 2) shape = array_ops.shape(x) @@ -114,11 +104,7 @@ class CholeskyOuterProduct(bijector.Bijector): return math_ops.matmul(x, x, adjoint_b=True) def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) + return linalg_ops.cholesky(y) def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: @@ -161,13 +147,6 @@ class CholeskyOuterProduct(bijector.Bijector): # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - diag = array_ops.matrix_diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output @@ -200,7 +179,7 @@ class CholeskyOuterProduct(bijector.Bijector): sum_weighted_log_diag = array_ops.squeeze( math_ops.matmul(math_ops.log(diag), exponents[..., array_ops.newaxis]), - squeeze_dims=-1) + axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index ccb1f029277bc07011df7be047a075274f2b3a27..e9e994f839ab2fe0a0f52f5f404fb2a0c8f9cd94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -44,12 +44,16 @@ class ConditionalBijector(bijector.Bijector): "**condition_kwargs": "Named arguments forwarded to subclass implementation."}) def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + self, y, event_ndims, name="inverse_log_det_jacobian", + **condition_kwargs): + return self._call_inverse_log_det_jacobian( + y, event_ndims, name, **condition_kwargs) @distribution_util.AppendDocstring(kwargs_dict={ "**condition_kwargs": "Named arguments forwarded to subclass implementation."}) def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) + self, x, event_ndims, name="forward_log_det_jacobian", + **condition_kwargs): + return self._call_forward_log_det_jacobian( + x, event_ndims, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..07627e1e45eae6b63d830b2adf036bdc3b1d2895 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors import power_transform +from tensorflow.python.util import deprecation __all__ = [ @@ -33,8 +34,8 @@ class Exp(power_transform.PowerTransform): ```python # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) + # batch ndim 2. + exp = Exp() x = [[[1., 2], [3, 4]], [[5, 6], @@ -47,20 +48,26 @@ class Exp(power_transform.PowerTransform): over the event space. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, - event_ndims=0, validate_args=False, name="exp"): """Instantiates the `Exp` bijector. Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ + # forward_min_event_ndims = 0. + # No forward_min_event_ndims specified as this is done in PowerTransform. super(Exp, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000000000000000000000000000000000..31a9ca27e519bc312813668bf621a875838f12a0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,165 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util +from tensorflow.python.util import deprecation + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index 67f39785563255be0fe154aca3cbcf01c6a01e73..71e562a927a30a17d695b81c566f981db7553ad9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Gumbel", @@ -45,10 +46,17 @@ class Gumbel(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=0., scale=1., - event_ndims=0, validate_args=False, name="gumbel"): """Instantiates the `Gumbel` bijector. @@ -60,8 +68,6 @@ class Gumbel(bijector.Bijector): scale: Positive Float-like `Tensor` that is the same dtype and is broadcastable with `loc`. This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -80,7 +86,9 @@ class Gumbel(bijector.Bijector): ], self._scale) super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + validate_args=validate_args, + forward_min_event_ndims=0, + name=name) @property def loc(self): @@ -102,15 +110,11 @@ class Gumbel(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + return math_ops.log(self.scale / (-math_ops.log(y) * y)) def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + return -z - math_ops.exp(-z) - math_ops.log(self.scale) def _maybe_assert_valid_y(self, y): if not self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..1504bd27204f728c0cb519159230e945128c4740 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -40,9 +41,17 @@ class Inline(bijector.Bijector): name="exp") ``` - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + The above example is equivalent to the `Bijector` `Exp()`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, forward_fn=None, inverse_fn=None, @@ -54,6 +63,8 @@ class Inline(bijector.Bijector): inverse_event_shape_tensor_fn=None, is_constant_jacobian=False, validate_args=False, + forward_min_event_ndims=None, + inverse_min_event_ndims=None, name="inline"): """Creates a `Bijector` from callables. @@ -76,10 +87,15 @@ class Inline(bijector.Bijector): constant for all input arguments. validate_args: Python `bool` indicating whether arguments should be checked for correctness. + forward_min_event_ndims: Python `int` indicating the minimal + dimensionality this bijector acts on. + inverse_min_event_ndims: Python `int` indicating the minimal + dimensionality this bijector acts on. name: Python `str`, name given to ops managed by this object. """ super(Inline, self).__init__( - event_ndims=0, + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -134,8 +150,8 @@ class Inline(bijector.Bijector): "inverse_log_det_jacobian_fn is not a callable function.") return self._inverse_log_det_jacobian_fn(y, **kwargs) - def _forward_log_det_jacobian(self, y, **kwargs): + def _forward_log_det_jacobian(self, x, **kwargs): if not callable(self._forward_log_det_jacobian_fn): raise NotImplementedError( "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) + return self._forward_log_det_jacobian_fn(x, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..a648676d4b1956e5c27f67a71e6bd93d0d7fc97d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,14 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops.distributions import bijector as bijector_lib +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Invert", ] -class Invert(bijector_lib.Bijector): +class Invert(bijector.Bijector): """Bijector which inverts another Bijector. Example Use: [ExpGammaDistribution (see Background & Context)]( @@ -40,6 +41,14 @@ class Invert(bijector_lib.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijector, validate_args=False, name=None): """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. @@ -66,8 +75,9 @@ class Invert(bijector_lib.Bijector): self._bijector = bijector super(Invert, self).__init__( - event_ndims=bijector.event_ndims, graph_parents=bijector.graph_parents, + forward_min_event_ndims=bijector.inverse_min_event_ndims, + inverse_min_event_ndims=bijector.forward_min_event_ndims, is_constant_jacobian=bijector.is_constant_jacobian, validate_args=validate_args, dtype=bijector.dtype, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..33b75a04d34fdd01bc0d854d4e5b9c45a737b122 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kumaraswamy bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + +__all__ = [ + "Kumaraswamy", +] + + +class Kumaraswamy(bijector.Bijector): + """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`. + + This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the [Kumaraswamy distribution]( + https://en.wikipedia.org/wiki/Kumaraswamy_distribution): + + ```none + Y ~ Kumaraswamy(a, b) + pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + name="kumaraswamy"): + """Instantiates the `Kumaraswamy` bijector. + + Args: + concentration1: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is + `concentration1`. + concentration0: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is + `concentration0`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + with self._name_scope("init", values=[concentration1, concentration0]): + concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args=validate_args) + concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args=validate_args) + + self._concentration1 = concentration1 + self._concentration0 = concentration0 + super(Kumaraswamy, self).__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + @property + def concentration1(self): + """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration1 + + @property + def concentration0(self): + """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration0 + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.exp( + math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0)) + / self.concentration1) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.exp(math_ops.log1p( + -(1 - y**self.concentration1)**self.concentration0)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid(y) + return ( + math_ops.log(self.concentration1) + math_ops.log(self.concentration0) + + (self.concentration1 - 1) * math_ops.log(y) + + (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1)) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid(self, x): + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + x, + message="sample must be non-negative"), + check_ops.assert_less_equal( + x, array_ops.ones([], self.concentration0.dtype), + message="sample must be no larger than `1`."), + ], x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 5251dbcb5748f75688aa43ce6e4e9dbd76be78bb..b8f2a4b2c731bdaee78692c036fb9f2fba4e3760 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -32,7 +32,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -42,17 +43,18 @@ __all__ = [ ] -class MaskedAutoregressiveFlow(bijector_lib.Bijector): +class MaskedAutoregressiveFlow(bijector.Bijector): """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, + The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a + relatively simple framework for user-specified (deep) architectures to learn + a distribution over vector-valued events. Regarding terminology, "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." + [(Papamakarios et al., 2016)][3] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -60,7 +62,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): this property by zeroing out weights in its `masked_dense` layers. In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression" is implemented using a `tf.while_loop` and a deep neural network (DNN) with masked weights such that the autoregressive property is automatically met in the `inverse`. @@ -75,26 +77,26 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): Given a `shift_and_log_scale_fn`, the forward and inverse transformations are (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. + must compute each `shift` (aka `loc` or "mu" in [Germain et al. (2015)][1]) + and `log(scale)` (aka "alpha" in [Germain et al. (2015)][1]) such that each + are broadcastable with the arguments to `forward` and `inverse`, i.e., such + that the calculations in `forward`, `inverse` [below] are possible. For convenience, `masked_autoregressive_default_template` is offered as a possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. + architecture [(Germain et al., 2015)][1]. MADE is a feed-forward network that + computes a `shift` and `log(scale)` using `masked_dense` layers in a deep + neural network. Weights are masked to ensure the autoregressive property. It + is possible that this architecture is suboptimal for your task. To build + alternative networks, either change the arguments to + `masked_autoregressive_default_template`, use the `masked_dense` function to + roll-out your own, or use some other architecture, e.g., using `tf.layers`. Warning: no attempt is made to validate that the `shift_and_log_scale_fn` enforces the "autoregressive property". Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, + semantics, the forward transformation is ```python def forward(x): @@ -106,7 +108,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): return y ``` - and the inverse transformation is, + and the inverse transformation is ```python def inverse(y): @@ -121,7 +123,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this also proves the transform is bijective.) - #### Example Use + #### Examples ```python tfd = tf.contrib.distributions @@ -142,7 +144,8 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): maf.log_prob(x) # Almost free; uses Bijector caching. maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - # [1] also describes an "Inverse Autoregressive Flow", e.g., + # [Papamakarios et al. (2016)][3] also describe an Inverse Autoregressive + # Flow [(Kingma et al., 2016)][2]: iaf = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( @@ -168,16 +171,30 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): event_shape=[dims]) ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 + [2]: Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya + Sutskever, and Max Welling. Improving Variational Inference with Inverse + Autoregressive Flow. In _Neural Information Processing Systems_, 2016. + https://arxiv.org/abs/1606.04934 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift_and_log_scale_fn, is_constant_jacobian=False, @@ -212,6 +229,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): self._shift_and_log_scale_fn = shift_and_log_scale_fn self._unroll_loop = unroll_loop super(MaskedAutoregressiveFlow, self).__init__( + forward_min_event_ndims=1, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -287,6 +305,14 @@ MASK_INCLUSIVE = "inclusive" MASK_EXCLUSIVE = "exclusive" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): """Generate the slices for building an autoregressive mask.""" # TODO(b/67594795): Better support of dynamic shape. @@ -304,6 +330,14 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): return slices +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_mask(num_blocks, n_in, n_out, @@ -318,6 +352,14 @@ def _gen_mask(num_blocks, return mask +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_dense(inputs, units, num_blocks=None, @@ -329,11 +371,7 @@ def masked_dense(inputs, **kwargs): """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + See [Germain et al. (2015)][1] for detailed explanation. Arguments: inputs: Tensor input. @@ -358,6 +396,12 @@ def masked_dense(inputs, Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ # TODO(b/67594795): Better support of dynamic shape. input_depth = inputs.shape.with_rank_at_least(1)[-1].value @@ -388,6 +432,14 @@ def masked_dense(inputs, return layer.apply(inputs) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_autoregressive_default_template( hidden_layers, shift_only=False, @@ -398,23 +450,24 @@ def masked_autoregressive_default_template( name=None, *args, **kwargs): - """Build the MADE Model [1]. + """Build the Masked Autoregressive Density Estimator (Germain et al., 2015). This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. + created once. It takes the input and returns the `loc` ("mu" in [Germain et + al. (2015)][1]) and `log_scale` ("alpha" in [Germain et al. (2015)][1]) from + the MADE network. Warning: This function uses `masked_dense` to create randomly initialized `tf.Variables`. It is presumed that these will be fit, just as you would any other neural architecture which uses `tf.layers.dense`. - #### About Hidden Layers: + #### About Hidden Layers Each element of `hidden_layers` should be greater than the `input_depth` (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the neural network). This is necessary to ensure the autoregressivity property. - #### About Clipping: + #### About Clipping This function also optionally clips the `log_scale` (but possibly not its gradient). This is useful because if `log_scale` is too small/large it might @@ -427,11 +480,7 @@ def masked_autoregressive_default_template( `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual `grad[clip(x)] exp(clip(x))`. - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: + Args: hidden_layers: Python `list`-like of non-negative integer, scalars indicating the number of units in each hidden layer. Default: `[512, 512]. shift_only: Python `bool` indicating if only the `shift` term shall be @@ -450,12 +499,20 @@ def masked_autoregressive_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms (the "mu" in + [Germain et al. (2015)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in + [Germain et al. (2015)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ with ops.name_scope(name, "masked_autoregressive_default_template", @@ -499,6 +556,14 @@ def masked_autoregressive_default_template( "masked_autoregressive_default_template", _fn) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): """Clips input while leaving gradient unaltered.""" with ops.name_scope(name, "clip_by_value_preserve_grad", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..49e6192f067edec4890dcfa107876a5104c14dd4 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,154 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py new file mode 100644 index 0000000000000000000000000000000000000000..fb393218b6b47764f45b5055bbf15cc17aba219e --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ordered bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "Ordered", +] + + +class Ordered(bijector.Bijector): + """Bijector which maps a tensor x_k that has increasing elements in the last + dimension to an unconstrained tensor y_k. + + Both the domain and the codomain of the mapping is [-inf, inf], however, + the input of the forward mapping must be strictly increasing. + The inverse of the bijector applied to a normal random vector `y ~ N(0, 1)` + gives back a sorted random vector with the same distribution `x ~ N(0, 1)` + where `x = sort(y)` + + On the last dimension of the tensor, Ordered bijector performs: + `y[0] = x[0]` + `y[1:] = math_ops.log(x[1:] - x[:-1])` + + #### Example Use: + + ```python + bijector.Ordered().forward([2, 3, 4]) + # Result: [2., 0., 0.] + + bijector.Ordered().inverse([0.06428002, -1.07774478, -0.71530371]) + # Result: [0.06428002, 0.40464228, 0.8936858] + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="ordered"): + super(Ordered, self).__init__( + forward_min_event_ndims=1, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None or input_shape[-1] is None: + return input_shape + return tensor_shape.TensorShape([input_shape[-1]]) + + def _forward_event_shape_tensor(self, input_shape): + return (input_shape[-1])[..., array_ops.newaxis] + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None or output_shape[-1] is None: + return output_shape + if output_shape[-1] <= 1: + raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1]) + return tensor_shape.TensorShape([output_shape[-1]]) + + def _inverse_event_shape_tensor(self, output_shape): + if self.validate_args: + is_greater_one = check_ops.assert_greater( + output_shape[-1], 1, message="Need last dimension greater than 1.") + output_shape = control_flow_ops.with_dependencies( + [is_greater_one], output_shape) + return (output_shape[-1])[..., array_ops.newaxis] + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + y0 = x[..., 0, array_ops.newaxis] + yk = math_ops.log(x[..., 1:] - x[..., :-1]) + y = array_ops.concat([y0, yk], axis=-1) + return y + + def _inverse(self, y): + x0 = y[..., 0, array_ops.newaxis] + xk = math_ops.exp(y[..., 1:]) + x = array_ops.concat([x0, xk], axis=-1) + return math_ops.cumsum(x, axis=-1) + + def _inverse_log_det_jacobian(self, y): + # The Jacobian of the inverse mapping is lower + # triangular, with the diagonal elements being: + # J[i,i] = 1 if i=1, and + # exp(y_i) if 1 1`, use the `tfb.Reshape` bijector to flatten the event shape. - - Recall that the MAF bijector [2] implements a normalizing flow via an - autoregressive transformation. MAF and IAF have opposite computational - tradeoffs - MAF can train all units in parallel but must sample units - sequentially, while IAF must train units sequentially but can sample in - parallel. In contrast, Real NVP can compute both forward and inverse - computations in parallel. However, the lack of an autoregressive + channel-wise masking [(Papamakarios et al., 2016)[4], use the `tfb.Permute` + bijector to re-order desired masked units into the first `d` units. For base + distributions with `event_ndims > 1`, use the `tfb.Reshape` bijector to + flatten the event shape. + + Recall that the MAF bijector [(Papamakarios et al., 2016)][4] implements a + normalizing flow via an autoregressive transformation. MAF and IAF have + opposite computational tradeoffs - MAF can train all units in parallel but + must sample units sequentially, while IAF must train units sequentially but + can sample in parallel. In contrast, Real NVP can compute both forward and + inverse computations in parallel. However, the lack of an autoregressive transformations makes it less expressive on a per-bijector basis. A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or - "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable - with the arguments to `forward` and `inverse`, i.e., such that the - calculations in `forward`, `inverse` [below] are possible. For convenience, + "mu" in [Papamakarios et al. (2016)][4]) and `log(scale)` (aka "alpha" in + [Papamakarios et al. (2016)][4]) such that each are broadcastable with the + arguments to `forward` and `inverse`, i.e., such that the calculations in + `forward`, `inverse` [below] are possible. For convenience, `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` function. - NICE [3] is a special case of the Real NVP bijector which discards the scale - transformation, resulting in a constant-time inverse-log-determinant-Jacobian. - To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should - return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in - the `RealNVP` constructor. Calling `real_nvp_default_template` with - `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. + NICE [(Dinh et al., 2014)][2] is a special case of the Real NVP bijector + which discards the scale transformation, resulting in a constant-time + inverse-log-determinant-Jacobian. To use a NICE bijector instead of Real + NVP, `shift_and_log_scale_fn` should return `(shift, None)`, and + `is_constant_jacobian` should be set to `True` in the `RealNVP` constructor. + Calling `real_nvp_default_template` with `shift_only=True` returns one such + NICE-compatible `shift_and_log_scale_fn`. Caching: the scalar input depth `D` of the base distribution is not known at construction time. The first call to any of `forward(x)`, `inverse(x)`, @@ -103,25 +107,34 @@ class RealNVP(bijector_lib.Bijector): nvp.log_prob(0.) ``` - For more examples, see [4]. + For more examples, see [Jang (2018)][3]. - [1]: "Density Estimation using Real NVP." - Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. - https://arxiv.org/abs/1605.08803 + #### References - [2]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + [1]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 - [3]: "NICE: Non-linear Independent Components Estimation." - Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. - https://arxiv.org/abs/1410.8516 + [2]: Laurent Dinh, David Krueger, and Yoshua Bengio. NICE: Non-linear + Independent Components Estimation. _arXiv preprint arXiv:1410.8516_, + 2014. https://arxiv.org/abs/1410.8516 - [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." - Eric Jang. Blog post. January 2018. - http://blog.evjang.com/2018/01/nf2.html + [3]: Eric Jang. Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows. + _Technical Report_, 2018. http://blog.evjang.com/2018/01/nf2.html + + [4]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, num_masked, shift_and_log_scale_fn, @@ -162,7 +175,7 @@ class RealNVP(bijector_lib.Bijector): self._input_depth = None self._shift_and_log_scale_fn = shift_and_log_scale_fn super(RealNVP, self).__init__( - event_ndims=1, + forward_min_event_ndims=1, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -220,10 +233,18 @@ class RealNVP(bijector_lib.Bijector): _, log_scale = self._shift_and_log_scale_fn( x0, self._input_depth - self._num_masked) if log_scale is None: - return constant_op.constant(0., dtype=x.dtype, name="ildj") + return constant_op.constant(0., dtype=x.dtype, name="fldj") return math_ops.reduce_sum(log_scale, axis=-1) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def real_nvp_default_template( hidden_layers, shift_only=False, @@ -250,12 +271,20 @@ def real_nvp_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms ("mu" in + [Papamakarios et al. (2016)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms ("alpha" in + [Papamakarios et al. (2016)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ with ops.name_scope(name, "real_nvp_default_template"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 55eca063126797d577653f0d6bcdfddf8192bdb5..c8282229a30fabff0c4c267d0bdfcdbce4f5f3d9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -28,7 +28,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,15 +37,31 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _static_ndims_from_shape(shape): return shape.shape.with_rank_at_least(1)[0].value +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _ndims_from_shape(shape): return array_ops.shape(shape)[0] -class Reshape(bijector_lib.Bijector): +class Reshape(bijector.Bijector): """Reshapes the `event_shape` of a `Tensor`. The semantics generally follow that of `tf.reshape()`, with @@ -86,6 +103,14 @@ class Reshape(bijector_lib.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, event_shape_out, event_shape_in=(-1,), validate_args=False, name=None): """Creates a `Reshape` bijector. @@ -128,15 +153,17 @@ class Reshape(bijector_lib.Bijector): self._event_shape_in = event_shape_in self._event_shape_out = event_shape_out - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") + super(Reshape, self).__init__( + forward_min_event_ndims=0, + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") def _maybe_check_valid_shape(self, shape, validate_args): """Check that a shape Tensor is int-type and otherwise sane.""" if not shape.dtype.is_integer: raise TypeError("{} dtype ({}) should be `int`-like.".format( - shape.op.name, shape.dtype.name)) + shape, shape.dtype.name)) assertions = [] @@ -144,10 +171,10 @@ class Reshape(bijector_lib.Bijector): ndims_ = tensor_util.constant_value(ndims) if ndims_ is not None and ndims_ > 1: raise ValueError("`{}` rank ({}) should be <= 1.".format( - shape.op.name, ndims_)) + shape, ndims_)) elif validate_args: assertions.append(check_ops.assert_less_equal( - ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + ndims, 1, message="`{}` rank should be <= 1.".format(shape))) shape_ = tensor_util.constant_value_as_shape(shape) if shape_.is_fully_defined(): @@ -155,12 +182,12 @@ class Reshape(bijector_lib.Bijector): if sum(es == -1) > 1: raise ValueError( "`{}` must have at most one `-1` (given {})" - .format(shape.op.name, es)) + .format(shape, es)) if np.any(es < -1): raise ValueError( "`{}` elements must be either positive integers or `-1`" "(given {})." - .format(shape.op.name, es)) + .format(shape, es)) elif validate_args: assertions.extend([ check_ops.assert_less_equal( @@ -168,11 +195,11 @@ class Reshape(bijector_lib.Bijector): math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), 1, message="`{}` elements must have at most one `-1`." - .format(shape.op.name)), + .format(shape)), check_ops.assert_greater_equal( shape, -1, message="`{}` elements must be either positive integers or `-1`." - .format(shape.op.name)), + .format(shape)), ]) return assertions diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbe8665781211ca803feb8bf5a8c04fb0b969e8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar +from tensorflow.contrib.distributions.python.ops.bijectors import chain +from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular +from tensorflow.contrib.distributions.python.ops.bijectors import softplus +from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal +from tensorflow.python.util import deprecation + +__all__ = [ + "ScaleTriL", +] + + +class ScaleTriL(chain.Chain): + """Transforms unconstrained vectors to TriL matrices with positive diagonal. + + This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` + followed by `tfb.TransformDiagonal`, and provided mostly as a + convenience. The default setup is somewhat opinionated, using a + Softplus transformation followed by a small shift (`1e-5`) which + attempts to avoid numerical issues from zeros on the diagonal. + + #### Examples + + ```python + tfb = tf.contrib.distributions.bijectors + b = tfb.ScaleTriL( + diag_bijector=tfb.Exp(), + diag_shift=None) + b.forward(x=[0., 0., 0.]) + # Result: [[1., 0.], + # [0., 1.]] + b.inverse(y=[[1., 0], + [.5, 2]]) + # Result: [log(2), .5, log(1)] + + # Define a distribution over PSD matrices of shape `[3, 3]`, + # with `1 + 2 + 3 = 6` degrees of freedom. + dist = tfd.TransformedDistribution( + tfd.Normal(tf.zeros(6), tf.ones(6)), + tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()])) + + # Using an identity transformation, ScaleTriL is equivalent to + # tfb.FillTriangular. + b = tfb.ScaleTriL( + diag_bijector=tfb.Identity(), + diag_shift=None) + + # For greater control over initialization, one can manually encode + # pre- and post- shifts inside of `diag_bijector`. + b = tfb.ScaleTriL( + diag_bijector=tfb.Chain([ + tfb.AffineScalar(shift=1e-3), + tfb.Softplus(), + tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) + # = log(expm1(1.)) = 0.5413 + diag_shift=None) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector=None, + diag_shift=1e-5, + validate_args=False, + name="scale_tril"): + """Instantiates the `ScaleTriL` bijector. + + Args: + diag_bijector: `Bijector` instance, used to transform the output diagonal + to be positive. + Default value: `None` (i.e., `tfb.Softplus()`). + diag_shift: Float value broadcastable and added to all diagonal entries + after applying the `diag_bijector`. Setting a positive + value forces the output diagonal entries to be positive, but + prevents inverting the transformation for matrices with + diagonal entries less than this value. + Default value: `1e-5` (i.e., no shift is applied). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + Default value: `False` (i.e., arguments are not validated). + name: Python `str` name given to ops managed by this object. + Default value: `scale_tril`. + """ + + if diag_bijector is None: + diag_bijector = softplus.Softplus(validate_args=validate_args) + + if diag_shift is not None: + diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift), + diag_bijector]) + + super(ScaleTriL, self).__init__( + [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), + fill_triangular.FillTriangular()], + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index a640dfe7dfbcce96261589c7fc49107deaefdd54..194b318fce31a13f84e7b664b58cebb24fc9a264 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,9 +32,19 @@ __all__ = [ class Sigmoid(bijector.Bijector): """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) def _forward(self, x): return math_ops.sigmoid(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py deleted file mode 100644 index 223bc9d042c69be05b0e578835a31ed6e83c0c97..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SigmoidCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered - - -__all__ = [ - "SigmoidCentered", -] - - -class SigmoidCentered(softmax_centered.SoftmaxCentered): - """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). - - Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. - - See `bijector.SoftmaxCentered` for more details. - """ - - def __init__(self, validate_args=False, name="sigmoid_centered"): - super(SigmoidCentered, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 3a75e4ae9495793901b0da91a5aa3982aab35852..241fba2cb7ec33b7b02c1ca79051f1b826d7d2aa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -26,12 +26,21 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _sqrtx2p1(x): """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" return array_ops.where( @@ -88,10 +97,17 @@ class SinhArcsinh(bijector.Bijector): `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, skewness=None, tailweight=None, - event_ndims=0, validate_args=False, name="SinhArcsinh"): """Instantiates the `SinhArcsinh` bijector. @@ -101,8 +117,6 @@ class SinhArcsinh(bijector.Bijector): of type `float32`. tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -125,7 +139,9 @@ class SinhArcsinh(bijector.Bijector): message="Argument tailweight was not positive") ], self._tailweight) super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) @property def skewness(self): @@ -149,31 +165,29 @@ class SinhArcsinh(bijector.Bijector): # dx/dy # = cosh(arcsinh(y) / tailweight - skewness) # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + return ( math_ops.log(math_ops.cosh( math_ops.asinh(y) / self.tailweight - self.skewness) # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) + - math_ops.log(self.tailweight)) def _forward_log_det_jacobian(self, x): # y = sinh((arcsinh(x) + skewness) * tailweight) # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), # dy/dx # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + return ( math_ops.log(math_ops.cosh( (math_ops.asinh(x) + self.skewness) * self.tailweight) # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) + + math_ops.log(self.tailweight)) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index a9dcce6c526600f3b26c6bceb730417000917ce7..20ee0d340833d5c5275e2ab52a89dcdf7198add1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,19 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -47,17 +43,14 @@ class SoftmaxCentered(bijector.Bijector): e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last coordinate. - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - Example Use: ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + bijector.SoftmaxCentered().forward(tf.log([2, 3, 4])) # Result: [0.2, 0.3, 0.4, 0.1] # Extra result: 0.1 - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + bijector.SoftmaxCentered().inverse([0.2, 0.3, 0.4, 0.1]) # Result: tf.log([2, 3, 4]) # Extra coordinate removed. ``` @@ -68,88 +61,59 @@ class SoftmaxCentered(bijector.Bijector): makes the (forward) image non-open and the theorem does not directly apply. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, - event_ndims=0, validate_args=False, name="softmax_centered"): self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, validate_args=validate_args, name=name) def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: + if input_shape.ndims is None or input_shape[-1] is None: return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + return tensor_shape.TensorShape([input_shape[-1] + 1]) def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 + return (input_shape[-1] + 1)[..., array_ops.newaxis] def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: + if output_shape.ndims is None or output_shape[-1] is None: return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) + if output_shape[-1] <= 1: + raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1]) + return tensor_shape.TensorShape([output_shape[-1] - 1]) def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] if self.validate_args: # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) + is_greater_one = check_ops.assert_greater( + output_shape[-1], 1, message="Need last dimension greater than 1.") + output_shape = control_flow_ops.with_dependencies( + [is_greater_one], output_shape) + return (output_shape[-1] - 1)[..., array_ops.newaxis] def _forward(self, x): # Pad the last dim with a zeros vector. We need this because it lets us # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - y = distribution_util.pad(y, axis=-1, back=True) + y = distribution_util.pad(x, axis=-1, back=True) # Set shape hints. if x.shape.ndims is not None: - shape = x.shape.as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) + shape = x.shape[:-1].concatenate(x.shape[-1] + 1) y.shape.assert_is_compatible_with(shape) y.set_shape(shape) - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). return nn_ops.softmax(y) def _inverse(self, y): @@ -161,42 +125,17 @@ class SoftmaxCentered(bijector.Bijector): # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) # = log(exp(x[i])/normalization) - log(y[end]) # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.shape.as_list(), dtype=np.int32) - if y.shape.is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = distribution_util.prefer_static_rank(y) # Do this first to make sure CSE catches that it'll happen again in # _inverse_log_det_jacobian. x = math_ops.log(y) - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + log_normalization = (-x[..., -1])[..., array_ops.newaxis] + x = x[..., :-1] + log_normalization # Set shape hints. if y.shape.ndims is not None: - shape = y.shape.as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) + shape = y.shape[:-1].concatenate(y.shape[-1] - 1) x.shape.assert_is_compatible_with(shape) x.set_shape(shape) @@ -222,19 +161,14 @@ class SoftmaxCentered(bijector.Bijector): return -math_ops.reduce_sum(math_ops.log(y), axis=-1) def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + return array_ops.squeeze( + (-log_normalization + math_ops.reduce_sum( + x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 81957fcf78922fa15fd20a25d144071f431161ae..3df84ef8b04c2c8f6be91ecd1c972ad1484b4285 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -62,7 +63,7 @@ class Softplus(bijector.Bijector): ```python # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) + softplus = Softplus() x = [[[1., 2], [3, 4]], [[5, 6], @@ -80,8 +81,15 @@ class Softplus(bijector.Bijector): "hinge_softness": ( "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, - event_ndims=0, hinge_softness=None, validate_args=False, name="softplus"): @@ -101,7 +109,7 @@ class Softplus(bijector.Bijector): [nonzero_check], self.hinge_softness) super(Softplus, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -130,14 +138,12 @@ class Softplus(bijector.Bijector): # 1 - exp{-Y} approx Y. if self.hinge_softness is not None: y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) + return -math_ops.log(-math_ops.expm1(-y)) def _forward_log_det_jacobian(self, x): if self.hinge_softness is not None: x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) + return -nn_ops.softplus(-x) @property def hinge_softness(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py new file mode 100644 index 0000000000000000000000000000000000000000..f96a4bb01de59a21107b9e7c14f929e13e358ac9 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Softsign bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "Softsign", +] + + +class Softsign(bijector.Bijector): + """Bijector which computes `Y = g(X) = X / (1 + |X|)`. + + The softsign `Bijector` has the following two useful properties: + + * The domain is all real numbers + * `softsign(x) approx sgn(x)`, for large `|x|`. + + #### Examples + + ```python + # Create the Y = softsign(X) transform. + softsign = Softsign() + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + x / (1 + abs(x)) == softsign.forward(x) + x / (1 - abs(x)) == softsign.inverse(x) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="softsign"): + super(Softsign, self).__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return x / (1. + math_ops.abs(x)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return y / (1. - math_ops.abs(y)) + + def _forward_log_det_jacobian(self, x): + return -2. * math_ops.log1p(math_ops.abs(x)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + return -2. * math_ops.log1p(-math_ops.abs(y)) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = [ + check_ops.assert_greater( + y, math_ops.cast(-1., dtype=y.dtype.base_dtype), + message="Inverse transformation input must be greater than -1."), + check_ops.assert_less( + y, math_ops.cast(1., dtype=y.dtype.base_dtype), + message="Inverse transformation input must be less than 1.") + ] + + return control_flow_ops.with_dependencies(is_valid, y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py new file mode 100644 index 0000000000000000000000000000000000000000..294460a80f6209797831ea361e64efe677f71e59 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -0,0 +1,92 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Square bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "Square", +] + + +class Square(bijector.Bijector): + """Compute `g(X) = X^2`; X is a positive real number. + + g is a bijection between the non-negative real numbers (R_+) and the + non-negative real numbers. + + #### Examples + + ```python + bijector.Square().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [4, 1]], i.e., x^2 + + bijector.Square().inverse(y=[[1., 4], [9, 1]]) + # Result: [[1., 2], [3, 1]], i.e., sqrt(y). + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="square"): + """Instantiates the `Square` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._name = name + super(Square, self).__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.square(x) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.sqrt(y) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid(x) + return np.log(2.) + math_ops.log(x) + + def _maybe_assert_valid(self, t): + if not self.validate_args: + return t + is_valid = check_ops.assert_non_negative( + t, message="All elements must be non-negative.") + return control_flow_ops.with_dependencies([is_valid], t) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7a3b026b8dcc31bed49c489d77b9c184f463cb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + +__all__ = [ + "TransformDiagonal", +] + + +class TransformDiagonal(bijector.Bijector): + """Applies a Bijector to the diagonal of a matrix. + + #### Example + + ```python + b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) + + b.forward([[1., 0.], + [0., 1.]]) + # ==> [[2.718, 0.], + [0., 2.718]] + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector, + validate_args=False, + name="transform_diagonal"): + """Instantiates the `TransformDiagonal` bijector. + + Args: + diag_bijector: `Bijector` instance used to transform the diagonal. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._diag_bijector = diag_bijector + super(TransformDiagonal, self).__init__( + forward_min_event_ndims=2, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x)) + return array_ops.matrix_set_diag(x, diag) + + def _inverse(self, y): + diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y)) + return array_ops.matrix_set_diag(y, diag) + + def _forward_log_det_jacobian(self, x): + # We formulate the Jacobian with respect to the flattened matrices + # `vec(x)` and `vec(y)`. Suppose for notational convenience that + # the first `n` entries of `vec(x)` are the diagonal of `x`, and + # the remaining `n**2-n` entries are the off-diagonals in + # arbitrary order. Then the Jacobian is a block-diagonal matrix, + # with the Jacobian of the diagonal bijector in the first block, + # and the identity Jacobian for the remaining entries (since this + # bijector acts as the identity on non-diagonal entries): + # + # J_vec(x) (vec(y)) = + # ------------------------------- + # | J_diag(x) (diag(y)) 0 | n entries + # | | + # | 0 I | n**2-n entries + # ------------------------------- + # n n**2-n + # + # Since the log-det of the second (identity) block is zero, the + # overall log-det-jacobian is just the log-det of first block, + # from the diagonal bijector. + # + # Note that for elementwise operations (exp, softplus, etc) the + # first block of the Jacobian will itself be a diagonal matrix, + # but our implementation does not require this to be true. + return self._diag_bijector.forward_log_det_jacobian( + array_ops.matrix_diag_part(x), event_ndims=1) + + def _inverse_log_det_jacobian(self, y): + return self._diag_bijector.inverse_log_det_jacobian( + array_ops.matrix_diag_part(y), event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index 00520bcda85e9527767e6342bf75f10667c264a8..8903a70d98ae144731b12047e5074d0450b59378 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -47,10 +48,17 @@ class Weibull(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale=1., concentration=1., - event_ndims=0, validate_args=False, name="weibull"): """Instantiates the `Weibull` bijector. @@ -62,8 +70,6 @@ class Weibull(bijector.Bijector): concentration: Positive Float-type `Tensor` that is the same dtype and is broadcastable with `scale`. This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -89,7 +95,7 @@ class Weibull(bijector.Bijector): ], self._concentration) super(Weibull, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -113,29 +119,25 @@ class Weibull(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( + return ( -math_ops.log1p(-y) + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) + math_ops.log(self.scale / self.concentration)) def _forward_log_det_jacobian(self, x): x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( + return ( -(x / self.scale) ** self.concentration + (self.concentration - 1) * math_ops.log(x) + math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) + -self.concentration * math_ops.log(self.scale)) def _maybe_assert_valid_x(self, x): if not self.validate_args: return x is_valid = check_ops.assert_non_negative( x, - message="Forward transformation input must be at least {}.".format(0)) + message="Forward transformation input must be at least 0.") return control_flow_ops.with_dependencies([is_valid], x) def _maybe_assert_valid_y(self, y): diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index 6a1bb39ab28218a411bdf4329965186bcf32bf30..b349e5966dd750fdf96c0b211dce02658c9400b7 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation _binomial_sample_note = """ @@ -42,6 +43,14 @@ to integer values. """ +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _bdtr(k, n, p): """The binomial cumulative distribution function. @@ -130,6 +139,14 @@ class Binomial(distribution.Distribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, @@ -163,8 +180,8 @@ class Binomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[total_count, logits, probs]): + parameters = dict(locals()) + with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), validate_args) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index 6f5d724a2a945ed8f9c159d8314327c6f994d1db..cb5223b0557080e10bf24c3e1cb432f15fd5e7e3 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "Cauchy", @@ -92,6 +93,14 @@ class Cauchy(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -120,8 +129,8 @@ class Cauchy(distribution.Distribution): Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = locals() - with ops.name_scope(name, values=[loc, scale]): + parameters = dict(locals()) + with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index bdd5571c966a74e58e4f9f8eed2628f131a1b92e..e9a7b39070f3d76693ad54852ed0847a0980d2a6 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -21,8 +21,11 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma +from tensorflow.python.util import deprecation __all__ = [ @@ -61,6 +64,14 @@ class Chi2(gamma.Gamma): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, @@ -81,13 +92,17 @@ class Chi2(gamma.Gamma): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True # through to the parent class results in unnecessary asserts. - with ops.name_scope(name, values=[df]): - self._df = ops.convert_to_tensor(df, name="df") + with ops.name_scope(name, values=[df]) as name: + with ops.control_dependencies([ + check_ops.assert_positive(df), + ] if validate_args else []): + self._df = array_ops.identity(df, name="df") + super(Chi2, self).__init__( concentration=0.5 * self._df, rate=constant_op.constant(0.5, dtype=self._df.dtype), @@ -108,13 +123,21 @@ class Chi2(gamma.Gamma): class Chi2WithAbsDf(Chi2): """Chi2 with parameter transform `df = floor(abs(df))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): - parameters = locals() - with ops.name_scope(name, values=[df]): + parameters = dict(locals()) + with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( math_ops.abs(df, name="abs_df"), diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea..3598c8d23ea9007fb359ae4931738fb61ede4ccc 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -105,7 +105,9 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + event_ndims = self._maybe_get_static_event_ndims() + ildj = self.bijector.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_log_prob_for_one_fiber(y, x, ildj, distribution_kwargs) @@ -128,7 +130,9 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + event_ndims = self._maybe_get_static_event_ndims() + ildj = self.bijector.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_prob_for_one_fiber(y, x, ildj, distribution_kwargs) @@ -214,3 +218,15 @@ class ConditionalTransformedDistribution( # implies the qth quantile of Y is g(x_q). inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) + + def _maybe_get_static_event_ndims(self): + if self.event_shape.ndims is not None: + return self.event_shape.ndims + + event_ndims = array_ops.size(self.event_shape_tensor()) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) + + if event_ndims_ is not None: + return event_ndims_ + + return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 8049522e9f5dc26b244b7e710a9ae8b981efd6b6..ad853ee293f86565c1af601214522f53d936b70a 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "Deterministic", @@ -43,6 +44,14 @@ __all__ = [ class _BaseDeterministic(distribution.Distribution): """Base class for Deterministic distributions.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -86,8 +95,8 @@ class _BaseDeterministic(distribution.Distribution): Raises: ValueError: If `loc` is a scalar. """ - parameters = locals() - with ops.name_scope(name, values=[loc, atol, rtol]): + parameters = dict(locals()) + with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: msg = "Argument loc must be at least rank 1." @@ -203,6 +212,14 @@ class Deterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -308,6 +325,14 @@ class VectorDeterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 289e1d50e1146a641c0cc433ece3465aed73b1c2..6959b3e8775d2dd488b4ee3252d143ef376d58f9 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,12 +21,19 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib + +# The following two lines are redundant, in a sense. The first enables +# good coding practice *within* this file (`util.prefer_static_value` +# rather than `prefer_static_value`). The second ensures that users +# also get the core utils when they import this file. +from tensorflow.python.ops.distributions import util from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def move_dimension(x, source_idx, dest_idx): + """Move a single tensor dimension within its shape. + + This is a special case of `tf.transpose()`, which applies + arbitrary permutations to tensor dimensions. + + Args: + x: Tensor of rank `ndims`. + source_idx: Integer index into `x.shape` (negative indexing is + supported). + dest_idx: Integer index into `x.shape` (negative indexing is + supported). + + Returns: + x_perm: Tensor of rank `ndims`, in which the dimension at original + index `source_idx` has been moved to new index `dest_idx`, with + all other dimensions retained in their original order. + + Example: + + ```python + x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x_perm = _move_dimension(x, 1, 1) # no-op + x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] + x_perm = _move_dimension(x, 0, -2) # equivalent to previous + x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] + ``` + """ + ndims = util.prefer_static_rank(x) + if isinstance(source_idx, int): + dtype = dtypes.int32 + else: + dtype = dtypes.as_dtype(source_idx.dtype) + + # Handle negative indexing. Since ndims might be dynamic, this makes + # source_idx and dest_idx also possibly dynamic. + if source_idx < 0: + source_idx = ndims + source_idx + if dest_idx < 0: + dest_idx = ndims + dest_idx + + # Construct the appropriate permutation of dimensions, depending + # whether the source is before or after the destination. + def move_left_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, dest_idx, dtype=dtype), + [source_idx], + math_ops.range(dest_idx, source_idx, dtype=dtype), + math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) + + def move_right_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, source_idx, dtype=dtype), + math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), + [source_idx], + math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) + + def x_permuted(): + return array_ops.transpose( + x, perm=smart_cond.smart_cond(source_idx < dest_idx, + move_right_permutation, + move_left_permutation)) + + # One final conditional to handle the special case where source + # and destination indices are equal. + return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), + lambda: x, + x_permuted) diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 6b53338c4542c75d3977c075b7750c780080ac48..bdec6527d5378d6e86aa8e6279cc6ee672083e56 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHea from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -30,6 +31,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def estimator_head_distribution_regression(make_distribution_fn, label_dimension=1, logits_dimension=None, @@ -75,8 +84,16 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): - """Creates a _RegressionHead instance from an arbitray `Distribution`.""" - + """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, make_distribution_fn, label_dimension, diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index 8f190e48a7148d84082d73771cba4660a1a0d221..d62f024aa2a081f0ec231015af1f26a8851518e9 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Geometric(distribution.Distribution): @@ -55,6 +56,14 @@ class Geometric(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, logits=None, probs=None, @@ -85,8 +94,8 @@ class Geometric(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[logits, probs]): + parameters = dict(locals()) + with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5..acdea4d61d3ada7e9f4f0aa7bc58c5643db2802b 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation class _Gumbel(distribution.Distribution): @@ -96,6 +97,14 @@ class _Gumbel(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -124,8 +133,8 @@ class _Gumbel(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() - with ops.name_scope(name, values=[loc, scale]): + parameters = dict(locals()) + with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") @@ -190,9 +199,6 @@ class _Gumbel(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return -math_ops.exp(-self._z(x)) diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index fc0751a6e0b78cb3d79bd3478e740bb05cd26428..b02c4031069191592b8acc1a90313450f98af6d7 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math +from tensorflow.python.util import deprecation __all__ = [ @@ -85,6 +86,14 @@ class HalfNormal(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale, validate_args=False, @@ -105,8 +114,8 @@ class HalfNormal(distribution.Distribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[scale]): + parameters = dict(locals()) + with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): self._scale = array_ops.identity(scale, name="scale") diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index cbce005013281ff3c58c94d525d5ce7a865d725a..0672702b96c1eb81c176774554df3f5922a0319e 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -28,6 +28,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.util import deprecation class Independent(distribution_lib.Distribution): @@ -35,7 +37,7 @@ class Independent(distribution_lib.Distribution): This distribution is useful for regarding a collection of independent, non-identical distributions as a single random variable. For example, the - `Indpendent` distribution composed of a collection of `Bernoulli` + `Independent` distribution composed of a collection of `Bernoulli` distributions might define a distribution over an image (where each `Bernoulli` is a distribution over each pixel). @@ -93,6 +95,14 @@ class Independent(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): @@ -115,10 +125,10 @@ class Independent(distribution_lib.Distribution): ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ - parameters = locals() + parameters = dict(locals()) name = name or "Independent" + distribution.name self._distribution = distribution - with ops.name_scope(name): + with ops.name_scope(name) as name: if reinterpreted_batch_ndims is None: reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( distribution) @@ -254,3 +264,66 @@ class Independent(distribution_lib.Distribution): else: which_maximum = np.maximum return which_maximum(0, ndims - 1) + + +@kullback_leibler.RegisterKL(Independent, Independent) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def _kl_independent(a, b, name="kl_independent"): + """Batched KL divergence `KL(a || b)` for Independent distributions. + + We can leverage the fact that + ``` + KL(Independent(a) || Independent(b)) = sum(KL(a || b)) + ``` + where the sum is over the `reinterpreted_batch_ndims`. + + Args: + a: Instance of `Independent`. + b: Instance of `Independent`. + name: (optional) name to use for created ops. Default "kl_independent". + + Returns: + Batchwise `KL(a || b)`. + + Raises: + ValueError: If the event space for `a` and `b`, or their underlying + distributions don't match. + """ + p = a.distribution + q = b.distribution + + # The KL between any two (non)-batched distributions is a scalar. + # Given that the KL between two factored distributions is the sum, i.e. + # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute + # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. + if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined(): + if a.event_shape == b.event_shape: + if p.event_shape == q.event_shape: + num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims + reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] + + return math_ops.reduce_sum( + kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) + else: + raise NotImplementedError("KL between Independents with different " + "event shapes not supported.") + else: + raise ValueError("Event shapes do not match.") + else: + with ops.control_dependencies([ + check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), + check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) + ]): + num_reduce_dims = ( + array_ops.shape(a.event_shape_tensor()[0]) - + array_ops.shape(p.event_shape_tensor()[0])) + reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1) + return math_ops.reduce_sum( + kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index ee4d86867d48b20e97757bcec57d452085814b80..70d050d7a647b38928ddb1c788db0e6957ac0f03 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -95,6 +96,14 @@ class InverseGamma(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, @@ -125,8 +134,8 @@ class InverseGamma(distribution.Distribution): Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = locals() - with ops.name_scope(name, values=[concentration, rate]): + parameters = dict(locals()) + with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), check_ops.assert_positive(rate), @@ -192,12 +201,6 @@ class InverseGamma(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - - def _log_cdf(self, x): - return math_ops.log(self._cdf(x)) - def _cdf(self, x): x = self._maybe_assert_valid_sample(x) # Note that igammac returns the upper regularized incomplete gamma @@ -280,14 +283,22 @@ class InverseGamma(distribution.Distribution): class InverseGammaWithSoftplusConcentrationRate(InverseGamma): """`InverseGamma` with softplus of `concentration` and `rate`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): - parameters = locals() - with ops.name_scope(name, values=[concentration, rate]): + parameters = dict(locals()) + with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, name="softplus_concentration"), diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 74d5d8773cf3e69a52554c87d656fea2835c8354..e3712dd84e36609d6bba4a5a39866046c0c8d1d8 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -20,16 +20,18 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.distributions import beta from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -39,28 +41,33 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. - Derivation from [1] and Euler's constant [2]. - [1] - - https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers - [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant - + Derivation from [here]( + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers) + and [Euler's constant]( + https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant). Args: x: input float. Returns: z: The analytic continuation of the harmonic number for the input. - """ one = array_ops.ones([], dtype=x.dtype) return math_ops.digamma(x + one) - math_ops.digamma(one) -@tf_export("distributions.Kumaraswamy") -class Kumaraswamy(beta.Beta): +class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. The Kumaraswamy distribution is defined over the `(0, 1)` interval using @@ -125,6 +132,14 @@ class Kumaraswamy(beta.Beta): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, @@ -151,59 +166,33 @@ class Kumaraswamy(beta.Beta): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ + with ops.name_scope(name, values=[concentration1, concentration0]) as name: + concentration1 = ops.convert_to_tensor( + concentration1, name="concentration1") + concentration0 = ops.convert_to_tensor( + concentration0, name="concentration0") super(Kumaraswamy, self).__init__( - concentration1=concentration1, - concentration0=concentration0, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, + distribution=uniform.Uniform( + low=array_ops.zeros([], dtype=concentration1.dtype), + high=array_ops.ones([], dtype=concentration1.dtype), + allow_nan_stats=allow_nan_stats), + bijector=bijectors.Kumaraswamy( + concentration1=concentration1, concentration0=concentration0, + validate_args=validate_args), + batch_shape=distribution_util.get_broadcast_shape( + concentration1, concentration0), name=name) self._reparameterization_type = distribution.FULLY_REPARAMETERIZED - def _sample_n(self, n, seed=None): - expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 - expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 - shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) - uniform_sample = random_ops.random_uniform( - shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) - - kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( - 1. / expanded_concentration1) - return kumaraswamy_sample - - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _log_cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return math_ops.log1p(-(1 - x**a)**b) + @property + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self.bijector.concentration1 - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return 1 - (1 - x**a)**b - - def _survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return (1 - x**a)**b - - def _log_survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return b * math_ops.log1p(-x**a) - - def _log_unnormalized_prob(self, x): - x = self._maybe_assert_valid_sample(x) - a = self.concentration1 - b = self.concentration0 - return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) - - def _log_normalization(self): - a = self.concentration1 - b = self.concentration0 - return -(math_ops.log(a) + math_ops.log(b)) + @property + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self.bijector.concentration0 def _entropy(self): a = self.concentration1 @@ -213,10 +202,11 @@ class Kumaraswamy(beta.Beta): def _moment(self, n): """Compute the n'th (uncentered) moment.""" + total_concentration = self.concentration1 + self.concentration0 expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 + total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 + total_concentration, dtype=self.dtype) * self.concentration0 beta_arg0 = 1 + n / expanded_concentration1 beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( @@ -246,13 +236,14 @@ class Kumaraswamy(beta.Beta): name="nan") is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration1.dtype), self.concentration1, message="Mode undefined for concentration1 <= 1."), check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration0.dtype), self.concentration0, message="Mode undefined for concentration0 <= 1.") ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 473677f8d91b184e029f345bb05f5c5d63df7a40..02e3bad51ee48188acf83cb09359861c9e6932c7 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation class Logistic(distribution.Distribution): @@ -91,6 +92,14 @@ class Logistic(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -119,8 +128,8 @@ class Logistic(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() - with ops.name_scope(name, values=[loc, scale]): + parameters = dict(locals()) + with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") @@ -185,9 +194,6 @@ class Logistic(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return -nn_ops.softplus(-self._z(x)) diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index cef6a143fc615901315a3780bf4ed53b8c7cd177..3b7114ef067c0aaede23fff04c40d1dc6e830f1c 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Mixture(distribution.Distribution): @@ -66,6 +67,14 @@ class Mixture(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, cat, components, @@ -116,7 +125,7 @@ class Mixture(distribution.Distribution): matching static batch shapes, or all components do not have matching static event shapes. """ - parameters = locals() + parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) @@ -145,7 +154,7 @@ class Mixture(distribution.Distribution): "none of the components provide a static number of ndims") # Ensure that all batch and event ndims are consistent. - with ops.name_scope(name, values=[cat.logits]): + with ops.name_scope(name, values=[cat.logits]) as name: num_components = cat.event_size static_num_components = tensor_util.constant_value(num_components) if static_num_components is None: diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index b93bdc5ab4010663baddda1410b302644853648b..8ffee940d03c9a5204f2ac6f7acd9ea482adae1a 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class MixtureSameFamily(distribution.Distribution): @@ -95,6 +96,14 @@ class MixtureSameFamily(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mixture_distribution, components_distribution, @@ -130,8 +139,8 @@ class MixtureSameFamily(distribution.Distribution): ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ - parameters = locals() - with ops.name_scope(name): + parameters = dict(locals()) + with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] @@ -321,6 +330,14 @@ class MixtureSameFamily(distribution.Distribution): return x +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" z = x - y diff --git a/tensorflow/contrib/distributions/python/ops/moving_stats.py b/tensorflow/contrib/distributions/python/ops/moving_stats.py index 20f85643b9e7db61b4786dffe4115c7d3c00b046..87d40805a3c7a9c2871305af7f7182b7e2923530 100644 --- a/tensorflow/contrib/distributions/python/ops/moving_stats.py +++ b/tensorflow/contrib/distributions/python/ops/moving_stats.py @@ -47,9 +47,7 @@ def assign_moving_mean_variance( Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-1 mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Args: mean_var: `float`-like `Variable` representing the exponentially weighted @@ -72,6 +70,12 @@ def assign_moving_mean_variance( TypeError: if `mean_var` does not have float type `dtype`. TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ with ops.name_scope(name, "assign_moving_mean_variance", [variance_var, mean_var, value, decay]): @@ -183,9 +187,7 @@ def moving_mean_variance(value, decay, collections=None, name=None): Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-`1` mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Unlike `assign_moving_mean_variance`, this function handles variable creation. @@ -208,6 +210,12 @@ def moving_mean_variance(value, decay, collections=None, name=None): Raises: TypeError: if `value_var` does not have float type `dtype`. TypeError: if `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index e862552880f4073c8fa8e90134d0633e7484b0bf..cd0c282ba6cebf784261a4e821f36ce4eed98fe0 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -193,8 +202,8 @@ class MultivariateNormalDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() - with ops.name_scope(name): + parameters = dict(locals()) + with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): # No need to validate_args while making diag_scale. The returned @@ -218,14 +227,22 @@ class MultivariateNormalDiag( class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale_diag, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagWithSoftplusScale"): - parameters = locals() - with ops.name_scope(name, values=[scale_diag]): + parameters = dict(locals()) + with ops.name_scope(name, values=[scale_diag]) as name: super(MultivariateNormalDiagWithSoftplusScale, self).__init__( loc=loc, scale_diag=nn.softplus(scale_diag), diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 413e88f03ae0286c294f3404549a73e1a47dcff7..d8401801f21afbe8fd042053c6a38a31a2539438 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -141,6 +142,14 @@ class MultivariateNormalDiagPlusLowRank( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -215,10 +224,10 @@ class MultivariateNormalDiagPlusLowRank( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) - with ops.name_scope(name): + with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier, scale_perturb_factor, scale_perturb_diag]): diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 4bea99fbb75349f97fde473cb5716fe6c426ce90..dbc4c1b3dc956641f3e38ffafe3a3410bd3e2097 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -45,7 +46,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): The probability density function (pdf) is, with `@` as matrix multiplication, ```none - pdf(x; loc, covariance_matrix) = exp(-0.5 ||y||**2) / Z, + pdf(x; loc, covariance_matrix) = exp(-0.5 y) / Z, y = (x - loc)^T @ inv(covariance_matrix) @ (x - loc) Z = (2 pi)**(0.5 k) |det(covariance_matrix)|**(0.5). ``` @@ -54,8 +55,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): * `loc` is a vector in `R^k`, * `covariance_matrix` is an `R^{k x k}` symmetric positive definite matrix, - * `Z` denotes the normalization constant, and, - * `||y||**2` denotes the squared Euclidean norm of `y`. + * `Z` denotes the normalization constant. Additional leading dimensions (if any) in `loc` and `covariance_matrix` allow for batch dimensions. @@ -113,6 +113,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, covariance_matrix=None, @@ -156,10 +164,10 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ - parameters = locals() + parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. - with ops.name_scope(name): + with ops.name_scope(name) as name: with ops.name_scope("init", values=[loc, covariance_matrix]): if covariance_matrix is None: scale_tril = None diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index a7399792892f4c179c05168184d76ec95c168b51..efe5a6d0d99ca8fa9e0274049423bb3c4eef2d6f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -27,6 +27,7 @@ from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -133,6 +134,14 @@ class MultivariateNormalLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -170,13 +179,13 @@ class MultivariateNormalLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: raise TypeError("`scale` parameter must have floating-point dtype.") - with ops.name_scope(name, values=[loc] + scale.graph_parents): + with ops.name_scope(name, values=[loc] + scale.graph_parents) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc @@ -266,6 +275,14 @@ class MultivariateNormalLinearOperator( @kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, MultivariateNormalLinearOperator) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 6c7dc4ca7aaf5b3a20b072e9360d15528ad10556..d9110947ecdbba1a63669573f46db17b02e512ab 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalTriL( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_tril=None, @@ -179,12 +188,12 @@ class MultivariateNormalTriL( Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ - parameters = locals() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: raise ValueError("Must specify one or both of `loc`, `scale_tril`.") - with ops.name_scope(name): + with ops.name_scope(name) as name: with ops.name_scope("init", values=[loc, scale_tril]): loc = _convert_to_tensor(loc, name="loc") scale_tril = _convert_to_tensor(scale_tril, name="scale_tril") diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index 3a58df80da6c02b056f5e5a63bf41de5fc6d44a4..6acfc5746a0cc20e916de81b71f90e08d8d91ad5 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class NegativeBinomial(distribution.Distribution): @@ -51,6 +52,14 @@ class NegativeBinomial(distribution.Distribution): * `n!` is the factorial of `n`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, @@ -90,8 +99,8 @@ class NegativeBinomial(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[total_count, logits, probs]): + parameters = dict(locals()) + with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) with ops.control_dependencies( diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index b76cebf79fad09ebec68f2459c6fe80794ea81c0..0c762f17c9b770ecada57b6ce60a4825ba374dd9 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class OneHotCategorical(distribution.Distribution): @@ -52,7 +53,7 @@ class OneHotCategorical(distribution.Distribution): #### Examples - Creates a 3-class distiribution, with the 2nd class, the most likely to be + Creates a 3-class distribution, with the 2nd class, the most likely to be drawn from. ```python @@ -60,7 +61,7 @@ class OneHotCategorical(distribution.Distribution): dist = OneHotCategorical(probs=p) ``` - Creates a 3-class distiribution, with the 2nd class the most likely to be + Creates a 3-class distribution, with the 2nd class the most likely to be drawn from, using logits. ```python @@ -83,6 +84,14 @@ class OneHotCategorical(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, logits=None, @@ -115,8 +124,8 @@ class OneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[logits, probs]): + parameters = dict(locals()) + with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, multidimensional=True) @@ -203,9 +212,6 @@ class OneHotCategorical(distribution.Distribution): ret = array_ops.reshape(ret, logits_shape) return ret - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _entropy(self): return -math_ops.reduce_sum( nn_ops.log_softmax(self.logits) * self.probs, axis=-1) @@ -236,6 +242,14 @@ class OneHotCategorical(distribution.Distribution): @kullback_leibler.RegisterKL(OneHotCategorical, OneHotCategorical) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_categorical_categorical(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical. diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index e967dcc90d0712ffc346fb61ee67c44a6d9207cb..3d055085cc7386e57a71aa310458b7666bb9a396 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Poisson", @@ -35,9 +36,15 @@ __all__ = [ _poisson_sample_note = """ -Note that the input value must be a non-negative floating point tensor with -dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only -legal if it is non-negative and its components are equal to integer values. +The Poisson distribution is technically only defined for non-negative integer +values. When `validate_args=False`, non-integral inputs trigger an assertion. + +When `validate_args=False` calculations are otherwise unchanged despite +integral or non-integral inputs. + +When `validate_args=False`, evaluating the pmf at non-integral values, +corresponds to evaluations of an unnormalized distribution, that does not +correspond to evaluations of the cdf. """ @@ -59,6 +66,14 @@ class Poisson(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, rate=None, log_rate=None, @@ -87,8 +102,8 @@ class Poisson(distribution.Distribution): TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ - parameters = locals() - with ops.name_scope(name, values=[rate]): + parameters = dict(locals()) + with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") elif log_rate is None: @@ -150,10 +165,6 @@ class Poisson(distribution.Distribution): def _cdf(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - else: - # Whether or not x is integer-form, the following is well-defined. - # However, scipy takes the floor, so we do too. - x = math_ops.floor(x) return math_ops.igammac(1. + x, self.rate) def _log_normalization(self): @@ -162,9 +173,6 @@ class Poisson(distribution.Distribution): def _log_unnormalized_prob(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - else: - # For consistency with cdf, we take the floor. - x = math_ops.floor(x) return x * self.log_rate - math_ops.lgamma(1. + x) def _mean(self): diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 92f2bba1828696248c9d9460566a08ba372c3358..7a7ad1be35b80ff0f000181ea0778ab282a8220f 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -33,6 +33,7 @@ from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -42,6 +43,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_gauss_hermite( loc, scale, quadrature_size, validate_args=False, name=None): # pylint: disable=unused-argument @@ -85,6 +94,14 @@ def quadrature_scheme_lognormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_quantiles( loc, scale, quadrature_size, validate_args=False, name=None): @@ -114,7 +131,7 @@ def quadrature_scheme_lognormal_quantiles( # Create a LogNormal distribution. dist = transformed_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=loc, scale=scale), - bijector=Exp(event_ndims=0), + bijector=Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: @@ -214,6 +231,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): validate_args=True) """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -255,8 +280,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ - parameters = locals() - with ops.name_scope(name, values=[loc, scale]): + parameters = dict(locals()) + with ops.name_scope(name, values=[loc, scale]) as name: if loc is not None: loc = ops.convert_to_tensor(loc, name="loc") if scale is not None: @@ -417,6 +442,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 8aebb79b9138cce1373e6472d17cf9072d2bc285..ef3bdfa75fcaa8df17db1238ceadadf788601356 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -27,10 +27,19 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distributions from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = ["QuantizedDistribution"] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _logsum_expbig_minus_expsmall(big, small): """Stable evaluation of `Log[exp{big} - exp{small}]`. @@ -128,7 +137,7 @@ The base distribution's `log_cdf` method must be defined on `y - 1`. class QuantizedDistribution(distributions.Distribution): """Distribution representing the quantization `Y = ceiling(X)`. - #### Definition in terms of sampling. + #### Definition in Terms of Sampling ``` 1. Draw X @@ -138,7 +147,7 @@ class QuantizedDistribution(distributions.Distribution): 5. Return Y ``` - #### Definition in terms of the probability mass function. + #### Definition in Terms of the Probability Mass Function Given scalar random variable `X`, we define a discrete random variable `Y` supported on the integers as follows: @@ -170,14 +179,72 @@ class QuantizedDistribution(distributions.Distribution): `P[Y = j]` is still the mass of `X` within the `jth` interval. - #### Caveats + #### Examples + + We illustrate a mixture of discretized logistic distributions + [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit + audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in + a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures + `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints. + The lowest value has probability `P(X <= 0.5)` and the highest value has + probability `P(2**16 - 1.5 < X)`. + + Below we assume a `wavenet` function. It takes as `input` right-shifted audio + samples of shape `[..., sequence_length]`. It returns a real-valued tensor of + shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and + `scale` parameter belonging to the logistic distribution, and a `logits` + parameter determining the unnormalized probability of that component. + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + net = wavenet(inputs) + loc, unconstrained_scale, logits = tf.split(net, + num_or_size_splits=3, + axis=-1) + scale = tf.nn.softplus(unconstrained_scale) + + # Form mixture of discretized logistic distributions. Note we shift the + # logistic distribution by -0.5. This lets the quantization capture "rounding" + # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`. + discretized_logistic_dist = tfd.QuantizedDistribution( + distribution=tfd.TransformedDistribution( + distribution=tfd.Logistic(loc=loc, scale=scale), + bijector=tfb.AffineScalar(shift=-0.5)), + low=0., + high=2**16 - 1.) + mixture_dist = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=discretized_logistic_dist) + + neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets)) + train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood) + ``` + + After instantiating `mixture_dist`, we illustrate maximum likelihood by + calculating its log-probability of audio samples as `target` and optimizing. + + #### References - Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than - a closed form function such as for a Poisson), computations such as mean and - entropy are better done with samples or approximations, and are not - implemented by this class. + [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma. + PixelCNN++: Improving the PixelCNN with discretized logistic mixture + likelihood and other modifications. + _International Conference on Learning Representations_, 2017. + https://arxiv.org/abs/1701.05517 + [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, low=None, @@ -213,11 +280,11 @@ class QuantizedDistribution(distributions.Distribution): `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ - parameters = locals() + parameters = dict(locals()) values = ( list(distribution.parameters.values()) + [low, high]) - with ops.name_scope(name, values=values): + with ops.name_scope(name, values=values) as name: self._dist = distribution if low is not None: diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index b525809015537ac8c7ee701c100fba6541fe2e92..7e1f64dc425e6a576bfbe1bb456901fddfac26e1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic +from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class RelaxedBernoulli(transformed_distribution.TransformedDistribution): @@ -35,10 +36,10 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): The RelaxedBernoulli is a distribution over the unit interval (0,1), which continuously approximates a Bernoulli. The degree of approximation is - controlled by a temperature: as the temperaturegoes to 0 the RelaxedBernoulli - becomes discrete with a distribution described by the `logits` or `probs` - parameters, as the temperature goes to infinity the RelaxedBernoulli - becomes the constant distribution that is identically 0.5. + controlled by a temperature: as the temperature goes to 0 the + RelaxedBernoulli becomes discrete with a distribution described by the + `logits` or `probs` parameters, as the temperature goes to infinity the + RelaxedBernoulli becomes the constant distribution that is identically 0.5. The RelaxedBernoulli distribution is a reparameterized continuous distribution that is the binary special case of the RelaxedOneHotCategorical @@ -131,6 +132,14 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Gumbel-Softmax. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, temperature, logits=None, @@ -165,8 +174,8 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ - parameters = locals() - with ops.name_scope(name, values=[logits, probs, temperature]): + parameters = dict(locals()) + with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 2aa771a71efe52c8d86d459f090ea8ee137c4487..9b5bd7576f2a3c364e21da76dd3905a8c6e35829 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class ExpRelaxedOneHotCategorical(distribution.Distribution): @@ -125,6 +126,14 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, @@ -162,8 +171,8 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[logits, probs, temperature]): + parameters = dict(locals()) + with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, @@ -285,9 +294,6 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): ret = array_ops.reshape(log_prob, logits_shape) return ret - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _assert_valid_sample(self, x): if not self.validate_args: return x @@ -306,7 +312,7 @@ class RelaxedOneHotCategorical( The RelaxedOneHotCategorical is a distribution over random probability vectors, vectors of positive real values that sum to one, which continuously approximates a OneHotCategorical. The degree of approximation is controlled by - a temperature: as the temperaturegoes to 0 the RelaxedOneHotCategorical + a temperature: as the temperature goes to 0 the RelaxedOneHotCategorical becomes discrete with a distribution described by the `logits` or `probs` parameters, as the temperature goes to infinity the RelaxedOneHotCategorical becomes the constant distribution that is identically the constant vector of @@ -371,6 +377,14 @@ class RelaxedOneHotCategorical( A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, @@ -412,5 +426,5 @@ class RelaxedOneHotCategorical( validate_args=validate_args, allow_nan_stats=allow_nan_stats) super(RelaxedOneHotCategorical, self).__init__(dist, - bijectors.Exp(event_ndims=1), + bijectors.Exp(), name=name) diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index dfc813361977c159d8d48f9d5b9ff03db5b4acdc..f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -301,13 +302,16 @@ def percentile(x, with ops.name_scope(name, [x, q]): x = ops.convert_to_tensor(x, name="x") - q = math_ops.to_float(q, name="q") + # Double is needed here and below, else we get the wrong index if the array + # is huge along axis. + q = math_ops.to_double(q, name="q") _get_static_ndims(q, expect_ndims=0) if validate_args: q = control_flow_ops.with_dependencies([ - check_ops.assert_rank(q, 0), check_ops.assert_greater_equal(q, 0.), - check_ops.assert_less_equal(q, 100.) + check_ops.assert_rank(q, 0), + check_ops.assert_greater_equal(q, math_ops.to_double(0.)), + check_ops.assert_less_equal(q, math_ops.to_double(100.)) ], q) if axis is None: @@ -332,7 +336,7 @@ def percentile(x, y = _move_dims_to_flat_end(x, axis, x_ndims) frac_at_q_or_above = 1. - q / 100. - d = math_ops.to_float(array_ops.shape(y)[-1]) + d = math_ops.to_double(array_ops.shape(y)[-1]) if interpolation == "lower": index = math_ops.ceil((d - 1) * frac_at_q_or_above) @@ -341,12 +345,18 @@ def percentile(x, elif interpolation == "nearest": index = math_ops.round((d - 1) * frac_at_q_or_above) + # If d is gigantic, then we would have d == d - 1, even in double... So + # let's use max/min to avoid out of bounds errors. + d = array_ops.shape(y)[-1] + # d - 1 will be distinct from d in int32. + index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1) + # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) # result.shape = B - result = sorted_y[..., math_ops.to_int32(index)] + result = sorted_y[..., index] result.set_shape(y.get_shape()[:-1]) if keep_dims: diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..cf505ac627b62ae0a3d1ec1ce2a237c3c2ff1b74 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -0,0 +1,228 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Local PRNG for amplifying seed entropy into seeds for base operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib + + +class SeedStream(object): + """Local PRNG for amplifying seed entropy into seeds for base operations. + + Writing sampling code which correctly sets the pseudo-random number + generator (PRNG) seed is surprisingly difficult. This class serves as + a helper for the TensorFlow Probability coding pattern designed to + avoid common mistakes. + + # Motivating Example + + A common first-cut implementation of a sampler for the beta + distribution is to compute the ratio of a gamma with itself plus + another gamma. This code snippet tries to do that, but contains a + surprisingly common error: + + ```python + def broken_beta(shape, alpha, beta, seed): + x = tf.random_gamma(shape, alpha, seed=seed) + y = tf.random_gamma(shape, beta, seed=seed) + return x / (x + y) + ``` + + The mistake is that the two gamma draws are seeded with the same + seed. This causes them to always produce the same results, which, + in turn, leads this code snippet to always return `0.5`. Because it + can happen across abstraction boundaries, this kind of error is + surprisingly easy to make when handling immutable seeds. + + # Goals + + TensorFlow Probability adopts a code style designed to eliminate the + above class of error, without exacerbating others. The goals of + this code style are: + + - Support reproducibility of results (by encouraging seeding of all + pseudo-random operations). + + - Avoid shared-write global state (by not relying on a global PRNG). + + - Prevent accidental seed reuse by TF Probability implementers. This + goal is served with the local pseudo-random seed generator provided + in this module. + + - Mitigate potential accidental seed reuse by TF Probability clients + (with a salting scheme). + + - Prevent accidental resonances with downstream PRNGs (by hashing the + output). + + ## Non-goals + + - Implementing a high-performance PRNG for generating large amounts of + entropy. That's the job of the underlying TensorFlow PRNG we are + seeding. + + - Avoiding random seed collisions, aka "birthday attacks". + + # Code pattern + + ```python + def random_beta(shape, alpha, beta, seed): # (a) + seed = SeedStream(seed, salt="random_beta") # (b) + x = tf.random_gamma(shape, alpha, seed=seed()) # (c) + y = tf.random_gamma(shape, beta, seed=seed()) # (c) + return x / (x + y) + ``` + + The elements of this pattern are: + + - Accept an explicit seed (line a) as an argument in all public + functions, and write the function to be deterministic (up to any + numerical issues) for fixed seed. + + - Rationale: This provides the client with the ability to reproduce + results. Accepting an immutable seed rather than a mutable PRNG + object reduces code coupling, permitting different sections to be + reproducible independently. + + - Use that seed only to initialize a local `SeedStream` instance (line b). + + - Rationale: Avoids accidental seed reuse. + + - Supply the name of the function being implemented as a salt to the + `SeedStream` instance (line b). This serves to keep the salts + unique; unique salts ensure that clients of TF Probability will see + different functions always produce independent results even if + called with the same seeds. + + - Seed each callee operation with the output of a unique call to the + `SeedStream` instance (lines c). This ensures reproducibility of + results while preventing seed reuse across callee invocations. + + # Why salt? + + Salting the `SeedStream` instances (with unique salts) is defensive + programming against a client accidentally committing a mistake + similar to our motivating example. Consider the following situation + that might arise without salting: + + ```python + def tfp_foo(seed): + seed = SeedStream(seed, salt="") + foo_stuff = tf.random_normal(seed=seed()) + ... + + def tfp_bar(seed): + seed = SeedStream(seed, salt="") + bar_stuff = tf.random_normal(seed=seed()) + ... + + def client_baz(seed): + foo = tfp_foo(seed=seed) + bar = tfp_bar(seed=seed) + ... + ``` + + The client should have used different seeds as inputs to `foo` and + `bar`. However, because they didn't, *and because `foo` and `bar` + both sample a Gaussian internally as their first action*, the + internal `foo_stuff` and `bar_stuff` will be the same, and the + returned `foo` and `bar` will not be independent, leading to subtly + incorrect answers from the client's simulation. This kind of bug is + particularly insidious for the client, because it depends on a + Distributions implementation detail, namely the order in which `foo` + and `bar` invoke the samplers they depend on. In particular, a + Bayesflow team member can introduce such a bug in previously + (accidentally) correct client code by performing an internal + refactoring that causes this operation order alignment. + + A salting discipline eliminates this problem by making sure that the + seeds seen by `foo`'s callees will differ from those seen by `bar`'s + callees, even if `foo` and `bar` are invoked with the same input + seed. + """ + + def __init__(self, seed, salt): + """Initializes a `SeedStream`. + + Args: + seed: Any Python object convertible to string, supplying the + initial entropy. If `None`, operations seeded with seeds + drawn from this `SeedStream` will follow TensorFlow semantics + for not being seeded. + salt: Any Python object convertible to string, supplying + auxiliary entropy. Must be unique across the Distributions + and TensorFlow Probability code base. See class docstring for + rationale. + """ + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed + self._salt = salt + self._counter = 0 + + def __call__(self): + """Returns a fresh integer usable as a seed in downstream operations. + + If this `SeedStream` was initialized with `seed=None`, returns + `None`. This has the effect that downstream operations (both + `SeedStream`s and primitive TensorFlow ops) will behave as though + they were unseeded. + + The returned integer is non-negative, and uniformly distributed in + the half-open interval `[0, 2**512)`. This is consistent with + TensorFlow, as TensorFlow operations internally use the residue of + the given seed modulo `2**31 - 1` (see + `tensorflow/python/framework/random_seed.py`). + + Returns: + seed: A fresh integer usable as a seed in downstream operations, + or `None`. + """ + self._counter += 1 + if self._seed is None: + return None + composite = str((self._seed, self._counter, self._salt)).encode("utf-8") + return int(hashlib.sha512(composite).hexdigest(), 16) + + @property + def original_seed(self): + return self._seed + + @property + def salt(self): + return self._salt + +# Design rationales for the SeedStream class +# +# - Salts are accepted for the reason given above to supply them. +# +# - A `None` seed propagates to downstream seeds, so they exhibit +# their "unseeded" behavior. +# +# - The return value is a Python int so it can be passed directly to +# TensorFlow operations as a seed. It is large to avoid losing seed +# space needlessly (TF will internally read only the last 31 bits). +# +# - The output is hashed with a crypto-grade hash function as a form +# of defensive programming: this reliably prevents all possible +# accidental resonances with all possible downstream PRNGs. The +# specific function used is not important; SHA512 was ready to hand. +# +# - The internal state update is a simple counter because (a) given +# that the output is hashed anyway, this is enough, and (b) letting +# it be this predictable permits a future "generate many seeds in +# parallel" operation whose results would agree with running +# sequentially. diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 5fb6f0c7eaa8c4734ea4c161b0eee6f24d4c9850..4f348be2806aa3ade7c1ea2a7bc68ca26db6447f 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -27,50 +27,56 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _DistributionShape(object): """Manage and manipulate `Distribution` shape. - Terminology: - Recall that a `Tensor` has: - - `shape`: size of `Tensor` dimensions, - - `ndims`: size of `shape`; number of `Tensor` dimensions, - - `dims`: indexes into `shape`; useful for transpose, reduce. + #### Terminology - `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, - `batch_dims`, and `event_dims`. To understand the semantics of these - dimensions, consider when two of the three are fixed and the remaining - is varied: - - `sample_dims`: indexes independent draws from identical - parameterizations of the `Distribution`. - - `batch_dims`: indexes independent draws from non-identical - parameterizations of the `Distribution`. - - `event_dims`: indexes event coordinates from one sample. + Recall that a `Tensor` has: + - `shape`: size of `Tensor` dimensions, + - `ndims`: size of `shape`; number of `Tensor` dimensions, + - `dims`: indexes into `shape`; useful for transpose, reduce. - The `sample`, `batch`, and `event` dimensions constitute the entirety of a - `Distribution` `Tensor`'s shape. + `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, + `batch_dims`, and `event_dims`. To understand the semantics of these + dimensions, consider when two of the three are fixed and the remaining + is varied: + - `sample_dims`: indexes independent draws from identical + parameterizations of the `Distribution`. + - `batch_dims`: indexes independent draws from non-identical + parameterizations of the `Distribution`. + - `event_dims`: indexes event coordinates from one sample. - The dimensions are always in `sample`, `batch`, `event` order. + The `sample`, `batch`, and `event` dimensions constitute the entirety of a + `Distribution` `Tensor`'s shape. - Purpose: - This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into - `Distribution` notions of `sample,` `batch,` and `event` dimensions. That - is, it computes any of: + The dimensions are always in `sample`, `batch`, `event` order. - ``` - sample_shape batch_shape event_shape - sample_dims batch_dims event_dims - sample_ndims batch_ndims event_ndims - ``` + #### Purpose + + This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into + `Distribution` notions of `sample,` `batch,` and `event` dimensions. That + is, it computes any of: + + ``` + sample_shape batch_shape event_shape + sample_dims batch_dims event_dims + sample_ndims batch_ndims event_ndims + ``` + + for a given `Tensor`, e.g., the result of + `Distribution.sample(sample_shape=...)`. + + For a given `Tensor`, this class computes the above table using minimal + information: `batch_ndims` and `event_ndims`. - for a given `Tensor`, e.g., the result of - `Distribution.sample(sample_shape=...)`. + #### Examples - For a given `Tensor`, this class computes the above table using minimal - information: `batch_ndims` and `event_ndims`. + We show examples of distribution shape semantics. - Examples of `Distribution` `shape` semantics: - Sample dimensions: Computing summary statistics, i.e., the average is a reduction over sample dimensions. @@ -111,54 +117,64 @@ class _DistributionShape(object): tf.div(1., tf.reduce_prod(x, event_dims)) ``` - Examples using this class: - Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. - - ```python - # 150 iid samples from one multivariate Normal with two degrees of freedom. - mu = [0., 0] - sigma = [[1., 0], - [0, 1]] - mvn = MultivariateNormal(mu, sigma) - rand_mvn = mvn.sample(sample_shape=[3, 50]) - shaper = DistributionShape(batch_ndims=0, event_ndims=1) - S, B, E = shaper.get_shape(rand_mvn) - # S = [3, 50] - # B = [] - # E = [2] - - # 12 iid samples from one Wishart with 2x2 events. - sigma = [[1., 0], - [2, 1]] - wishart = Wishart(df=5, scale=sigma) - rand_wishart = wishart.sample(sample_shape=[3, 4]) - shaper = DistributionShape(batch_ndims=0, event_ndims=2) - S, B, E = shaper.get_shape(rand_wishart) - # S = [3, 4] - # B = [] - # E = [2, 2] - - # 100 iid samples from two, non-identical trivariate Normal distributions. - mu = ... # shape(2, 3) - sigma = ... # shape(2, 3, 3) - X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) - # S = [4, 25] - # B = [2] - # E = [3] - ``` - - Argument Validation: - When `validate_args=False`, checks that cannot be done during - graph construction are performed at graph execution. This may result in a - performance degradation because data must be switched from GPU to CPU. - - For example, when `validate_args=False` and `event_ndims` is a - non-constant `Tensor`, it is checked to be a non-negative integer at graph - execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` - arguments are always checked for correctness since this can be done for - "free," i.e., during graph construction. + We show examples using this class. + + Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. + + ```python + # 150 iid samples from one multivariate Normal with two degrees of freedom. + mu = [0., 0] + sigma = [[1., 0], + [0, 1]] + mvn = MultivariateNormal(mu, sigma) + rand_mvn = mvn.sample(sample_shape=[3, 50]) + shaper = DistributionShape(batch_ndims=0, event_ndims=1) + S, B, E = shaper.get_shape(rand_mvn) + # S = [3, 50] + # B = [] + # E = [2] + + # 12 iid samples from one Wishart with 2x2 events. + sigma = [[1., 0], + [2, 1]] + wishart = Wishart(df=5, scale=sigma) + rand_wishart = wishart.sample(sample_shape=[3, 4]) + shaper = DistributionShape(batch_ndims=0, event_ndims=2) + S, B, E = shaper.get_shape(rand_wishart) + # S = [3, 4] + # B = [] + # E = [2, 2] + + # 100 iid samples from two, non-identical trivariate Normal distributions. + mu = ... # shape(2, 3) + sigma = ... # shape(2, 3, 3) + X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) + # S = [4, 25] + # B = [2] + # E = [3] + ``` + + #### Argument Validation + + When `validate_args=False`, checks that cannot be done during + graph construction are performed at graph execution. This may result in a + performance degradation because data must be switched from GPU to CPU. + + For example, when `validate_args=False` and `event_ndims` is a + non-constant `Tensor`, it is checked to be a non-negative integer at graph + execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` + arguments are always checked for correctness since this can be done for + "free," i.e., during graph construction. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batch_ndims=None, event_ndims=None, @@ -432,7 +448,7 @@ class _DistributionShape(object): if self._batch_ndims_is_0 and expand_batch_dim: squeeze_dims += [1] if squeeze_dims: - x = array_ops.squeeze(x, squeeze_dims=squeeze_dims) + x = array_ops.squeeze(x, axis=squeeze_dims) # x.shape: [prod(S)]+B+E _, batch_shape, event_shape = self.get_shape(x) else: diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index c4b8f055b7fbc3f0835b503eddd7617610326d8c..a9d0fb4ccfb1803873f7fe17089f3e7c7f10f4b7 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", @@ -94,6 +95,14 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -132,9 +141,10 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) - with ops.name_scope(name, values=[loc, scale, skewness, tailweight]): + with ops.name_scope(name, + values=[loc, scale, skewness, tailweight]) as name: loc = ops.convert_to_tensor(loc, name="loc") dtype = loc.dtype scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype) @@ -166,21 +176,20 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): # Make the SAS bijector, 'F'. f = bijectors.SinhArcsinh( - skewness=skewness, tailweight=tailweight, event_ndims=0) + skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = bijectors.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), - tailweight=tailweight, event_ndims=0) + tailweight=tailweight) - # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) + # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) - affine = bijectors.Affine( + affine = bijectors.AffineScalar( shift=loc, - scale_identity_multiplier=c, - validate_args=validate_args, - event_ndims=0) + scale=c, + validate_args=validate_args) bijector = bijectors.Chain([affine, f]) diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..c25e8c51d7705b641699fb05623c7b0fb4950e1b --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -0,0 +1,910 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Statistical test assertions calibrated for their error rates. + +Statistical tests have an inescapable probability of error: a correct +sampler can still fail a test by chance, and an incorrect sampler can +still pass a test by chance. This library is about bounding both of +those error rates. This requires admitting a task-specific notion of +"discrepancy": Correct code will fail rarely, code that misbehaves by +more than the discrepancy will pass rarely, and nothing reliable can +be said about code that misbehaves, but misbehaves by less than the +discrepancy. + +# Example + +Consider testing that the mean of a scalar probability distribution P +is some expected constant. Suppose the support of P is the interval +`[0, 1]`. Then you might do this: + +```python +tfd = tf.contrib.distributions + +expected_mean = ... +num_samples = 5000 +samples = ... draw 5000 samples from P + +# Check that the mean looks right +check1 = tfd.assert_true_mean_equal_by_dkwm( + samples, low=0., high=1., expected=expected_mean, + false_fail_rate=1e-6) + +# Check that the difference in means detectable with 5000 samples is +# small enough +check2 = tf.assert_less( + tfd.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=1.0, + false_fail_rate=1e-6, false_pass_rate=1e-6), + 0.01) + +# Be sure to execute both assertion ops +sess.run([check1, check2]) +``` + +The second assertion is an instance of experiment design. It's a +deterministic computation (independent of the code under test) that +checks that `5000` samples is enough to reliably resolve mean +differences of `0.01` or more. Here "reliably" means that if the code +under test is correct, the probability of drawing an unlucky sample +that causes this test to fail is at most 1e-6; and if the code under +test is incorrect enough that its true mean is 0.01 more or less than +expected, then the probability of drawing a "lucky" sample that causes +the test to false-pass is also at most 1e-6. + +# Overview + +Every function in this library can be characterized in terms of: + +- The property being tested, such as the full density of the + distribution under test, or just its true mean, or a single + Bernoulli probability, etc. + +- The relation being asserted, e.g., whether the mean is less, more, + or equal to the given expected value. + +- The stochastic bound being relied upon, such as the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + or the CDF of the binomial distribution (for assertions about + Bernoulli probabilities). + +- The number of sample sets in the statistical test. For example, + testing equality of means has a one-sample variant, where the + expected mean is given exactly, and a two-sample variant, where the + expected mean is itself given by a set of samples (e.g., from an + alternative algorithm). + +- What operation(s) of the test are to be performed. Each test has + three of these: + + 1. `assert` executes the test. Specifically, it creates a TF op that + produces an error if it has enough evidence to prove that the + property under test is violated. These functions depend on the + desired false failure rate, because that determines the sizes of + appropriate confidence intervals, etc. + + 2. `min_discrepancy` computes the smallest difference reliably + detectable by that test, given the sample count and error rates. + What it's a difference of is test-specific. For example, a test + for equality of means would make detection guarantees about the + difference the true means. + + 3. `min_num_samples` computes the minimum number of samples needed + to reliably detect a given discrepancy with given error rates. + + The latter two are for experimental design, and are meant to be + usable either interactively or inline in the overall test method. + +This library follows a naming convention, to make room for every +combination of the above. A name mentions the operation first, then +the property, then the relation, then the bound, then, if the test +takes more than one set of samples, a token indicating this. For +example, `assert_true_mean_equal_by_dkwm` (which is implicitly +one-sample). Each name is a grammatically sound noun phrase (or verb +phrase, for the asserts). + +# Asymptotic properties + +The number of samples needed tends to scale as `O(1/discrepancy**2)` and +as `O(log(1/error_rate))`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + +__all__ = [ + "true_mean_confidence_interval_by_dkwm", + "assert_true_mean_equal_by_dkwm", + "min_discrepancy_of_true_means_detectable_by_dkwm", + "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_in_interval_by_dkwm", + "assert_true_mean_equal_by_dkwm_two_sample", + "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", + "min_num_samples_for_dkwm_mean_two_sample_test", +] + + +def _batch_sort_vector(x, ascending=True, name=None): + with ops.name_scope(name, "_batch_sort_vector", [x]): + x = ops.convert_to_tensor(x, name="x") + n = array_ops.shape(x)[-1] + if ascending: + y, _ = nn_ops.top_k(-x, k=n, sorted=True) + y = -y + else: + y, _ = nn_ops.top_k(x, k=n, sorted=True) + y.set_shape(x.shape) + return y + + +def _do_maximum_mean(samples, envelope, high, name=None): + """Common code between maximum_mean and minimum_mean.""" + with ops.name_scope(name, "do_maximum_mean", [samples, envelope, high]): + n = array_ops.rank(samples) + # Move the batch dimension of `samples` to the rightmost position, + # where the _batch_sort_vector function wants it. + perm = array_ops.concat([math_ops.range(1, n), [0]], axis=0) + samples = array_ops.transpose(samples, perm) + + samples = _batch_sort_vector(samples) + + # The maximum mean is given by taking `envelope`-worth of + # probability from the smallest samples and moving it to the + # maximum value. This amounts to: + # - ignoring the smallest k samples, where `k/n < envelope` + # - taking a `1/n - (envelope - k/n)` part of the index k sample + # - taking all the other samples + # - and adding `envelope * high` at the end. + # The following is a vectorized and batched way of computing this. + # `max_mean_contrib` is a mask implementing the previous. + batch_size = array_ops.shape(samples)[-1] + batch_size = math_ops.cast(batch_size, dtype=samples.dtype.base_dtype) + step = 1. / batch_size + cum_steps = step * math_ops.range( + 1, batch_size + 1, dtype=samples.dtype.base_dtype) + max_mean_contrib = clip_ops.clip_by_value( + cum_steps - envelope[..., array_ops.newaxis], + clip_value_min=0., + clip_value_max=step) + return math_ops.reduce_sum( + samples * max_mean_contrib, axis=-1) + envelope * high + + +def _maximum_mean(samples, envelope, high, name=None): + """Returns a stochastic upper bound on the mean of a scalar distribution. + + The idea is that if the true CDF is within an `eps`-envelope of the + empirical CDF of the samples, and the support is bounded above, then + the mean is bounded above as well. In symbols, + + ```none + sup_x(|F_n(x) - F(x)|) < eps + ``` + + The 0th dimension of `samples` is interpreted as independent and + identically distributed samples. The remaining dimensions are + broadcast together with `envelope` and `high`, and operated on + separately. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `envelope` and `high`. + envelope: Floating-point `Tensor` of sizes of admissible CDF + envelopes (i.e., the `eps` above). + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. `samples <= high`. + name: A name for this operation (optional). + + Returns: + bound: Floating-point `Tensor` of upper bounds on the true means. + + Raises: + InvalidArgumentError: If some `sample` is found to be larger than + the corresponding `high`. + """ + with ops.name_scope(name, "maximum_mean", [samples, envelope, high]): + samples = ops.convert_to_tensor(samples, name="samples") + envelope = ops.convert_to_tensor(envelope, name="envelope") + high = ops.convert_to_tensor(high, name="high") + + xmax = math_ops.reduce_max(samples, axis=[0]) + msg = "Given sample maximum value exceeds expectations" + check_op = check_ops.assert_less_equal(xmax, high, message=msg) + with ops.control_dependencies([check_op]): + return array_ops.identity(_do_maximum_mean(samples, envelope, high)) + + +def _minimum_mean(samples, envelope, low, name=None): + """Returns a stochastic lower bound on the mean of a scalar distribution. + + The idea is that if the true CDF is within an `eps`-envelope of the + empirical CDF of the samples, and the support is bounded below, then + the mean is bounded below as well. In symbols, + + ```none + sup_x(|F_n(x) - F(x)|) < eps + ``` + + The 0th dimension of `samples` is interpreted as independent and + identically distributed samples. The remaining dimensions are + broadcast together with `envelope` and `low`, and operated on + separately. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `envelope` and `low`. + envelope: Floating-point `Tensor` of sizes of admissible CDF + envelopes (i.e., the `eps` above). + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. `samples >= low`. + name: A name for this operation (optional). + + Returns: + bound: Floating-point `Tensor` of lower bounds on the true means. + + Raises: + InvalidArgumentError: If some `sample` is found to be smaller than + the corresponding `low`. + """ + with ops.name_scope(name, "minimum_mean", [samples, envelope, low]): + samples = ops.convert_to_tensor(samples, name="samples") + envelope = ops.convert_to_tensor(envelope, name="envelope") + low = ops.convert_to_tensor(low, name="low") + + xmin = math_ops.reduce_min(samples, axis=[0]) + msg = "Given sample minimum value falls below expectations" + check_op = check_ops.assert_greater_equal(xmin, low, message=msg) + with ops.control_dependencies([check_op]): + return - _do_maximum_mean(-samples, envelope, -low) + + +def _dkwm_cdf_envelope(n, error_rate, name=None): + """Computes the CDF envelope that the DKWM inequality licenses. + + The [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + gives a stochastic bound on the distance between the true cumulative + distribution function (CDF) of any distribution and its empirical + CDF. To wit, for `n` iid samples from any distribution with CDF F, + + ```none + P(sup_x |F_n(x) - F(x)| > eps) < 2exp(-2n eps^2) + ``` + + This function computes the envelope size `eps` as a function of the + number of samples `n` and the desired limit on the left-hand + probability above. + + Args: + n: `Tensor` of numbers of samples drawn. + error_rate: Floating-point `Tensor` of admissible rates of mistakes. + name: A name for this operation (optional). + + Returns: + eps: `Tensor` of maximum distances the true CDF can be from the + empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and + as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and + `error_rate`. + """ + with ops.name_scope(name, "dkwm_cdf_envelope", [n, error_rate]): + n = math_ops.cast(n, dtype=error_rate.dtype) + return math_ops.sqrt(-gen_math_ops.log(error_rate / 2.) / (2. * n)) + + +def _check_shape_dominates(samples, parameters): + """Check that broadcasting `samples` against `parameters` does not expand it. + + Why? Because I want to be very sure that the samples tensor is not + accidentally enlarged by broadcasting against tensors that are + supposed to be describing the distribution(s) sampled from, lest the + sample counts end up inflated. + + Args: + samples: A `Tensor` whose shape is to be protected against broadcasting. + parameters: A list of `Tensor`s who are parameters for the statistical test. + + Returns: + samples: Return original `samples` with control dependencies attached + to ensure no broadcasting. + """ + def check(t): + samples_batch_shape = array_ops.shape(samples)[1:] + broadcasted_batch_shape = array_ops.broadcast_dynamic_shape( + samples_batch_shape, array_ops.shape(t)) + # This rank check ensures that I don't get a wrong answer from the + # _shapes_ broadcasting against each other. + samples_batch_ndims = array_ops.size(samples_batch_shape) + ge = check_ops.assert_greater_equal( + samples_batch_ndims, array_ops.rank(t)) + eq = check_ops.assert_equal(samples_batch_shape, broadcasted_batch_shape) + return ge, eq + checks = list(itertools.chain(*[check(t) for t in parameters])) + with ops.control_dependencies(checks): + return array_ops.identity(samples) + + +def true_mean_confidence_interval_by_dkwm( + samples, low, high, error_rate=1e-6, name=None): + """Computes a confidence interval for the mean of a scalar distribution. + + In batch mode, computes confidence intervals for all distributions + in the batch (which need not be identically distributed). + + Relies on the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + The probability (over the randomness of drawing the given samples) + that any true mean is outside the corresponding returned interval is + no more than the given `error_rate`. The size of the intervals + scale as + `O(1 / sqrt(#samples))`, as `O(high - low)`, and as `O(-log(error_rate))`. + + Note that `error_rate` is a total error rate for all the confidence + intervals in the batch. As such, if the batch is nontrivial, the + error rate is not broadcast but divided (evenly) among the batch + members. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + error_rate: *Scalar* floating-point `Tensor` admissible total rate + of mistakes. + name: A name for this operation (optional). + + Returns: + low: A floating-point `Tensor` of stochastic lower bounds on the + true means. + high: A floating-point `Tensor` of stochastic upper bounds on the + true means. + """ + with ops.name_scope( + name, "true_mean_confidence_interval_by_dkwm", + [samples, low, high, error_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + error_rate = ops.convert_to_tensor(error_rate, name="error_rate") + samples = _check_shape_dominates(samples, [low, high]) + check_ops.assert_scalar(error_rate) # Static shape + error_rate = _itemwise_error_rate(error_rate, [low, high], samples) + n = array_ops.shape(samples)[0] + envelope = _dkwm_cdf_envelope(n, error_rate) + min_mean = _minimum_mean(samples, envelope, low) + max_mean = _maximum_mean(samples, envelope, high) + return min_mean, max_mean + + +def _itemwise_error_rate( + total_error_rate, param_tensors, sample_tensor=None, name=None): + with ops.name_scope( + name, "itemwise_error_rate", + [total_error_rate, param_tensors, sample_tensor]): + result_shape = [1] + for p_tensor in param_tensors: + result_shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(p_tensor), result_shape) + if sample_tensor is not None: + result_shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(sample_tensor)[1:], result_shape) + num_items = math_ops.reduce_prod(result_shape) + return total_error_rate / math_ops.cast( + num_items, dtype=total_error_rate.dtype) + + +def assert_true_mean_equal_by_dkwm( + samples, low, high, expected, false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is as expected. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the true mean of some distribution from which the given samples are + drawn is _not_ the given expected mean with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want to + check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected: Floating-point `Tensor` of expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean is + outside the corresponding confidence interval. + """ + with ops.name_scope( + name, "assert_true_mean_equal_by_dkwm", + [samples, low, high, expected, false_fail_rate]): + return assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected, expected, false_fail_rate) + + +def min_discrepancy_of_true_means_detectable_by_dkwm( + n, low, high, false_fail_rate, false_pass_rate, name=None): + """Returns the minimum mean discrepancy that a DKWM-based test can detect. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Note that `false_fail_rate` is a total false failure rate for all + the tests in the batch. As such, if the batch is nontrivial, each + member will demand more samples. The `false_pass_rate` is also + interpreted as a total, but is treated asymmetrically: If each test + in the batch detects its corresponding discrepancy with probability + at least `1 - false_pass_rate`, then running all those tests and + failing if any one fails will jointly detect all those discrepancies + with the same `false_pass_rate`. + + Args: + n: `Tensor` of numbers of samples to be drawn from the distributions + of interest. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. + name: A name for this operation (optional). + + Returns: + discr: `Tensor` of lower bounds on the distances between true + means detectable by a DKWM-based test. + + For each batch member `i`, of `K` total, drawing `n[i]` samples from + some scalar distribution supported on `[low[i], high[i]]` is enough + to detect a difference in means of size `discr[i]` or more. + Specifically, we guarantee that (a) if the true mean is the expected + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. + + The detectable discrepancy scales as + + - `O(high[i] - low[i])`, + - `O(1 / sqrt(n[i]))`, + - `O(-log(false_fail_rate/K))`, and + - `O(-log(false_pass_rate))`. + """ + with ops.name_scope( + name, "min_discrepancy_of_true_means_detectable_by_dkwm", + [n, low, high, false_fail_rate, false_pass_rate]): + n = ops.convert_to_tensor(n, name="n") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Algorithm: Assume a true CDF F. The DKWM inequality gives a + # stochastic bound on how far the observed empirical CDF F_n can be. + # Then, using the DKWM inequality again gives a stochastic bound on + # the farthest candidate true CDF F' that + # true_mean_confidence_interval_by_dkwm might consider. At worst, these + # errors may go in the same direction, so the distance between F and + # F' is bounded by the sum. + # On batching: false fail rates sum, so I need to reduce + # the input to account for the batching. False pass rates + # max, so I don't. + sampling_envelope = _dkwm_cdf_envelope(n, false_pass_rate) + false_fail_rate = _itemwise_error_rate(false_fail_rate, [n, low, high]) + analysis_envelope = _dkwm_cdf_envelope(n, false_fail_rate) + return (high - low) * (sampling_envelope + analysis_envelope) + + +def min_num_samples_for_dkwm_mean_test( + discrepancy, low, high, + false_fail_rate=1e-6, false_pass_rate=1e-6, name=None): + """Returns how many samples suffice for a one-sample DKWM mean test. + + To wit, returns an upper bound on the number of samples necessary to + guarantee detecting a mean difference of at least the given + `discrepancy`, with the given `false_fail_rate` and `false_pass_rate`, + using the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + on a scalar distribution supported on `[low, high]`. + + Args: + discrepancy: Floating-point `Tensor` of desired upper limits on mean + differences that may go undetected with probability higher than + `1 - false_pass_rate`. + low: `Tensor` of lower bounds on the distributions' support. + high: `Tensor` of upper bounds on the distributions' support. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. + name: A name for this operation (optional). + + Returns: + n: `Tensor` of numbers of samples to be drawn from the distributions + of interest. + + The `discrepancy`, `low`, and `high` tensors must have + broadcast-compatible shapes. + + For each batch member `i`, of `K` total, drawing `n[i]` samples from + some scalar distribution supported on `[low[i], high[i]]` is enough + to detect a difference in means of size `discrepancy[i]` or more. + Specifically, we guarantee that (a) if the true mean is the expected + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. + + The required number of samples scales + as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, + `O(-log(false_pass_rate))`, and `O(1 / discrepancy[i]**2)`. + """ + with ops.name_scope( + name, "min_num_samples_for_dkwm_mean_test", + [low, high, false_fail_rate, false_pass_rate, discrepancy]): + discrepancy = ops.convert_to_tensor( + discrepancy, name="discrepancy") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Could choose to cleverly allocate envelopes, but this is sound. + envelope1 = discrepancy / (2. * (high - low)) + envelope2 = envelope1 + false_fail_rate = _itemwise_error_rate( + false_fail_rate, [low, high, discrepancy]) + n1 = -math_ops.log(false_fail_rate / 2.) / (2. * envelope1**2) + n2 = -math_ops.log(false_pass_rate / 2.) / (2. * envelope2**2) + return math_ops.maximum(n1, n2) + + +def assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected_low, expected_high, + false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is in the given interval. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the mean of the distribution from which the given samples are + drawn is _outside_ the given interval with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want + to check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected_low: Floating-point `Tensor` of lower bounds on the + expected true means. + expected_high: Floating-point `Tensor` of upper bounds on the + expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean + interval does not overlap with the corresponding confidence + interval. + """ + with ops.name_scope( + name, "assert_true_mean_in_interval_by_dkwm", + [samples, low, high, expected_low, expected_high, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected_low = ops.convert_to_tensor(expected_low, name="expected_low") + expected_high = ops.convert_to_tensor(expected_high, name="expected_high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates( + samples, [low, high, expected_low, expected_high]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, false_fail_rate) + # Assert that the interval [min_mean, max_mean] intersects the + # interval [expected_low, expected_high]. This is true if + # max_mean >= expected_low and min_mean <= expected_high. + # By DeMorgan's law, that's also equivalent to + # not (max_mean < expected_low or min_mean > expected_high), + # which is a way of saying the two intervals are not disjoint. + check_confidence_interval_can_intersect = check_ops.assert_greater_equal( + max_mean, expected_low, message="Confidence interval does not " + "intersect: true mean smaller than expected") + with ops.control_dependencies([check_confidence_interval_can_intersect]): + return check_ops.assert_less_equal( + min_mean, expected_high, message="Confidence interval does not " + "intersect: true mean greater than expected") + + +def assert_true_mean_equal_by_dkwm_two_sample( + samples1, low1, high1, samples2, low2, high2, + false_fail_rate=1e-6, name=None): + """Asserts the means of the given distributions are equal. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the means of the distributions from which the given samples are + drawn are _not_ equal with statistical significance `false_fail_rate` + or stronger, otherwise passes. If you also want to check that you + are gathering enough evidence that a pass is not spurious, see + `min_num_samples_for_dkwm_mean_two_sample_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm_two_sample`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples1: Floating-point `Tensor` of samples from the + distribution(s) A. Entries are assumed IID across the 0th + dimension. The other dimensions must broadcast with `low1`, + `high1`, `low2`, and `high2`. + The support is bounded: `low1 <= samples1 <= high1`. + low1: Floating-point `Tensor` of lower bounds on the supports of the + distributions A. + high1: Floating-point `Tensor` of upper bounds on the supports of + the distributions A. + samples2: Floating-point `Tensor` of samples from the + distribution(s) B. Entries are assumed IID across the 0th + dimension. The other dimensions must broadcast with `low1`, + `high1`, `low2`, and `high2`. + The support is bounded: `low2 <= samples2 <= high2`. + low2: Floating-point `Tensor` of lower bounds on the supports of the + distributions B. + high2: Floating-point `Tensor` of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any pair of confidence + intervals true for corresponding true means do not overlap. + """ + with ops.name_scope( + name, "assert_true_mean_equal_by_dkwm_two_sample", + [samples1, low1, high1, samples2, low2, high2, false_fail_rate]): + samples1 = ops.convert_to_tensor(samples1, name="samples1") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + samples2 = ops.convert_to_tensor(samples2, name="samples2") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples1 = _check_shape_dominates(samples1, [low1, high1]) + samples2 = _check_shape_dominates(samples2, [low2, high2]) + compatible_samples = check_ops.assert_equal( + array_ops.shape(samples1)[1:], array_ops.shape(samples2)[1:]) + with ops.control_dependencies([compatible_samples]): + # Could in principle play games with cleverly allocating + # significance instead of the even split below. It may be possible + # to get tighter intervals, in order to obtain a higher power test. + # Any allocation strategy that depends only on the support bounds + # and sample counts should be valid; however, because the intervals + # scale as O(-log(false_fail_rate)), there doesn't seem to be much + # room to win. + min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( + samples2, low2, high2, false_fail_rate / 2.) + return assert_true_mean_in_interval_by_dkwm( + samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.) + + +def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + n1, low1, high1, n2, low2, high2, + false_fail_rate, false_pass_rate, name=None): + """Returns the minimum mean discrepancy for a two-sample DKWM-based test. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Note that `false_fail_rate` is a total false failure rate for all + the tests in the batch. As such, if the batch is nontrivial, each + member will demand more samples. The `false_pass_rate` is also + interpreted as a total, but is treated asymmetrically: If each test + in the batch detects its corresponding discrepancy with probability + at least `1 - false_pass_rate`, then running all those tests and + failing if any one fails will jointly detect all those discrepancies + with the same `false_pass_rate`. + + Args: + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + low1: Floating-point `Tensor` of lower bounds on the supports of the + distributions A. + high1: Floating-point `Tensor` of upper bounds on the supports of + the distributions A. + n2: `Tensor` of numbers of samples to be drawn from the distributions B. + low2: Floating-point `Tensor` of lower bounds on the supports of the + distributions B. + high2: Floating-point `Tensor` of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. + name: A name for this operation (optional). + + Returns: + discr: `Tensor` of lower bounds on the distances between true means + detectable by a two-sample DKWM-based test. + + For each batch member `i`, of `K` total, drawing `n1[i]` samples + from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` + samples from scalar distribution B supported on `[low2[i], high2[i]]` + is enough to detect a difference in their true means of size + `discr[i]` or more. Specifically, we guarantee that (a) if their + true means are equal, `assert_true_mean_equal_by_dkwm_two_sample` + will fail with probability at most `false_fail_rate/K` (which + amounts to `false_fail_rate` if applied to the whole batch at once), + and (b) if their true means differ by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm_two_sample` will pass with + probability at most `false_pass_rate`. + + The detectable distribution scales as + + - `O(high1[i] - low1[i])`, `O(high2[i] - low2[i])`, + - `O(1 / sqrt(n1[i]))`, `O(1 / sqrt(n2[i]))`, + - `O(-log(false_fail_rate/K))`, and + - `O(-log(false_pass_rate))`. + """ + with ops.name_scope( + name, "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", + [n1, low1, high1, n2, low2, high2, false_fail_rate, false_pass_rate]): + n1 = ops.convert_to_tensor(n1, name="n1") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + n2 = ops.convert_to_tensor(n2, name="n2") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + det_disc1 = min_discrepancy_of_true_means_detectable_by_dkwm( + n1, low1, high1, false_fail_rate / 2., false_pass_rate / 2.) + det_disc2 = min_discrepancy_of_true_means_detectable_by_dkwm( + n2, low2, high2, false_fail_rate / 2., false_pass_rate / 2.) + return det_disc1 + det_disc2 + + +def min_num_samples_for_dkwm_mean_two_sample_test( + discrepancy, low1, high1, low2, high2, + false_fail_rate=1e-6, false_pass_rate=1e-6, name=None): + """Returns how many samples suffice for a two-sample DKWM mean test. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Args: + discrepancy: Floating-point `Tensor` of desired upper limits on mean + differences that may go undetected with probability higher than + `1 - false_pass_rate`. + low1: Floating-point `Tensor` of lower bounds on the supports of the + distributions A. + high1: Floating-point `Tensor` of upper bounds on the supports of + the distributions A. + low2: Floating-point `Tensor` of lower bounds on the supports of the + distributions B. + high2: Floating-point `Tensor` of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. + name: A name for this operation (optional). + + Returns: + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + n2: `Tensor` of numbers of samples to be drawn from the distributions B. + + For each batch member `i`, of `K` total, drawing `n1[i]` samples + from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` + samples from scalar distribution B supported on `[low2[i], high2[i]]` + is enough to detect a difference in their true means of size + `discr[i]` or more. Specifically, we guarantee that (a) if their + true means are equal, `assert_true_mean_equal_by_dkwm_two_sample` + will fail with probability at most `false_fail_rate/K` (which + amounts to `false_fail_rate` if applied to the whole batch at once), + and (b) if their true means differ by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm_two_sample` will pass with + probability at most `false_pass_rate`. + + The required number of samples scales as + + - `O((high1[i] - low1[i])**2)`, `O((high2[i] - low2[i])**2)`, + - `O(-log(false_fail_rate/K))`, + - `O(-log(false_pass_rate))`, and + - `O(1 / discrepancy[i]**2)`. + """ + with ops.name_scope( + name, "min_num_samples_for_dkwm_mean_two_sample_test", + [low1, high1, low2, high2, + false_fail_rate, false_pass_rate, discrepancy]): + discrepancy = ops.convert_to_tensor(discrepancy, name="discrepancy") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Could choose to cleverly allocate discrepancy tolerances and + # failure probabilities, but this is sound. + n1 = min_num_samples_for_dkwm_mean_test( + discrepancy / 2., low1, high1, + false_fail_rate / 2., false_pass_rate / 2.) + n2 = min_num_samples_for_dkwm_mean_test( + discrepancy / 2., low2, high2, + false_fail_rate / 2., false_pass_rate / 2.) + return n1, n2 diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 0c747f8e68529484ae6f695b8500cde74857bb11..ece03fe4aab3cc3046e0958d883ca9388517b94b 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_gauss_hermite( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -111,6 +120,14 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_quantiles( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -181,7 +198,7 @@ def quadrature_scheme_softmaxnormal_quantiles( edges = array_ops.reshape(edges, shape=array_ops.concat([ [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) quantiles = dist.quantile(edges) - quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + quantiles = SoftmaxCentered().forward(quantiles) # Cyclically permute left by one. perm = array_ops.concat([ math_ops.range(1, 1 + batch_ndims), [0]], axis=0) @@ -248,11 +265,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of the quantiles of `p(z)` (generalized quantiles if `K > 2`). - See [1] for more details. - - [1]. "Quadrature Compound: An approximating family of distributions" - Joshua Dillon, Ian Langmore, arXiv preprints - https://arxiv.org/abs/1801.03080 + See [Dillon and Langmore (2018)][1] for more details. #### About `Vector` distributions in TensorFlow. @@ -313,8 +326,23 @@ class VectorDiffeomixture(distribution_lib.Distribution): is_positive_definite=True), ], validate_args=True) + ``` + + #### References + + [1]: Joshua Dillon and Ian Langmore. Quadrature Compound: An approximating + family of distributions. _arXiv preprint arXiv:1801.03080_, 2018. + https://arxiv.org/abs/1801.03080 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mix_loc, temperature, @@ -392,8 +420,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ - parameters = locals() - with ops.name_scope(name, values=[mix_loc, temperature]): + parameters = dict(locals()) + with ops.name_scope(name, values=[mix_loc, temperature]) as name: if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " @@ -424,7 +452,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._endpoint_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, - event_ndims=1, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale))] @@ -464,7 +491,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._interpolated_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, - event_ndims=1, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( @@ -618,9 +644,11 @@ class VectorDiffeomixture(distribution_lib.Distribution): log_prob = math_ops.reduce_sum(self.distribution.log_prob(y), axis=-2) # Because the affine transformation has a constant Jacobian, it is the case # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general. - fldj = array_ops.stack( - [aff.forward_log_det_jacobian(x) for aff in self.interpolated_affine], - axis=-1) + fldj = array_ops.stack([ + aff.forward_log_det_jacobian( + x, + event_ndims=array_ops.rank(self.event_shape_tensor()) + ) for aff in self.interpolated_affine], axis=-1) return math_ops.reduce_logsumexp( self.mixture_distribution.logits - fldj + log_prob, axis=-1) @@ -776,6 +804,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return array_ops.reshape(p, shape=expand_shape) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): @@ -809,6 +845,14 @@ def maybe_check_quadrature_param(param, name, validate_args): return param +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): @@ -847,6 +891,14 @@ def determine_batch_event_shapes(grid, endpoint_affine): return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: @@ -873,6 +925,14 @@ def interpolate_loc(grid, loc): return [x[..., k] for k in range(deg)] # list(shape:[B, e]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: @@ -889,6 +949,14 @@ def interpolate_scale(grid, scale): ])[0] for q in range(deg)] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def linop_scale(w, op): # We assume w > 0. (This assumption only relates to the is_* attributes.) with ops.name_scope("linop_scale", values=[w]): @@ -924,6 +992,14 @@ def linop_scale(w, op): "Unsupported Linop type ({})".format(type(op).__name__)) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] @@ -932,6 +1008,14 @@ def concat_vectors(*args): return [val for vec in args_ for val in vec] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -941,11 +1025,27 @@ def add(x, y): return x + y +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def softmax(x, axis, name=None): """Equivalent to tf.nn.softmax but works around b/70297725.""" with ops.name_scope(name, "softmax", [x, axis]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 526fe2d39aef9aed833b889de80e849c469435e7..73356a3625c9a1aa15af5b6c1cf2ccb0c514b39a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_exponential_linear_operator as vector_exponential_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -116,6 +117,14 @@ class VectorExponentialDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -175,8 +184,8 @@ class VectorExponentialDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() - with ops.name_scope(name): + parameters = dict(locals()) + with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): # No need to validate_args while making diag_scale. The returned diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 9d5fd9ac4178a1ae29b1ce32f304b22fd3d234dc..9a47b4855763a25b484ad04a3415d191f19256f7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = ["VectorExponentialLinearOperator"] @@ -138,6 +139,14 @@ class VectorExponentialLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -175,13 +184,13 @@ class VectorExponentialLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: raise TypeError("`scale` parameter must have floating-point dtype.") - with ops.name_scope(name, values=[loc] + scale.graph_parents): + with ops.name_scope(name, values=[loc] + scale.graph_parents) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 8dd983b750d9b39775e570800006011f4968f7f3..e68ddc569c95ff63760b4b2f6d7a92f17240a558 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_laplace_linear_operator as vector_laplace_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -151,6 +152,14 @@ class VectorLaplaceDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -210,7 +219,7 @@ class VectorLaplaceDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name): with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index ec485c95c15da2794b67d2699d2bdd9db97bb6c4..3923161a332a77e4eaab8d65d96fd8c278c872ec 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -154,6 +155,14 @@ class VectorLaplaceLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -191,7 +200,7 @@ class VectorLaplaceLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index e1ccf116457a97261b9ce3965552764771d3bdd2..49ffff24caec8d6c525f65f06796d10548d5ec40 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "VectorSinhArcsinhDiag", @@ -95,6 +96,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -163,13 +172,13 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope( name, values=[ loc, scale_diag, scale_identity_multiplier, skewness, tailweight - ]): + ]) as name: loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None @@ -215,19 +224,19 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): tailweight = ops.convert_to_tensor( tailweight, dtype=dtype, name="tailweight") f = bijectors.SinhArcsinh( - skewness=skewness, tailweight=tailweight, event_ndims=1) + skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = bijectors.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), - tailweight=tailweight, event_ndims=0) + tailweight=tailweight) # Make the Affine bijector, Z --> loc + C * Z. c = 2 * scale_diag_part / f_noskew.forward( ops.convert_to_tensor(2, dtype=dtype)) affine = bijectors.Affine( - shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) + shift=loc, scale_diag=c, validate_args=validate_args) bijector = bijectors.Chain([affine, f]) diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 8c67647a618d22a58428d78865c4ebf7d98bdf9e..f289b39e51aff36780541a0545ed9e6cfe21dd4e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation class _VectorStudentT(transformed_distribution.TransformedDistribution): @@ -66,7 +67,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): This distribution is an Affine transformation of iid [Student's t-distributions]( https://en.wikipedia.org/wiki/Student%27s_t-distribution) - and should not be confused with the [Multivate Student's t-distribution]( + and should not be confused with the [Multivariate Student's t-distribution]( https://en.wikipedia.org/wiki/Multivariate_t-distribution). The traditional Multivariate Student's t-distribution is type of [elliptical distribution]( @@ -121,6 +122,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, loc=None, @@ -175,10 +184,10 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] - with ops.name_scope(name): + with ops.name_scope(name) as name: with ops.name_scope("init", values=graph_parents): # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index e4ac65012b9c7e3ed5ada3ed75020f3905740156..f1accaaa4c920344608015c792a2c3606de1337f 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "WishartCholesky", @@ -73,6 +74,14 @@ class _WishartLinearOperator(distribution.Distribution): this class. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale_operator, @@ -107,9 +116,9 @@ class _WishartLinearOperator(distribution.Distribution): ValueError: if df < k, where scale operator event shape is `(k, k)` """ - parameters = locals() + parameters = dict(locals()) self._cholesky_input_output_matrices = cholesky_input_output_matrices - with ops.name_scope(name) as ns: + with ops.name_scope(name) as name: with ops.name_scope("init", values=[df, scale_operator]): if not scale_operator.dtype.is_floating: raise TypeError( @@ -163,7 +172,7 @@ class _WishartLinearOperator(distribution.Distribution): parameters=parameters, graph_parents=([self._df, self._dimension] + self._scale_operator.graph_parents), - name=ns) + name=name) @property def df(self): @@ -228,9 +237,12 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) + expanded_df = self.df * array_ops.ones( + self.scale_operator.batch_shape_tensor(), + dtype=self.df.dtype.base_dtype) g = random_ops.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( - 0.5 * self.df, self.dimension), + 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( @@ -498,6 +510,14 @@ class WishartCholesky(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -527,8 +547,8 @@ class WishartCholesky(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name, values=[scale]): + parameters = dict(locals()) + with ops.name_scope(name, values=[scale]) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) if validate_args: @@ -614,6 +634,14 @@ class WishartFull(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -643,8 +671,8 @@ class WishartFull(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() - with ops.name_scope(name) as ns: + parameters = dict(locals()) + with ops.name_scope(name) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) if validate_args: @@ -663,5 +691,5 @@ class WishartFull(_WishartLinearOperator): cholesky_input_output_matrices=cholesky_input_output_matrices, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - name=ns) + name=name) self._parameters = parameters diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 9d2ca07c3a25fa7acb9b0f5806b763d9a57b51fa..4384431e7b9c3e6ef259391fa9efa5a35d23c86a 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,12 +1,8 @@ # Eager Execution -> *WARNING*: This is a preview/pre-alpha version. The API and performance -> characteristics are subject to change. - -Eager execution is an experimental interface to TensorFlow that provides an -imperative programming style (à la [NumPy](http://www.numpy.org)). When you -enable eager execution, TensorFlow operations execute immediately; you do not -execute a pre-constructed graph with +Eager execution provides an imperative interface to TensorFlow (similar to +[NumPy](http://www.numpy.org)). When you enable eager execution, TensorFlow +operations execute immediately; you do not execute a pre-constructed graph with [`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session). For example, consider a simple computation in TensorFlow: @@ -33,7 +29,7 @@ print(m) ## Caveats This feature is in early stages and work remains to be done in terms of smooth -support for distributed and multi-GPU training and CPU performance. +support for distributed and multi-GPU training and performance. - [Known issues](https://github.com/tensorflow/tensorflow/issues?q=is%3Aissue%20is%3Aopen%20label%3Acomp%3Aeager) - Feedback is welcome, please consider @@ -41,21 +37,14 @@ support for distributed and multi-GPU training and CPU performance. ## Installation -Eager execution is included in TensorFlow versions 1.5 and above. +For eager execution, we recommend using TensorFlow version 1.8 or newer. Installation instructions at https://www.tensorflow.org/install/ ## Documentation For an introduction to eager execution in TensorFlow, see: -- [User Guide](python/g3doc/guide.md) +- [User Guide](https://www.tensorflow.org/programmers_guide/eager) ([source](../../docs_src/programmers_guide/eager.md)) - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) - -## Changelog - -- 2017/10/31: Initial preview release. -- 2017/12/01: Example of dynamic neural network: - [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). - See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD deleted file mode 100644 index aedfec8924e7314addd22349c0576a84a58d9aa3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/proto/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -tf_proto_library( - name = "checkpointable_object_graph_proto", - srcs = [ - "checkpointable_object_graph.proto", - ], - visibility = ["//tensorflow/contrib/eager/python:__subpackages__"], -) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index ad40e55cb48aac08eca7022846a0bd07b8accb3f..0cc764d2208c5b061b7b836bdf57a035f52c6fcf 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -17,16 +17,15 @@ py_library( ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:numerics", + "//tensorflow/python:gradients", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", "//tensorflow/python:template", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:core", - "//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", ], @@ -69,14 +68,19 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/contrib/data/python/ops:threadpool", + "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:training", "//tensorflow/python/data", "//tensorflow/python/eager:test", ], + tags = ["noguitar"], ) py_library( @@ -115,17 +119,18 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/summary:summary_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:summary_ops_v2", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", + "//tensorflow/python/training/checkpointable:base", ], ) @@ -135,11 +140,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":metrics", - "//tensorflow/contrib/summary:summary_ops", "//tensorflow/contrib/summary:summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -156,10 +161,10 @@ py_library( deps = [ ":datasets", ":metrics", - "//tensorflow/contrib/summary:summary_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:summary_ops_v2", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", "@six_archive//:six", @@ -218,62 +223,3 @@ py_test( "//tensorflow/python/eager:test", ], ) - -py_library( - name = "checkpointable_utils", - srcs = ["checkpointable_utils.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - ], -) - -py_test( - name = "checkpointable_utils_test", - srcs = ["checkpointable_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":checkpointable_utils", - ":network", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:layers_base", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - "@six_archive//:six", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py deleted file mode 100644 index 0506af391cf22ec89d5174f17208c8eb393ddc54..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ /dev/null @@ -1,435 +0,0 @@ -"""Utilities for working with Checkpointable objects.""" -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable as core_checkpointable -from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.training import saver as saver_lib - - -_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. - -# Keyword for identifying that the next bit of a checkpoint variable name is a -# slot name. Checkpoint names for slot variables look like: -# -# /<_OPTIMIZER_SLOTS_NAME>// -# -# Where is a full path from the checkpoint root to the -# variable being slotted for. -_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" -# Keyword for separating the path to an object from the name of an -# attribute in checkpoint names. Used like: -# /<_OBJECT_ATTRIBUTES_NAME>/ -_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" -# Key where the object graph proto is saved in a TensorBundle -_OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" - - -# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange -# or consolidating the implementation with get_variable. -def _default_getter(name, shape, dtype, initializer=None, - partition_info=None, **kwargs): - """A pared-down version of get_variable which does not reuse variables.""" - dtype = dtypes.as_dtype(dtype) - shape_object = tensor_shape.as_shape(shape) - with ops.init_scope(): - if initializer is None: - initializer, initializing_from_value = ( - variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access - name=name, shape=shape_object, dtype=dtype)) - else: - initializing_from_value = not callable(initializer) - # Same logic as get_variable - variable_dtype = dtype.base_dtype - if initializing_from_value: - if shape is not None: - raise ValueError("If initializer is a constant, do not specify shape.") - initial_value = initializer - else: - # Instantiate initializer if provided initializer is a type object. - if isinstance(initializer, type(init_ops.Initializer)): - initializer = initializer(dtype=dtype) - def initial_value(): - return initializer( - shape_object.as_list(), dtype=dtype, partition_info=partition_info) - return resource_variable_ops.ResourceVariable( - initial_value=initial_value, - name=name, - dtype=variable_dtype, - **kwargs - ) - - -def add_variable(checkpointable, name, shape=None, dtype=dtypes.float32, - initializer=None): - """Add a variable to a Checkpointable with no scope influence.""" - return checkpointable._add_variable_with_custom_getter( # pylint: disable=protected-access - name=name, shape=shape, dtype=dtype, - initializer=initializer, getter=_default_getter) - - -def _breadth_first_checkpointable_traversal(root_checkpointable): - """Find shortest paths to all variables owned by dependencies of root.""" - bfs_sorted = [] - to_visit = collections.deque([root_checkpointable]) - path_to_root = {root_checkpointable: ()} - while to_visit: - current_checkpointable = to_visit.popleft() - current_checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access - bfs_sorted.append(current_checkpointable) - for child_checkpointable in ( - current_checkpointable._checkpoint_dependencies): # pylint: disable=protected-access - if child_checkpointable.ref not in path_to_root: - path_to_root[child_checkpointable.ref] = ( - path_to_root[current_checkpointable] + (child_checkpointable,)) - to_visit.append(child_checkpointable.ref) - return bfs_sorted, path_to_root - - -def _escape_local_name(name): - # We need to support slashes in local names for compatibility, since this - # naming scheme is being patched in to things like Layer.add_variable where - # slashes were previously accepted. We also want to use slashes to indicate - # edges traversed to reach the variable, so we escape forward slashes in - # names. - return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) - .replace(r"/", _ESCAPE_CHAR + "S")) - - -def _object_prefix_from_path(path_to_root): - return "/".join( - (_escape_local_name(checkpointable.name) - for checkpointable in path_to_root)) - - -def _slot_variable_naming_for_optimizer(optimizer_path): - """Make a function for naming slot variables in an optimizer.""" - # Name slot variables: - # - # /<_OPTIMIZER_SLOTS_NAME>// - # - # where is exactly the checkpoint name used for the original - # variable, including the path from the checkpoint root and the local name in - # the object which owns it. Note that we only save slot variables if the - # variable it's slotting for is also being saved. - - optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) - - def _name_slot_variable(variable_path, slot_name): - """With an optimizer specified, name a slot variable.""" - return (variable_path - + optimizer_identifier - + _escape_local_name(slot_name)) - - return _name_slot_variable - - -def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): - """Gather and name slot variables.""" - non_slot_objects = list(checkpointable_objects) - slot_variables = {} - for checkpointable in non_slot_objects: - if isinstance(checkpointable, optimizer_lib.Optimizer): - naming_scheme = _slot_variable_naming_for_optimizer( - optimizer_path=object_names[checkpointable]) - slot_names = checkpointable.get_slot_names() - for slot_name in slot_names: - for original_variable_node_id, original_variable in enumerate( - non_slot_objects): - try: - slot_variable = checkpointable.get_slot( - original_variable, slot_name) - except AttributeError: - slot_variable = None - if slot_variable is None: - continue - slot_variable._maybe_initialize_checkpointable() # pylint: disable=protected-access - if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access - # TODO(allenl): Gather dependencies of slot variables. - raise NotImplementedError( - "Currently only variables with no dependencies can be saved as " - "slot variables. File a feature request if this limitation " - "bothers you.") - if slot_variable in node_ids: - raise NotImplementedError( - "A slot variable was re-used as a dependency of a " - "Checkpointable object. This is not currently allowed. File a " - "feature request if this limitation bothers you.") - checkpoint_name = naming_scheme( - variable_path=object_names[original_variable], - slot_name=slot_name) - object_names[slot_variable] = checkpoint_name - slot_variable_node_id = len(checkpointable_objects) - node_ids[slot_variable] = slot_variable_node_id - checkpointable_objects.append(slot_variable) - slot_variable_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph - .Object.SlotVariableReference( - slot_name=slot_name, - original_variable_node_id=original_variable_node_id, - slot_variable_node_id=slot_variable_node_id)) - slot_variables.setdefault(checkpointable, []).append( - slot_variable_proto) - return slot_variables - - -def _serialize_checkpointables( - checkpointable_objects, node_ids, object_names, slot_variables): - """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - named_saveables = {} - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - assert node_ids[checkpointable] == checkpoint_id - object_proto = object_graph_proto.nodes.add() - object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) - object_name = object_names[checkpointable] - for name, saveable in ( - checkpointable._gather_tensors_for_checkpoint().items()): # pylint: disable=protected-access - attribute = object_proto.attributes.add() - attribute.name = name - attribute.checkpoint_key = "%s/%s/%s" % ( - object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [saveable], convert_variable_to_tensor=False) - attribute.full_name, = saver_dict.keys() - named_saveables[attribute.checkpoint_key] = saveable - - for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access - child_proto = object_proto.children.add() - child_proto.node_id = node_ids[child.ref] - child_proto.local_name = child.name - - return named_saveables, object_graph_proto - - -def _serialize_object_graph(root_checkpointable): - """Determine checkpoint keys for variables and build a serialized graph. - - Non-slot variables are keyed based on a shortest path from the root saveable - to the object which owns the variable (i.e. the one which called - `Checkpointable._add_variable` to create it). - - Slot variables are keyed based on a shortest path to the variable being - slotted for, a shortest path to their optimizer, and the slot name. - - Args: - root_checkpointable: A `Checkpointable` object whose variables (including - the variables of dependencies, recursively) should be saved. - - Returns: - A tuple of (named_variables, object_graph_proto): - named_variables: A dictionary mapping names to variable objects. - object_graph_proto: A CheckpointableObjectGraph protocol buffer containing - the serialized object graph and variable references. - - Raises: - ValueError: If there are invalid characters in an optimizer's slot names. - """ - checkpointable_objects, path_to_root = ( - _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} - slot_variables = _serialize_slot_variables( - checkpointable_objects=checkpointable_objects, - node_ids=node_ids, - object_names=object_names) - return _serialize_checkpointables( - checkpointable_objects=checkpointable_objects, - node_ids=node_ids, - object_names=object_names, - slot_variables=slot_variables) - - -class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): - - def __init__(self, tensor, name): - spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name) - super(_NoRestoreSaveable, self).__init__(tensor, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - return control_flow_ops.no_op() - - -def save(file_prefix, root_checkpointable, checkpoint_number=None, - session=None): - """Save a training checkpoint. - - Args: - file_prefix: A prefix to use for the checkpoint filenames - (/path/to/directory/and_a_prefix). Names are generated based on this - prefix and the global step, if provided. - root_checkpointable: A Checkpointable object to save. The checkpoint - includes variables created by this object and any Checkpointable objects - it depends on. - checkpoint_number: An integer variable or Tensor, used to number - checkpoints. Typically this value is saved along with other variables in - training checkpoints, which will happen automatically if it was created by - `root_checkpointable` or one of its dependencies (via - `Checkpointable._add_variable`). - session: The session to evaluate variables in. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - - Returns: - The full path to the checkpoint. - """ - named_variables, serialized_graph = _serialize_object_graph( - root_checkpointable) - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - assert _OBJECT_GRAPH_PROTO_KEY not in named_variables - # TODO(allenl): Feed rather than embedding a constant. - named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( - tensor=constant_op.constant( - serialized_graph.SerializeToString(), dtype=dtypes.string), - name=_OBJECT_GRAPH_PROTO_KEY) - with ops.device("/device:CPU:0"): - save_path = saver_lib.Saver(var_list=named_variables).save( - sess=session, - save_path=file_prefix, - write_meta_graph=False, - global_step=checkpoint_number) - return save_path - - -class CheckpointLoadStatus(object): - """Checks the status of checkpoint loading.""" - - def __init__(self, checkpoint): - self._checkpoint = checkpoint - - def assert_consumed(self): - """Asserts that all objects in the checkpoint have been created/matched.""" - for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): - checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None) - if checkpointable is None: - raise AssertionError("Unresolved object in checkpoint: %s" % (node,)) - if checkpointable._update_uid < self._checkpoint.restore_uid: # pylint: disable=protected-access - raise AssertionError( - "Object not assigned a value from checkpoint: %s" % (node,)) - if self._checkpoint.slot_restorations: - # Sanity check; this collection should be clear if everything has been - # restored. - raise AssertionError("Unresolved slot restorations: %s" % ( - self._checkpoint.slot_restorations,)) - return self - - @property - def restore_ops(self): - """Operations to restore objects in the dependency graph.""" - return self._checkpoint.restore_ops - - -def restore(save_path, root_checkpointable, session=None): - """Restore a training checkpoint. - - Restores the values of variables created with `Checkpointable._add_variable` - in `root_checkpointable` and any objects that it tracks (transitive). Either - assigns values immediately if variables to restore have been created already, - or defers restoration until the variables are created. Dependencies added to - `root_checkpointable` after this call will be matched if they have a - corresponding object in the checkpoint. - - When building a graph, restorations are added to the graph but not run. A - session is required to retrieve checkpoint metadata. - - To disallow deferred loading, assert immediately that all checkpointed - variables have been matched to variable objects: - - ```python - restore(path, root).assert_consumed() - ``` - - An exception will be raised unless every object was matched and its variables - already exist. - - When graph building, `assert_consumed()` indicates that all of the restore ops - which will be created for this checkpoint have been created. They are - available in the `restore_ops` property of the status object: - - ```python - session.run(restore(path, root).assert_consumed().restore_ops) - ``` - - If the checkpoint has not been consumed completely, then the list of - `restore_ops` will grow as more objects are added to the dependency graph. - - Args: - save_path: The path to the checkpoint, as returned by `save` or - `tf.train.latest_checkpoint`. If None (as when there is no latest - checkpoint for `tf.train.latest_checkpoint` to return), does nothing. - root_checkpointable: The root of the object graph to restore. Variables to - restore need not have been created yet, but all dependencies on other - `Checkpointable` objects should already be declared. Objects in the - dependency graph are matched to objects in the checkpointed graph, and - matching objects have their variables restored (or the checkpointed values - saved for eventual restoration when the variable is created). - session: The session to retrieve metadata with. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - Returns: - A `CheckpointLoadStatus` object, which can be used to make assertions about - the status of checkpoint restoration and fetch restore ops. - """ - if save_path is None: - return - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - object_graph_string, = io_ops.restore_v2( - prefix=save_path, - tensor_names=[_OBJECT_GRAPH_PROTO_KEY], - shape_and_slices=[""], - dtypes=[dtypes.string], - name="object_graph_proto_read") - if session is not None: - object_graph_string = session.run(object_graph_string) - else: - object_graph_string = object_graph_string.numpy() - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - object_graph_proto.ParseFromString(object_graph_string) - checkpoint = core_checkpointable._Checkpoint( # pylint: disable=protected-access - object_graph_proto=object_graph_proto, - save_path=save_path) - core_checkpointable._CheckpointPosition( # pylint: disable=protected-access - checkpoint=checkpoint, proto_id=0).restore(root_checkpointable) - load_status = CheckpointLoadStatus(checkpoint) - return load_status diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py deleted file mode 100644 index 21ba6adc6a26a11783264be2e217373453224e79..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ /dev/null @@ -1,886 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import os -import unittest - -import six - -from tensorflow.contrib.eager.python import checkpointable_utils -from tensorflow.contrib.eager.python import network as network_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.layers import base -from tensorflow.python.layers import core -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable -from tensorflow.python.training import saver as core_saver -from tensorflow.python.training import training_util - - -class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - core.Dense.__init__(self, *args, **kwargs) - - def add_variable(self, name, shape, **kwargs): - # Calls both Checkpointable._add_variable and Layer.add_variable. Eventually - # Layer.add_variable should inherit from Checkpointable and simply call - # super and then do post-processing. - return checkpointable.Checkpointable._add_variable_with_custom_getter( - self, - name=name, - shape=shape, - getter=functools.partial(core.Dense.add_variable, self), - **kwargs) - - -# pylint: disable=not-callable -class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): - - def __setattr__(self, name, value): - if isinstance(value, base.Layer): - self.track_layer(value, name=name) - # Checkpointable is next in the method resolution order, so this will catch - # Checkpointable objects which aren't Layers. - super(CheckpointableNetwork, self).__setattr__(name, value) - - def track_layer(self, layer, name): - self._track_checkpointable(layer, name=name) - return super(CheckpointableNetwork, self).track_layer(layer) - - -class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): - - # NOTE: Copied from Optimizer with modifications to use add_variable - # for non-slot variables. These contortions are necessary to maintain - # checkpoint compatibility with variable.name based saving. - # TODO(allenl): Make this cleaner. - def _create_non_slot_variable(self, initial_value, name, colocate_with): - """Add an extra variable, not associated with a slot.""" - if context.in_graph_mode(): - graph = colocate_with.graph - else: - graph = None - - key = (name, graph) - v = self._non_slot_dict.get(key, None) - if v is None: - with ops.colocate_with(colocate_with): - def _variable_getter(name, shape, dtype, initializer): - del shape, dtype # not used, but there for compatibility - return variable_scope.variable( - name=name, initial_value=initializer, trainable=False) - - initial_value = ops.convert_to_tensor(initial_value) - v = self._add_variable_with_custom_getter( - name=name, - shape=initial_value.get_shape(), - initializer=initial_value, - getter=_variable_getter) - - self._non_slot_dict[key] = v - - return v - - -class NonLayerCheckpointable(checkpointable.Checkpointable): - - def __init__(self): - super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( - self, name="a_variable", shape=[]) - - -class MyNetwork(CheckpointableNetwork): - """A concrete Network for testing.""" - - def __init__(self): - super(MyNetwork, self).__init__() - self._named_dense = CheckpointableDenseLayer(1, use_bias=True) - self._via_track_layer = self.track_layer( - CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() - - def call(self, values): - return self._via_track_layer(self._named_dense(values)) - - -class Checkpoint(checkpointable.Checkpointable): - """A utility class which groups `Checkpointable` objects.""" - - def __init__(self, **kwargs): - super(Checkpoint, self).__init__() - for k, v in sorted(kwargs.items(), key=lambda item: item[0]): - setattr(self, k, v) - self._save_counter = None - - @property - def save_counter(self): - """An integer variable which starts at zero and is incremented on save. - - Used to number checkpoints. - - Returns: - The save counter variable. - """ - if self._save_counter is None: - # Initialized to 0 and incremented before saving. - self._save_counter = checkpointable_utils.add_variable( - self, name="save_counter", initializer=0, dtype=dtypes.int64) - return self._save_counter - - def save(self, file_prefix, session=None): - assign_op = self.save_counter.assign_add(1) - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - session.run(assign_op) - return checkpointable_utils.save( - file_prefix=file_prefix, - root_checkpointable=self, - checkpoint_number=self.save_counter, - session=session) - - def restore(self, save_path): - return checkpointable_utils.restore( - save_path=save_path, - root_checkpointable=self) - - -class InterfaceTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testAddVariable(self): - obj = NonLayerCheckpointable() - with self.assertRaisesRegexp(ValueError, "do not specify shape"): - checkpointable_utils.add_variable( - obj, name="shape_specified_twice", shape=[], initializer=1) - constant_initializer = checkpointable_utils.add_variable( - obj, name="constant_initializer", initializer=1) - with variable_scope.variable_scope("some_variable_scope"): - ones_initializer = checkpointable_utils.add_variable( - obj, - name="ones_initializer", - shape=[2], - initializer=init_ops.ones_initializer(dtype=dtypes.float32)) - bare_initializer = checkpointable_utils.add_variable( - obj, - name="bare_initializer", - shape=[2, 2], - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) - - # Even in graph mode, there are no naming conflicts between objects, only - # naming conflicts within an object. - other_duplicate = resource_variable_ops.ResourceVariable( - name="duplicate", initial_value=1.) - duplicate = checkpointable_utils.add_variable( - obj, name="duplicate", shape=[]) - with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): - checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) - - if context.in_graph_mode(): - self.evaluate(variables.global_variables_initializer()) - self.assertEqual("constant_initializer:0", constant_initializer.name) - self.assertEqual(1, self.evaluate(constant_initializer)) - self.assertEqual("some_variable_scope/ones_initializer:0", - ones_initializer.name) - self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) - self.assertAllEqual([[0., 0.], - [0., 0.]], self.evaluate(bare_initializer)) - self.assertEqual("a_variable:0", obj.a_variable.name) - self.assertEqual("duplicate:0", other_duplicate.name) - if context.in_graph_mode(): - # The .name attribute may be globally influenced, but the checkpoint name - # won't be (tested below). - self.assertEqual("duplicate_1:0", duplicate.name) - else: - # When executing eagerly, there's no uniquification of variable names. The - # checkpoint name will be the same. - self.assertEqual("duplicate:0", duplicate.name) - named_variables, _ = checkpointable_utils._serialize_object_graph(obj) - expected_checkpoint_names = ( - "a_variable/.ATTRIBUTES/VARIABLE_VALUE", - "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE", - "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE", - "duplicate/.ATTRIBUTES/VARIABLE_VALUE", - "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE", - ) - six.assertCountEqual( - self, expected_checkpoint_names, named_variables.keys()) - - def testInitNotCalled(self): - - class NoInit(checkpointable.Checkpointable): - - def __init__(self): - pass - - # __init__ for Checkpointable will be called implicitly. - checkpointable_utils.add_variable(NoInit(), "var", shape=[]) - - def testShapeDtype(self): - root = checkpointable.Checkpointable() - v1 = checkpointable_utils.add_variable( - root, name="v1", initializer=3., dtype=dtypes.float64) - self.assertEqual(dtypes.float64, v1.dtype) - v2 = checkpointable_utils.add_variable( - root, - name="v2", - shape=[3], - initializer=init_ops.ones_initializer, - dtype=dtypes.float64) - self.assertEqual(dtypes.float64, v2.dtype) - self.assertAllEqual([1., 1., 1.], self.evaluate(v2)) - - -class CheckpointingTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNamingWithOptimizer(self): - input_value = constant_op.constant([[3.]]) - network = MyNetwork() - # A nuisance Network using the same optimizer. Its slot variables should not - # go in the checkpoint, since it is never depended on. - other_network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = Checkpoint( - optimizer=optimizer, network=network, optimizer_step=optimizer_step) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value), - global_step=optimizer_step) - optimizer.minimize( - lambda: other_network(input_value), - global_step=optimizer_step) - else: - train_op = optimizer.minimize( - network(input_value), global_step=optimizer_step) - optimizer.minimize( - other_network(input_value), - global_step=optimizer_step) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - named_variables, serialized_graph = ( - checkpointable_utils._serialize_object_graph(root_checkpointable)) - expected_checkpoint_names = ( - # Created in the root node, so no prefix. - "optimizer_step", - # No name provided to track_checkpointable(), so the position is used - # instead (one-based). - "network/via_track_layer/kernel", - # track_checkpointable() with a name provided, so that's used - "network/_named_dense/kernel", - "network/_named_dense/bias", - # non-Layer dependency of the network - "network/_non_layer/a_variable", - # The optimizer creates two non-slot variables - "optimizer/beta1_power", - "optimizer/beta2_power", - # Slot variables - "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/m", - "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/v", - "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", - "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", - "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", - "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", - ) - suffix = "/.ATTRIBUTES/VARIABLE_VALUE" - expected_checkpoint_names = [ - name + suffix for name in expected_checkpoint_names] - six.assertCountEqual(self, expected_checkpoint_names, - named_variables.keys()) - # Check that we've mapped to the right variable objects (not exhaustive) - self.assertEqual( - "global_step:0", - named_variables["optimizer_step" + suffix].name) - self.assertEqual( - "my_network/checkpointable_dense_layer_1/kernel:0", - named_variables["network/via_track_layer/kernel" + suffix].name) - self.assertEqual( - "my_network/checkpointable_dense_layer/kernel:0", - named_variables["network/_named_dense/kernel" + suffix].name) - self.assertEqual( - "beta1_power:0", - named_variables["optimizer/beta1_power" + suffix].name) - self.assertEqual( - "beta2_power:0", - named_variables["optimizer/beta2_power" + suffix].name) - # Spot check the generated protocol buffers. - self.assertEqual("optimizer", - serialized_graph.nodes[0].children[1].local_name) - optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ - 1].node_id] - self.assertEqual("beta1_power", - optimizer_node.children[0].local_name) - self.assertEqual("beta1_power", - serialized_graph.nodes[optimizer_node.children[0].node_id] - .attributes[0].full_name) - self.assertEqual( - "my_network/checkpointable_dense_layer/kernel", - serialized_graph.nodes[optimizer_node.slot_variables[0] - .original_variable_node_id] - .attributes[0].full_name) - # We strip off the :0 suffix, as variable.name-based saving does. - self.assertEqual( - "my_network/checkpointable_dense_layer/kernel/Adam", - serialized_graph.nodes[optimizer_node.slot_variables[0] - .slot_variable_node_id] - .attributes[0].full_name) - self.assertEqual( - "my_network/checkpointable_dense_layer/kernel/Adam:0", - optimizer.get_slot( - var=named_variables["network/_named_dense/kernel" + suffix], - name="m").name) - self.assertEqual( - "network/_named_dense/kernel" + suffix, - serialized_graph.nodes[ - optimizer_node.slot_variables[0] - .original_variable_node_id].attributes[0].checkpoint_key) - self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) - self.assertEqual( - "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, - serialized_graph.nodes[ - optimizer_node.slot_variables[0] - .slot_variable_node_id].attributes[0].checkpoint_key) - - @test_util.run_in_graph_and_eager_modes() - def testSaveRestore(self): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Checkpoint(optimizer=optimizer, network=network) - input_value = constant_op.constant([[3.]]) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value)) - else: - train_op = optimizer.minimize(network(input_value)) - # TODO(allenl): Make initialization more pleasant when graph building. - root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) - m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") - self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - save_path = root_checkpointable.save(file_prefix=prefix) - self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) - optimizer_variables = self.evaluate(optimizer.variables()) - self.evaluate(state_ops.assign(m_bias_slot, [-2.])) - # Immediate restoration - status = root_checkpointable.restore(save_path=save_path).assert_consumed() - self.evaluate(status.restore_ops) - self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) - self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) - if context.in_graph_mode(): - return # Restore-on-create is only supported when executing eagerly - on_create_network = MyNetwork() - on_create_optimizer = CheckpointableAdam(0.001) - on_create_root = Checkpoint( - optimizer=on_create_optimizer, network=on_create_network) - # Deferred restoration - status = on_create_root.restore(save_path=save_path) - on_create_network(constant_op.constant([[3.]])) # create variables - self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) - self.assertAllEqual([42.], - self.evaluate( - on_create_network._named_dense.variables[1])) - on_create_m_bias_slot = on_create_optimizer.get_slot( - on_create_network._named_dense.variables[1], "m") - # Optimizer slot variables are created when the original variable is - # restored. - self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) - self.assertAllEqual(optimizer_variables[2:], - self.evaluate(on_create_optimizer.variables())) - on_create_optimizer._create_slots( - [resource_variable_ops.ResourceVariable([1.])]) - status.assert_consumed() - beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() - self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) - self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) - - def testDeferredRestorationUsageEager(self): - """An idiomatic eager execution example.""" - num_training_steps = 10 - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - for training_continuation in range(3): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Checkpoint( - optimizer=optimizer, network=network, - optimizer_step=training_util.get_or_create_global_step()) - root.restore(core_saver.latest_checkpoint(checkpoint_directory)) - for _ in range(num_training_steps): - # TODO(allenl): Use a Dataset and serialize/checkpoint it. - input_value = constant_op.constant([[3.]]) - optimizer.minimize( - lambda: network(input_value), # pylint: disable=cell-var-from-loop - global_step=root.optimizer_step) - root.save(file_prefix=checkpoint_prefix) - self.assertEqual((training_continuation + 1) * num_training_steps, - root.optimizer_step.numpy()) - - def testUsageGraph(self): - """Expected usage when graph building.""" - with context.graph_mode(): - num_training_steps = 10 - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - for training_continuation in range(3): - with ops.Graph().as_default(): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Checkpoint( - optimizer=optimizer, network=network, - global_step=training_util.get_or_create_global_step()) - input_value = constant_op.constant([[3.]]) - train_op = optimizer.minimize( - network(input_value), - global_step=root.global_step) - root.save_counter # pylint: disable=pointless-statement - init_op = variables.global_variables_initializer() - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) - with self.test_session(graph=ops.get_default_graph()) as session: - if checkpoint_path is None: - self.assertEqual(0, training_continuation) - session.run(init_op) - # Another alternative would be to run initializers automatically - # if no checkpoint is being loaded. This would make deferred - # loading a bit more useful with graph execution. - else: - status = checkpointable_utils.restore( - save_path=checkpoint_path, - root_checkpointable=root, - session=session).assert_consumed() - session.run(status.restore_ops) - for _ in range(num_training_steps): - session.run(train_op) - root.save(file_prefix=checkpoint_prefix, - session=session) - self.assertEqual((training_continuation + 1) * num_training_steps, - session.run(root.global_step)) - self.assertEqual(training_continuation + 1, - session.run(root.save_counter)) - - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - checkpointable_utils.add_variable( - root, name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testVariableNameEscaping(self): - suffix = "/.ATTRIBUTES/VARIABLE_VALUE" - self.assertEqual(r"a.Sb.Sc" + suffix, self._get_checkpoint_name(r"a/b/c")) - self.assertEqual(r"b" + suffix, self._get_checkpoint_name(r"b")) - self.assertEqual(r"c.S" + suffix, self._get_checkpoint_name(r"c/")) - self.assertEqual(r"d.S..S" + suffix, self._get_checkpoint_name(r"d/.S")) - self.assertEqual(r"d.S..ATTRIBUTES.Sf" + suffix, - self._get_checkpoint_name(r"d/.ATTRIBUTES/f")) - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - root.leaf = leaf - checkpointable_utils.add_variable(leaf, name="v", shape=[]) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - variable_name, = named_variables.keys() - self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name) - - @test_util.run_in_graph_and_eager_modes() - def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - # Dots are escaped, which avoids conflicts with reserved names. - root._track_checkpointable(leaf, name=".ATTRIBUTES") - checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - name, = named_variables.keys() - self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE") - - @test_util.run_in_graph_and_eager_modes() - def testLateDependencyTracking(self): - - class Dependency(checkpointable.Checkpointable): - - def build(self): - self.var = checkpointable_utils.add_variable( - self, "var", initializer=0.) - - class LateDependencies(checkpointable.Checkpointable): - - def add_dep(self): - self.dep = Dependency() - self.dep.build() - - original = LateDependencies() - original.add_dep() - self.evaluate(state_ops.assign(original.dep.var, 123.)) - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_path = checkpointable_utils.save(checkpoint_prefix, original) - load_into = LateDependencies() - status = checkpointable_utils.restore(save_path, load_into) - with self.assertRaises(AssertionError): - status.assert_consumed() - load_into.add_dep() - status.assert_consumed() - self.evaluate(status.restore_ops) - self.assertEqual(123., self.evaluate(load_into.dep.var)) - - @test_util.run_in_graph_and_eager_modes() - def testDepAfterVar(self): - - class Dependency(checkpointable.Checkpointable): - - def build(self): - self.var = checkpointable_utils.add_variable( - self, "var", initializer=0.) - - class DepAfterVar(checkpointable.Checkpointable): - - def add_dep(self): - dep = Dependency() - dep.build() - self.dep = dep - - dep_after_var = DepAfterVar() - dep_after_var.add_dep() - self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.)) - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_path = checkpointable_utils.save( - checkpoint_prefix, dep_after_var) - - loaded_dep_after_var = DepAfterVar() - status = checkpointable_utils.restore( - save_path, loaded_dep_after_var) - loaded_dep_after_var.add_dep() - status.assert_consumed() - self.evaluate(status.restore_ops) - self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testDeferredSlotRestoration(self): - checkpoint_directory = self.get_temp_dir() - - root = checkpointable.Checkpointable() - root.var = checkpointable_utils.add_variable( - root, name="var", initializer=0.) - optimizer = CheckpointableAdam(0.1) - if context.in_graph_mode(): - train_op = optimizer.minimize(root.var) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - else: - optimizer.minimize(root.var.read_value) - self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "no_slots"), root) - root.optimizer = optimizer - self.evaluate(state_ops.assign(root.var, 13.)) - self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), - 14.)) - slots_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "with_slots"), root) - new_root = checkpointable.Checkpointable() - # Load the slot-containing checkpoint (deferred), then immediately overwrite - # the non-slot variable (also deferred). - slot_status = checkpointable_utils.restore( - slots_path, new_root) - no_slot_status = checkpointable_utils.restore( - no_slots_path, new_root) - with self.assertRaises(AssertionError): - no_slot_status.assert_consumed() - new_root.var = checkpointable_utils.add_variable( - new_root, name="var", shape=[]) - no_slot_status.assert_consumed() - self.evaluate(no_slot_status.restore_ops) - self.assertEqual(12., self.evaluate(new_root.var)) - new_root.optimizer = CheckpointableAdam(0.1) - with self.assertRaisesRegexp(AssertionError, "beta1_power"): - slot_status.assert_consumed() - self.assertEqual(12., self.evaluate(new_root.var)) - if context.in_eager_mode(): - # Slot variables are only created with restoring initializers when - # executing eagerly. - self.assertEqual(14., self.evaluate( - new_root.optimizer.get_slot(name="m", var=new_root.var))) - else: - self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), - None) - if context.in_graph_mode(): - train_op = new_root.optimizer.minimize(new_root.var) - # The slot variable now exists; restore() didn't create it, but we should - # now have a restore op for it. - self.evaluate(slot_status.restore_ops) - self.assertEqual(14., self.evaluate( - new_root.optimizer.get_slot(name="m", var=new_root.var))) - self.evaluate(train_op) - else: - new_root.optimizer.minimize(new_root.var.read_value) - slot_status.assert_consumed() - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testOverlappingRestores(self): - checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep = checkpointable.Checkpointable() - save_root.dep.var = checkpointable_utils.add_variable( - save_root.dep, name="var", initializer=0.) - self.evaluate(state_ops.assign(save_root.dep.var, 12.)) - first_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "first"), save_root) - self.evaluate(state_ops.assign(save_root.dep.var, 13.)) - second_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "second"), save_root) - - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() - first_status = checkpointable_utils.restore( - first_path, first_root) - second_status = checkpointable_utils.restore( - second_path, second_root) - load_dep = checkpointable.Checkpointable() - load_dep.var = checkpointable_utils.add_variable( - load_dep, name="var", shape=[]) - first_root.dep = load_dep - first_status.assert_consumed() - self.evaluate(first_status.restore_ops) - self.assertEqual([], second_status.restore_ops) - self.assertEqual(12., self.evaluate(load_dep.var)) - second_root.dep = load_dep - second_status.assert_consumed() - self.evaluate(second_status.restore_ops) - self.assertEqual(13., self.evaluate(load_dep.var)) - - # Try again with the order of the restore() reversed. The last restore - # determines the final value. - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() - second_status = checkpointable_utils.restore( - second_path, second_root) - first_status = checkpointable_utils.restore( - first_path, first_root) - load_dep = checkpointable.Checkpointable() - load_dep.var = checkpointable_utils.add_variable( - load_dep, name="var", shape=[]) - first_root.dep = load_dep - first_status.assert_consumed() - self.assertEqual([], second_status.restore_ops) - self.evaluate(first_status.restore_ops) - self.assertEqual(12., self.evaluate(load_dep.var)) - second_root.dep = load_dep - second_status.assert_consumed() - self.evaluate(second_status.restore_ops) - self.assertEqual(12., self.evaluate(load_dep.var)) - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testAmbiguousLoad(self): - # Not OK to split one checkpoint object into two - checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() - dep_three = checkpointable.Checkpointable() - save_root.dep_one.dep_three = dep_three - save_root.dep_two.dep_three = dep_three - checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) - self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "ckpt"), save_root) - load_root = checkpointable.Checkpointable() - checkpointable_utils.restore(save_path, load_root) - load_root.dep_one = checkpointable.Checkpointable() - load_root.dep_two = checkpointable.Checkpointable() - load_root.dep_one.dep_three = checkpointable.Checkpointable() - with self.assertRaisesRegexp(AssertionError, - "resolved to different objects"): - load_root.dep_two.dep_three = checkpointable.Checkpointable() - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testObjectsCombined(self): - # Currently fine to load two checkpoint objects into one Python object - checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() - checkpointable_utils.add_variable( - save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) - checkpointable_utils.add_variable( - save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) - self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "ckpt"), save_root) - load_root = checkpointable.Checkpointable() - load_root.dep_one = checkpointable.Checkpointable() - load_root.dep_two = load_root.dep_one - v1 = checkpointable_utils.add_variable( - load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) - v2 = checkpointable_utils.add_variable( - load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) - status = checkpointable_utils.restore( - save_path, load_root).assert_consumed() - self.evaluate(status.restore_ops) - self.assertEqual(32., self.evaluate(v1)) - self.assertEqual(64., self.evaluate(v2)) - - @test_util.run_in_graph_and_eager_modes() - def testDependencyLoop(self): - # Note: this test creates garbage during eager execution because it - # purposefully creates a reference cycle. - first = checkpointable.Checkpointable() - second = checkpointable.Checkpointable() - first.second = second - second.first = first - first.v = checkpointable_utils.add_variable( - first, "v1", initializer=[3., 1., 4.]) - second.v = checkpointable_utils.add_variable( - second, "v2", initializer=[1., 1., 2., 3.]) - self.evaluate(variables.global_variables_initializer()) - checkpoint_directory = self.get_temp_dir() - save_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "ckpt"), first) - - # Test deferred loading - first_load = checkpointable.Checkpointable() - status = checkpointable_utils.restore(save_path, first_load) - second_load = checkpointable.Checkpointable() - first_load.second = second_load - second_load.first = first_load - with self.assertRaises(AssertionError): - status.assert_consumed() - first_load.v = checkpointable_utils.add_variable( - first_load, "v1", shape=[3]) - second_load.v = checkpointable_utils.add_variable( - second_load, "v2", shape=[4]) - status.assert_consumed() - self.evaluate(status.restore_ops) - self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) - self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) - - # Test loading when variables have already been created - self.evaluate(first_load.v.assign([2., 7., 1.])) - self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) - self.evaluate(second_load.v.assign([2., 7., 1., 8.])) - self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) - status = checkpointable_utils.restore( - save_path, first_load).assert_consumed() - self.evaluate(status.restore_ops) - self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) - self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) - - @test_util.run_in_graph_and_eager_modes() - def testRestoreOnAssign(self): - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_graph = ops.Graph() - with save_graph.as_default(), self.test_session(save_graph): - first = checkpointable.Checkpointable() - first.var1 = variable_scope.get_variable( - name="outside_var", initializer=0.) - first.var2 = variable_scope.get_variable( - name="blah", initializer=0.) - self.evaluate(first.var1.assign(4.)) - self.evaluate(first.var2.assign(8.)) - save_path = checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=first) - restore_graph = ops.Graph() - with restore_graph.as_default(), self.test_session(restore_graph): - second = checkpointable.Checkpointable() - second.var2 = variable_scope.get_variable( - name="blah", initializer=0.) - status = checkpointable_utils.restore( - save_path, root_checkpointable=second) - recreated_var1 = variable_scope.get_variable( - name="outside_var", initializer=0.) - self.evaluate(status.restore_ops) - self.assertEqual(8., self.evaluate(second.var2)) - self.evaluate(recreated_var1.assign(-2.)) - self.assertEqual(-2., self.evaluate(recreated_var1)) - second.var1 = recreated_var1 - self.evaluate(status.restore_ops) - self.assertEqual(4., self.evaluate(recreated_var1)) - - # TODO(allenl): Saver class that doesn't pollute the graph with constants. - @unittest.skip("todo") - def testManySavesGraph(self): - """Saves after the first should not modify the graph.""" - with context.graph_mode(): - graph = ops.Graph() - with graph.as_default(), self.test_session(graph): - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() - obj.var = variable_scope.get_variable(name="v", initializer=0.) - obj.opt = CheckpointableAdam(0.1) - obj.opt.minimize(obj.var.read_value()) - self.evaluate(variables.global_variables_initializer()) - checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=obj) - before_ops = graph.get_operations() - checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=obj) - self.assertEqual(before_ops, graph.get_operations()) - - @unittest.skip("todo") - def testManyRestoresGraph(self): - """Restores after the first should not modify the graph.""" - with context.graph_mode(): - graph = ops.Graph() - with graph.as_default(), self.test_session(graph): - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() - obj.var = variable_scope.get_variable(name="v", initializer=0.) - obj.opt = CheckpointableAdam(0.1) - obj.opt.minimize(obj.var.read_value()) - self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=obj) - checkpointable_utils.restore( - save_path, root_checkpointable=obj) - before_ops = graph.get_operations() - checkpointable_utils.restore( - save_path, root_checkpointable=obj) - self.assertEqual(before_ops, graph.get_operations()) - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index d177bfeab2d1fdc05d7ced54df8723fae2c77fdb..adf92c27ea0a27c5741bcdd175b277462cb28d02 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -27,11 +27,12 @@ from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 _uid_lock = threading.Lock() @@ -45,8 +46,13 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) -class Iterator(object): - """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" +class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): + """An iterator producing tf.Tensor objects from a tf.data.Dataset. + + NOTE: Unlike the iterator created by the + @{tf.data.Dataset.make_one_shot_iterator} method, this class enables + additional experimental functionality, such as prefetching to the GPU. + """ def __init__(self, dataset): """Creates a new iterator over the given dataset. @@ -65,39 +71,21 @@ class Iterator(object): dataset: A `tf.data.Dataset` object. Raises: + TypeError: If `dataset` is an unsupported type. RuntimeError: When invoked without eager execution enabled. """ + if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access + raise TypeError( + "`tf.contrib.data.prefetch_to_device()` is not compatible with " + "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " + "over the dataset instead.") - if not context.in_eager_mode(): - raise RuntimeError( - "{} objects can only be used when eager execution is enabled, use " - "tf.data.Dataset.make_iterator or " - "tf.data.Dataset.make_one_shot_iterator for graph construction". - format(type(self))) - with ops.device("/device:CPU:0"): - ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._output_types, self._output_classes)) - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.iterator( - shared_name="", - container=_generate_shared_name("eageriterator"), - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device="/device:CPU:0") - self._device = context.context().device_name - self._buffer_resource_handle = None + super(Iterator, self).__init__(dataset) if not context.context().device_spec.device_type: is_remote_device = False else: is_remote_device = context.context().device_spec.device_type != "CPU" + self._buffer_resource_handle = None if is_remote_device: with ops.device("/device:CPU:0"): iter_string_handle = gen_dataset_ops.iterator_to_string_handle( @@ -106,7 +94,7 @@ class Iterator(object): @function.Defun(dtypes.string) def remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, self._output_types, self._output_shapes) + h, self.output_types, self.output_shapes, self.output_classes) return remote_iterator.get_next() remote_fn.add_to_graph(None) @@ -117,96 +105,54 @@ class Iterator(object): f=remote_fn, target_device=target, buffer_size=10, - thread_pool_size=1, container="", - shared_name=_generate_shared_name("function_buffer_resource")) + shared_name=_generate_shared_name( + "contrib_eager_iterator_function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long handle=self._buffer_resource_handle, handle_device=self._device) - def __iter__(self): - return self - - def __next__(self): # For Python 3 compatibility - return self.next() - def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ - with ops.device(self._device): + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): if self._buffer_resource_handle is not None: - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) + with ops.device(self._device): + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) else: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` - # because in eager mode this code will run synchronously on the calling - # thread. Therefore we do not need to make a defensive context switch - # to a background thread, and can achieve a small constant performance - # boost by invoking the iterator synchronously. - ret = gen_dataset_ops.iterator_get_next_sync( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - - def next(self): - """Returns a nested structure of `tf.Tensor`s containing the next element. - """ - try: - return self._next_internal() - except errors.OutOfRangeError: - raise StopIteration - - @property - def output_classes(self): - """Returns the class of each component of an element of this iterator. - - The expected values are `tf.Tensor` and `tf.SparseTensor`. + return super(Iterator, self)._next_internal() - Returns: - A nested structure of Python `type` objects corresponding to each - component of an element of this dataset. - """ - return self._output_classes + # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset + # attributes(potential). - @property - def output_shapes(self): - """Returns the shape of each component of an element of this iterator. + class _Saveable(BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" - Returns: - A nested structure of `tf.TensorShape` objects corresponding to each - component of an element of this dataset. - """ - return self._output_shapes + def __init__(self, iterator_resource, name): + serialized_iterator = gen_dataset_ops.serialize_iterator( + iterator_resource) + specs = [ + BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") + ] + # pylint: disable=protected-access + super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - @property - def output_types(self): - """Returns the type of each component of an element of this iterator. + def restore(self, restored_tensors, restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, + restored_tensors[0]) - Returns: - A nested structure of `tf.DType` objects corresponding to each component - of an element of this dataset. - """ - return self._output_types + def _gather_saveables_for_checkpoint(self): - def get_next(self, name=None): - """Returns a nested structure of `tf.Tensor`s containing the next element. + def _saveable_factory(name): + return self._Saveable(self._resource, name) - Args: - name: (Optional.) A name for the created operation. Currently unused. - - Returns: - A nested structure of `tf.Tensor` objects. - - Raises: - `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. - """ - del name - return self._next_internal() + return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a1611e92b113839c2dd2a3b2560b0ba90c0a7ef0..68bec9aee894edd60a025ac1cf87ca3e010db842 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,11 +16,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +import threading import time import numpy as np from tensorflow.contrib import lookup +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.contrib.data.python.ops import threadpool +from tensorflow.contrib.data.python.ops import unique from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test @@ -31,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training.checkpointable import util as checkpointable_utils class IteratorTest(test.TestCase): @@ -41,6 +48,18 @@ class IteratorTest(test.TestCase): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) + def testBasicOneShotIterator(self): + got = [] + for t in Dataset.range(4).make_one_shot_iterator(): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + + def testBasicImplicitIterator(self): + got = [] + for t in Dataset.range(4): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + def testGetNext(self): iterator = datasets.Iterator(Dataset.range(4)) self.assertEqual(0, iterator.get_next().numpy()) @@ -50,6 +69,15 @@ class IteratorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): iterator.get_next() + def testGetNextOneShotIterator(self): + iterator = Dataset.range(4).make_one_shot_iterator() + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + self.assertEqual(2, iterator.get_next().numpy()) + self.assertEqual(3, iterator.get_next().numpy()) + with self.assertRaises(errors.OutOfRangeError): + iterator.get_next() + def testMultipleIteratorsOnTheSameDataset(self): ds = Dataset.range(4) it1 = datasets.Iterator(ds) @@ -165,6 +193,105 @@ class IteratorTest(test.TestCase): x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) + def testTensorsExplicitPrefetchToDevice(self): + ds = Dataset.from_tensor_slices([0., 1.]) + ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) + + with self.assertRaisesRegexp(TypeError, 'prefetch_to_device'): + datasets.Iterator(ds) + + for i, x in enumerate(ds): + with ops.device(test.gpu_device_name()): + x = math_ops.add(x, x) + self.assertEqual(float(i) + float(i), x.numpy()) + + def testOverrideThreadPool(self): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + for num_threads in [1, 2, 4, 8, 16]: + + dataset = ( + Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, display_name='private_thread_pool_%d' % num_threads)) + + thread_ids = [] + for next_element in datasets.Iterator(dataset): + thread_ids.append(next_element) + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertAllEqual([1, 4], iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + + def testSaveRestoreMultipleIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator_1 = datasets.Iterator(dataset) + iterator_2 = datasets.Iterator(dataset) + dataset_2 = Dataset.range(10) + iterator_3 = datasets.Iterator(dataset_2) + + checkpoint = checkpointable_utils.Checkpoint( + iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) + self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) + self.assertEqual(0, iterator_3.get_next().numpy()) + self.assertEqual(1, iterator_3.get_next().numpy()) + self.assertEqual(2, iterator_3.get_next().numpy()) + + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + + def testRestoreExhaustedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(3) + iterator = datasets.Iterator(dataset) + + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertEqual(2, iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertEqual(2, iterator.get_next().numpy()) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 68e7b5421fec7f73f10e381ca45f9d900de299d7..7949a3f6da293abdd85512209242bae76ab4d816 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,12 +22,12 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics -from tensorflow.contrib.summary import summary_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops class Evaluator(object): @@ -57,7 +57,7 @@ class Evaluator(object): self._model = model self._metrics = {} self._evaluators = {} - if context.in_graph_mode(): + if not context.executing_eagerly(): self.call = function.defun(self.call) # ---- API for users ---- @@ -90,7 +90,7 @@ class Evaluator(object): Only for graph execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Evaluator.init_variables() not needed when " "eager execution is enabled.") return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) @@ -113,7 +113,8 @@ class Evaluator(object): with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() - if context.in_eager_mode(): + + if context.executing_eagerly(): return f() else: return function.defun(f)() @@ -158,16 +159,16 @@ class Evaluator(object): @end_compatibility """ summary_logdir = kwargs.pop("summary_logdir", None) - if context.in_graph_mode(): - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), - *args, **kwargs) - init_op = self.init_variables() - results_op = self.all_metric_results(summary_logdir) - return (init_op, call_op, results_op) - # Eager case - for example in datasets.Iterator(dataset): - self.__call__(example, *args, **kwargs) - return self.all_metric_results(summary_logdir) + if context.executing_eagerly(): + for example in datasets.Iterator(dataset): + self.__call__(example, *args, **kwargs) + return self.all_metric_results(summary_logdir) + # Graph construction + call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, + **kwargs) + init_op = self.init_variables() + results_op = self.all_metric_results(summary_logdir) + return (init_op, call_op, results_op) @staticmethod def run_evaluation(init_op, call_op, results_op, sess=None): @@ -192,7 +193,7 @@ class Evaluator(object): Only for graph execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Evaluator.run_evaluation() not supported when " "eager execution is enabled.") sess = sess or ops.get_default_session() diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 15a21885f66eface291a39fa0ee1ff28bc297548..1d9371c7ac405dbf0ec40210270b90f2cf9b9a25 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -7,8 +7,9 @@ py_library( name = "examples_pip", deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/l2hmc", + "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", - "//tensorflow/contrib/eager/python/examples/mnist", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b9ac79f46c83bb709918e3b72830b90ddcfd71b4..cc9cf53410f641cc3303b4450e9eaa1301904a64 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -32,10 +32,11 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data +layers = tf.keras.layers FLAGS = None -class Discriminator(tfe.Network): +class Discriminator(tf.keras.Model): """GAN Discriminator. A network to differentiate between generated and real handwritten digits. @@ -56,19 +57,15 @@ class Discriminator(tfe.Network): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', - data_format=data_format, - activation=tf.tanh)) - self.pool1 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, - data_format=data_format, - activation=tf.tanh)) - self.pool2 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.flatten = self.track_layer(tf.layers.Flatten()) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) - self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + self.conv1 = layers.Conv2D( + 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) + self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.conv2 = layers.Conv2D( + 128, 5, data_format=data_format, activation=tf.tanh) + self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.flatten = layers.Flatten() + self.fc1 = layers.Dense(1024, activation=tf.tanh) + self.fc2 = layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. @@ -95,7 +92,7 @@ class Discriminator(tfe.Network): return x -class Generator(tfe.Network): +class Generator(tf.keras.Model): """Generator of handwritten digits similar to the ones in the MNIST dataset. """ @@ -116,18 +113,17 @@ class Generator(tfe.Network): else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] - self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, - activation=tf.tanh)) + self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) - self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( - 64, 4, strides=2, activation=None, data_format=data_format)) + self.conv1 = layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) - self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( - 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + self.conv2 = layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): """Return a batch of generated images. @@ -168,7 +164,8 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): """ loss_on_real = tf.losses.sigmoid_cross_entropy( - tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + tf.ones_like(discriminator_real_outputs), + discriminator_real_outputs, label_smoothing=0.25) loss_on_generated = tf.losses.sigmoid_cross_entropy( tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) @@ -198,9 +195,9 @@ def generator_loss(discriminator_gen_outputs): return loss -def train_one_epoch(generator, discriminator, - generator_optimizer, discriminator_optimizer, - dataset, log_interval, noise_dim): +def train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, dataset, step_counter, + log_interval, noise_dim): """Trains `generator` and `discriminator` models on `dataset`. Args: @@ -209,7 +206,8 @@ def train_one_epoch(generator, discriminator, generator_optimizer: Optimizer to use for generator. discriminator_optimizer: Optimizer to use for discriminator. dataset: Dataset of images to train on. - log_interval: How many global steps to wait between logging and collecting + step_counter: An integer variable, used to write summaries regularly. + log_interval: How many steps to wait between logging and collecting summaries. noise_dim: Dimension of noise vector to use. """ @@ -218,18 +216,23 @@ def train_one_epoch(generator, discriminator, total_discriminator_loss = 0.0 for (batch_index, images) in enumerate(tfe.Iterator(dataset)): with tf.device('/cpu:0'): - tf.assign_add(tf.train.get_global_step(), 1) + tf.assign_add(step_counter, 1) - with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + with tf.contrib.summary.record_summaries_every_n_global_steps( + log_interval, global_step=step_counter): current_batch_size = images.shape[0] - noise = tf.random_uniform(shape=[current_batch_size, noise_dim], - minval=-1., maxval=1., seed=batch_index) + noise = tf.random_uniform( + shape=[current_batch_size, noise_dim], + minval=-1., + maxval=1., + seed=batch_index) - with tfe.GradientTape(persistent=True) as g: + with tf.GradientTape(persistent=True) as g: generated_images = generator(noise) - tf.contrib.summary.image('generated_images', - tf.reshape(generated_images, [-1, 28, 28, 1]), - max_images=10) + tf.contrib.summary.image( + 'generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) discriminator_gen_outputs = discriminator(generated_images) discriminator_real_outputs = discriminator(images) @@ -244,18 +247,16 @@ def train_one_epoch(generator, discriminator, discriminator_grad = g.gradient(discriminator_loss_val, discriminator.variables) - with tf.variable_scope('generator'): - generator_optimizer.apply_gradients(zip(generator_grad, - generator.variables)) - with tf.variable_scope('discriminator'): - discriminator_optimizer.apply_gradients(zip(discriminator_grad, - discriminator.variables)) + generator_optimizer.apply_gradients( + zip(generator_grad, generator.variables)) + discriminator_optimizer.apply_gradients( + zip(discriminator_grad, discriminator.variables)) if log_interval and batch_index > 0 and batch_index % log_interval == 0: print('Batch #%d\tAverage Generator Loss: %.6f\t' - 'Average Discriminator Loss: %.6f' % ( - batch_index, total_generator_loss/batch_index, - total_discriminator_loss/batch_index)) + 'Average Discriminator Loss: %.6f' % + (batch_index, total_generator_loss / batch_index, + total_discriminator_loss / batch_index)) def main(_): @@ -266,18 +267,18 @@ def main(_): # Load the datasets data = input_data.read_data_sets(FLAGS.data_dir) - dataset = (tf.data.Dataset - .from_tensor_slices(data.train.images) - .shuffle(60000) - .batch(FLAGS.batch_size)) - - # Create the models and optimizers - generator = Generator(data_format) - discriminator = Discriminator(data_format) - with tf.variable_scope('generator'): - generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) - with tf.variable_scope('discriminator'): - discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + dataset = ( + tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000) + .batch(FLAGS.batch_size)) + + # Create the models and optimizers. + model_objects = { + 'generator': Generator(data_format), + 'discriminator': Discriminator(data_format), + 'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), + 'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), + 'step_counter': tf.train.get_or_create_global_step(), + } # Prepare summary writer and checkpoint info summary_writer = tf.contrib.summary.create_summary_file_writer( @@ -286,32 +287,26 @@ def main(_): latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) + checkpoint = tfe.Checkpoint(**model_objects) + # Restore variables on creation if a checkpoint exists. + checkpoint.restore(latest_cpkt) with tf.device(device): - for epoch in range(1, 101): - with tfe.restore_variables_on_create(latest_cpkt): - global_step = tf.train.get_or_create_global_step() - start = time.time() - with summary_writer.as_default(): - train_one_epoch(generator, discriminator, generator_optimizer, - discriminator_optimizer, - dataset, FLAGS.log_interval, FLAGS.noise) - end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) - - all_variables = ( - generator.variables - + discriminator.variables - + generator_optimizer.variables() - + discriminator_optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) + for _ in range(100): + start = time.time() + with summary_writer.as_default(): + train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval, + noise_dim=FLAGS.noise, **model_objects) + end = time.time() + checkpoint.save(checkpoint_prefix) + print('\nTrain time for epoch #%d (step %d): %f' % + (checkpoint.save_counter.numpy(), + checkpoint.step_counter.numpy(), + end - start)) if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index 4a3ca8d82bc2619b05a734f6d2e58431c1a45995..81ac05e26d23c2fc53f63d64bb28bdea6072e396 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -62,7 +62,7 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): for _ in range(measure_batches)] measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images) - tf.train.get_or_create_global_step() + step_counter = tf.train.get_or_create_global_step() with tf.device(device()): # Create the models and optimizers generator = mnist.Generator(data_format()) @@ -78,13 +78,15 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): # warm up mnist.train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, - burn_dataset, log_interval=SUMMARY_INTERVAL, + burn_dataset, step_counter, + log_interval=SUMMARY_INTERVAL, noise_dim=NOISE_DIM) # measure start = time.time() mnist.train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, - measure_dataset, log_interval=SUMMARY_INTERVAL, + measure_dataset, step_counter, + log_interval=SUMMARY_INTERVAL, noise_dim=NOISE_DIM) self._report('train', start, measure_batches, batch_size) @@ -109,5 +111,5 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD similarity index 57% rename from tensorflow/contrib/eager/python/examples/mnist/BUILD rename to tensorflow/contrib/eager/python/examples/l2hmc/BUILD index c61ec2dbae60a782c0e6589701554b045dcb92ae..7bdf9053de749af9d09b12ba7b848e21c1fdb8f0 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/BUILD +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -4,33 +4,36 @@ package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") -py_binary( - name = "mnist", - srcs = ["mnist.py"], +py_library( + name = "neural_nets", + srcs = ["neural_nets.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", - "//tensorflow/examples/tutorials/mnist:input_data", ], ) -cuda_py_test( - name = "mnist_test", - srcs = ["mnist_test.py"], - additional_deps = [ - ":mnist", - "//tensorflow/contrib/eager/python:tfe", +py_library( + name = "l2hmc", + srcs = ["l2hmc.py"], + srcs_version = "PY2AND3", + deps = [ + ":neural_nets", "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", ], ) cuda_py_test( - name = "mnist_graph_test", - srcs = ["mnist_graph_test.py"], + name = "l2hmc_test", + size = "large", + srcs = ["l2hmc_test.py"], additional_deps = [ - ":mnist", - "//third_party/py/numpy", + ":l2hmc", "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..729d8525fab31ee214178ca1bcb18dbd069f767a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets + + +class Dynamics(tf.keras.Model): + """Dynamics engine of naive L2HMC sampler. + + Args: + x_dim: dimensionality of observed data + loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + """ + + def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + super(Dynamics, self).__init__() + + self.x_dim = x_dim + self.potential = loglikelihood_fn + self.n_steps = n_steps + + self._construct_time() + self._construct_masks() + + self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) + self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) + + self.eps = tfe.Variable( + initial_value=eps, name="eps", dtype=tf.float32, trainable=True) + + def apply_transition(self, position): + """Propose a new state and perform the accept or reject step.""" + + # Simulate dynamics both forward and backward; + # Use sampled Bernoulli masks to compute the actual solutions + position_f, momentum_f, accept_prob_f = self.transition_kernel( + position, forward=True) + position_b, momentum_b, accept_prob_b = self.transition_kernel( + position, forward=False) + + # Decide direction uniformly + forward_mask = tf.cast( + tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + backward_mask = 1. - forward_mask + + # Obtain proposed states + position_post = ( + forward_mask[:, None] * position_f + + backward_mask[:, None] * position_b) + momentum_post = ( + forward_mask[:, None] * momentum_f + + backward_mask[:, None] * momentum_b) + + # Probability of accepting the proposed states + accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b + + # Accept or reject step + accept_mask = tf.cast( + accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32) + reject_mask = 1. - accept_mask + + # Samples after accept/reject step + position_out = ( + accept_mask[:, None] * position_post + reject_mask[:, None] * position) + + return position_post, momentum_post, accept_prob, position_out + + def transition_kernel(self, position, forward=True): + """Transition kernel of augmented leapfrog integrator.""" + + lf_fn = self._forward_lf if forward else self._backward_lf + + # Resample momentum + momentum = tf.random_normal(tf.shape(position)) + position_post, momentum_post = position, momentum + sumlogdet = 0. + # Apply augmented leapfrog steps + for i in range(self.n_steps): + position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, + i) + sumlogdet += logdet + + accept_prob = self._compute_accept_prob(position, momentum, position_post, + momentum_post, sumlogdet) + + return position_post, momentum_post, accept_prob + + def _forward_lf(self, position, momentum, i): + """One forward augmented leapfrog step. See eq (5-6) in paper.""" + + t = self._get_time(i) + mask, mask_inv = self._get_mask(i) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _backward_lf(self, position, momentum, i): + """One backward augmented leapfrog step. See Appendix A in paper.""" + + # Reversed index/sinusoidal time + t = self._get_time(self.n_steps - i - 1) + mask, mask_inv = self._get_mask(self.n_steps - i - 1) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _update_momentum_forward(self, position, momentum, t): + """Update v in the forward leapfrog step.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= .5 * self.eps + transformed *= self.eps + momentum = ( + momentum * tf.exp(scale) - + .5 * self.eps * (tf.exp(transformed) * grad - translation)) + + return momentum, scale + + def _update_position_forward(self, position, momentum, t, mask): + """Update x in the forward leapfrog step.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask * position, t]) + scale *= self.eps + transformed *= self.eps + position = ( + mask * position + + mask_inv * (position * tf.exp(scale) + self.eps * + (tf.exp(transformed) * momentum + translation))) + + return position, mask_inv * scale + + def _update_momentum_backward(self, position, momentum, t): + """Update v in the backward leapfrog step. Inverting the forward update.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= -.5 * self.eps + transformed *= self.eps + momentum = ( + tf.exp(scale) * (momentum + .5 * self.eps * + (tf.exp(transformed) * grad - translation))) + + return momentum, scale + + def _update_position_backward(self, position, momentum, t, mask): + """Update x in the backward leapfrog step. Inverting the forward update.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask_inv * position, t]) + scale *= -self.eps + transformed *= self.eps + position = ( + mask_inv * position + mask * tf.exp(scale) * + (position - self.eps * tf.exp(transformed) * momentum + translation)) + + return position, mask * scale + + def _compute_accept_prob(self, position, momentum, position_post, + momentum_post, sumlogdet): + """Compute the prob of accepting the proposed state given old state.""" + + old_hamil = self.hamiltonian(position, momentum) + new_hamil = self.hamiltonian(position_post, momentum_post) + + return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + + def _construct_time(self): + """Convert leapfrog step index into sinusoidal time.""" + + self.ts = [] + for i in range(self.n_steps): + t = tf.constant( + [ + np.cos(2 * np.pi * i / self.n_steps), + np.sin(2 * np.pi * i / self.n_steps) + ], + dtype=tf.float32) + self.ts.append(t[None, :]) + + def _get_time(self, i): + """Get sinusoidal time for i-th augmented leapfrog step.""" + + return self.ts[i] + + def _construct_masks(self): + """Construct different binary masks for different time steps.""" + + self.masks = [] + for _ in range(self.n_steps): + idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] + mask = np.zeros((self.x_dim,)) + mask[idx] = 1. + mask = tf.constant(mask, dtype=tf.float32) + self.masks.append(mask[None, :]) + + def _get_mask(self, i): + """Get binary masks for i-th augmented leapfrog step.""" + + m = self.masks[i] + return m, 1. - m + + def kinetic(self, v): + """Compute the kinetic energy.""" + + return .5 * tf.reduce_sum(v**2, axis=1) + + def hamiltonian(self, position, momentum): + """Compute the overall Hamiltonian.""" + + return self.potential(position) + self.kinetic(momentum) + + def grad_potential(self, position, check_numerics=True): + """Get gradient of potential function at current location.""" + + if not tf.executing_eagerly(): + # TODO(lxuechen): Change this to tfe.gradients_function when it works + grad = tf.gradients(self.potential(position), position)[0] + else: + grad = tfe.gradients_function(self.potential)(position)[0] + + if check_numerics: + return tf.check_numerics(grad, message="gradient of potential") + + return grad + + +# Examples of unnormalized log density/probabilities +def get_scg_energy_fn(): + """Get energy function for 2d strongly correlated Gaussian.""" + + # Avoid recreating tf constants on each invocation of gradients + mu = tf.constant([0., 0.]) + sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy + + +def get_multivariate_gaussian_energy_fn(x_dim=2): + """Get energy function for 2d strongly correlated Gaussian.""" + + mu = tf.random_normal(shape=[x_dim]) + # Lower triangularize and positive diagonal + l = tf.sigmoid( + tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) + # Exploit Cholesky decomposition + sigma = tf.matmul(l, tf.transpose(l)) + sigma *= 100. # Small covariance causes extreme numerical instability + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e33b4cae4c73388dfd78542c9907953f137ad710 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc + + +def get_default_hparams(): + return tf.contrib.training.HParams( + x_dim=2, + n_samples=200, + n_steps=10, + eps=.1, + n_iters=10, + learning_rate=.0003, + n_warmup_iters=3) + + +# Relevant functions for benchmarking +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + + with tf.GradientTape() as tape: + loss_val, x_out = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.variables) + + return loss_val, grads, x_out + + +def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): + """Warmup optimization to reduce overhead.""" + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + for _ in range(n_iters): + _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + +def fit(dynamics, + samples, + optimizer, + loss_fn=compute_loss, + n_iters=5000, + verbose=True, + logdir=None, + decay_lr=True): + """Fit L2HMC sampler with given log-likelihood function.""" + + if logdir: + summary_writer = tf.contrib.summary.create_file_writer(logdir) + + for i in range(n_iters): + loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + # TODO(lxuechen): Proper learning rate decay + if decay_lr: + grads = [grad * .96**(i // 1000) for grad in grads] + optimizer.apply_gradients(zip(grads, dynamics.variables)) + if verbose: + print("Iteration %d: loss %.4f" % (i, loss)) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) + + +class L2hmcTest(tf.test.TestCase): + """Unit tests for l2hmc in both eager and graph mode.""" + + def test_apply_transition(self): + """Testing function `Dynamics.apply_transition` in graph and eager mode.""" + + # Eager mode testing + hparams = get_default_hparams() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples) + + self.assertEqual(x_.shape, v_.shape) + self.assertEqual(x_out.shape, samples.shape) + self.assertEqual(x_.shape, x_out.shape) + self.assertEqual(x_accept_prob.shape, (hparams.n_samples,)) + + # Graph mode testing + with tf.Graph().as_default(): + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x) + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run( + [x_, v_, x_accept_prob, x_out], feed_dict={x: samples}) + + self.assertEqual(np_x_.shape, np_v_.shape) + self.assertEqual(samples.shape, np_x_out.shape) + self.assertEqual(np_x_.shape, np_x_out.shape) + self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,)) + + +class L2hmcBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for l2hmc.""" + + def _get_energy_fn(self): + """Get specific energy function according to FLAGS.""" + + if FLAGS.energy_fn == "scg": + energy_fn = l2hmc.get_scg_energy_fn() + elif FLAGS.energy_fn == "multivariate_gaussian": + energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) + else: + raise ValueError("No such energy function %s" % FLAGS.energy_fn) + + return energy_fn + + def benchmark_graph(self): + """Benchmark Graph performance.""" + + hparams = get_default_hparams() + tf.reset_default_graph() + with tf.Graph().as_default(): + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + loss, x_out = compute_loss(dynamics, x) + + global_step = tf.Variable(0., name="global_step", trainable=False) + learning_rate = tf.train.exponential_decay( + hparams.learning_rate, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + # Warmup to reduce initialization effect when timing + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + for _ in range(hparams.n_warmup_iters): + _, _, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + + # Training + start_time = time.time() + for i in range(hparams.n_iters): + samples, loss_np, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + print("Iteration %d: loss %.4f" % (i, loss_np)) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="graph_train_%s" % ("gpu" + if tf.test.is_gpu_available() else "cpu"), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + def benchmark_eager(self): + self._benchmark_eager() + + def benchmark_eager_defun(self): + self._benchmark_eager(defun=True) + + def _benchmark_eager(self, defun=False): + """Benchmark Eager performance.""" + + hparams = get_default_hparams() + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + loss_fn = tfe.defun(compute_loss) if defun else compute_loss + + # Warmup to reduce initialization effect when timing + warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) + + # Training + samples = tf.random_normal( + shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) + start_time = time.time() + fit(dynamics, + samples, + optimizer, + loss_fn=loss_fn, + n_iters=hparams.n_iters, + decay_lr=True) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else + "cpu", "_defun" if defun else ""), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + del dynamics + del loss_fn + + +if __name__ == "__main__": + tf.flags.DEFINE_string("energy_fn", "scg", + ("The energy function/unnormalized log-probability. " + "Either be `scg` or `multivariate_gaussian`")) + tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") + FLAGS = tf.flags.FLAGS + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..e230ad5e259df5b450897bd815e901e3934cd293 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.eager as tfe + + +class GenericNet(tf.keras.Model): + """Generic neural net with different initialization scale based on input. + + Args: + x_dim: dimensionality of observed data + factor: factor of variance scaling initializer + n_hidden: number of hidden units + """ + + def __init__(self, x_dim, factor, n_hidden=10): + super(GenericNet, self).__init__() + + self.v_layer = _custom_dense(n_hidden, 1. / 3.) + self.x_layer = _custom_dense(n_hidden, factor / 3.) + self.t_layer = _custom_dense(n_hidden, 1. / 3.) + self.h_layer = _custom_dense(n_hidden) + + # Scale + self.scale_layer = _custom_dense(x_dim, .001) + self.coeff_scale = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) + # Translation + self.translation_layer = _custom_dense(x_dim, factor=.001) + # Transformation + self.transformation_layer = _custom_dense(x_dim, .001) + self.coeff_transformation = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), + name='coeff_transformation', + trainable=True) + + def call(self, inputs): + v, x, t = inputs + h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t) + h = tf.nn.relu(h) + h = self.h_layer(h) + h = tf.nn.relu(h) + scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale) + translation = self.translation_layer(h) + transformation = ( + tf.nn.tanh(self.transformation_layer(h)) * tf.exp( + self.coeff_transformation)) + + return scale, translation, transformation + + +def _custom_dense(units, factor=1.): + """Custom dense layer with specified weight initialization.""" + + return tf.keras.layers.Dense( + units=units, + use_bias=True, + kernel_initializer=tf.contrib.layers.variance_scaling_initializer( + factor=factor * 2., mode='FAN_IN', uniform=False), + bias_initializer=tf.constant_initializer(0., dtype=tf.float32)) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index f86331af6f7928f0f86c888e22706c6e0a5978b2..2f6cfdf31e852d5d69a7a87980c9a441da504cf2 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -22,6 +22,7 @@ cuda_py_test( ":linear_regression", "//tensorflow:tensorflow_py", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) cuda_py_test( diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 6ce4de6ee0bf50400eff339ac04e132252a2b53e..099b712fc06d1d3eb9ab4095f8db7283690bda76 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -32,24 +32,16 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe +layers = tf.keras.layers -class LinearModel(tfe.Network): - """A TensorFlow linear regression model. - Uses TensorFlow's eager execution. - - For those familiar with TensorFlow graphs, notice the absence of - `tf.Session`. The `forward()` method here immediately executes and - returns output values. The `loss()` method immediately compares the - output of `forward()` with the target and returns the MSE loss value. - The `fit()` performs gradient-descent training on the model's weights - and bias. - """ +class LinearModel(tf.keras.Model): + """A TensorFlow linear regression model.""" def __init__(self): """Constructs a LinearModel object.""" super(LinearModel, self).__init__() - self._hidden_layer = self.track_layer(tf.layers.Dense(1)) + self._hidden_layer = layers.Dense(1) def call(self, xs): """Invoke the linear model. @@ -64,7 +56,7 @@ class LinearModel(tfe.Network): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(model(xs) - ys)) + return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) def fit(model, dataset, optimizer, verbose=False, logdir=None): @@ -83,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): mse = lambda xs, ys: mean_square_loss(model, xs, ys) loss_and_grads = tfe.implicit_value_and_gradients(mse) - tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= @@ -95,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if verbose: print("Iteration %d: loss = %s" % (i, loss.numpy())) - optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + optimizer.apply_gradients(grads) if logdir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + tf.contrib.summary.scalar("loss", loss, step=i) + tf.contrib.summary.scalar("step", i, step=i) def synthetic_dataset(w, b, noise_level, batch_size, num_batches): @@ -127,7 +119,7 @@ def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size, def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() # Ground-truth constants. true_w = [[-2.0], [4.0], [1.0]] true_b = [0.5] diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index e53234b51a7dccc11e548ac81a7ef070c628aa52..2bc2fc2aa9150a3181db612439d0c37c8e76d1e3 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -117,5 +117,5 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/README.md b/tensorflow/contrib/eager/python/examples/mnist/README.md index e987996b88ccf54a322749aadec4f9840760a90f..d1c079ff6b5cb187bbcfe2742293982b1bedd2d4 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/README.md +++ b/tensorflow/contrib/eager/python/examples/mnist/README.md @@ -1,10 +1 @@ -Classification model for the MNIST dataset using eager execution. - -To run: - -``` -python mnist.py -``` - -`mnist_graph_test.py` demonstrates that the same code that is executed eagerly -in `mnist.py` is used to construct a TensorFlow graph. +See https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py deleted file mode 100644 index 58b1e89d15895cf38331e6f7bd5a311a2f5f6467..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A deep MNIST classifier using convolutional layers. - -Sample usage: - python mnist.py --help -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import sys -import time - -import tensorflow as tf - -import tensorflow.contrib.eager as tfe -from tensorflow.examples.tutorials.mnist import input_data - -FLAGS = None - - -class MNISTModel(tf.keras.Model): - """MNIST Network. - - Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py - and - https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py - - But written using the tf.layers API. - """ - - def __init__(self, data_format): - """Creates a model for classifying a hand-written digit. - - Args: - data_format: Either 'channels_first' or 'channels_last'. - 'channels_first' is typically faster on GPUs while 'channels_last' is - typically faster on CPUs. See - https://www.tensorflow.org/performance/performance_guide#data_formats - """ - super(MNISTModel, self).__init__(name='') - if data_format == 'channels_first': - self._input_shape = [-1, 1, 28, 28] - else: - assert data_format == 'channels_last' - self._input_shape = [-1, 28, 28, 1] - self.conv1 = tf.layers.Conv2D( - 32, 5, data_format=data_format, activation=tf.nn.relu) - self.conv2 = tf.layers.Conv2D( - 64, 5, data_format=data_format, activation=tf.nn.relu) - self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu) - self.fc2 = tf.layers.Dense(10) - self.dropout = tf.layers.Dropout(0.5) - self.max_pool2d = tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='SAME', data_format=data_format) - - def call(self, inputs, training=False): - """Computes labels from inputs. - - Users should invoke __call__ to run the network, which delegates to this - method (and not call this method directly). - - Args: - inputs: A batch of images as a Tensor with shape [batch_size, 784]. - training: True if invoked in the context of training (causing dropout to - be applied). False otherwise. - - Returns: - A Tensor with shape [batch_size, 10] containing the predicted logits - for each image in the batch, for each of the 10 classes. - """ - - x = tf.reshape(inputs, self._input_shape) - x = self.conv1(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.max_pool2d(x) - x = tf.layers.flatten(x) - x = self.fc1(x) - x = self.dropout(x, training=training) - x = self.fc2(x) - return x - - -def loss(predictions, labels): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=predictions, labels=labels)) - - -def compute_accuracy(predictions, labels): - return tf.reduce_sum( - tf.cast( - tf.equal( - tf.argmax(predictions, axis=1, - output_type=tf.int64), - tf.argmax(labels, axis=1, - output_type=tf.int64)), - dtype=tf.float32)) / float(predictions.shape[0].value) - - -def train_one_epoch(model, optimizer, dataset, log_interval=None): - """Trains model on `dataset` using `optimizer`.""" - - tf.train.get_or_create_global_step() - - for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): - with tf.contrib.summary.record_summaries_every_n_global_steps(10): - with tfe.GradientTape() as tape: - prediction = model(images, training=True) - loss_value = loss(prediction, labels) - tf.contrib.summary.scalar('loss', loss_value) - tf.contrib.summary.scalar('accuracy', - compute_accuracy(prediction, labels)) - grads = tape.gradient(loss_value, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) - if log_interval and batch % log_interval == 0: - print('Batch #%d\tLoss: %.6f' % (batch, loss_value)) - - -def test(model, dataset): - """Perform an evaluation of `model` on the examples from `dataset`.""" - avg_loss = tfe.metrics.Mean('loss') - accuracy = tfe.metrics.Accuracy('accuracy') - - for (images, labels) in tfe.Iterator(dataset): - predictions = model(images, training=False) - avg_loss(loss(predictions, labels)) - accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64), - tf.argmax(labels, axis=1, output_type=tf.int64)) - print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' % - (avg_loss.result(), 100 * accuracy.result())) - with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar('loss', avg_loss.result()) - tf.contrib.summary.scalar('accuracy', accuracy.result()) - - -def load_data(data_dir): - """Returns training and test tf.data.Dataset objects.""" - data = input_data.read_data_sets(data_dir, one_hot=True) - train_ds = tf.data.Dataset.from_tensor_slices((data.train.images, - data.train.labels)) - test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels)) - return (train_ds, test_ds) - - -def main(_): - tfe.enable_eager_execution() - - (device, data_format) = ('/gpu:0', 'channels_first') - if FLAGS.no_gpu or tfe.num_gpus() <= 0: - (device, data_format) = ('/cpu:0', 'channels_last') - print('Using device %s, and data format %s.' % (device, data_format)) - - # Load the datasets - (train_ds, test_ds) = load_data(FLAGS.data_dir) - train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size) - - # Create the model and optimizer - model = MNISTModel(data_format) - optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) - - if FLAGS.output_dir: - train_dir = os.path.join(FLAGS.output_dir, 'train') - test_dir = os.path.join(FLAGS.output_dir, 'eval') - tf.gfile.MakeDirs(FLAGS.output_dir) - else: - train_dir = None - test_dir = None - summary_writer = tf.contrib.summary.create_file_writer( - train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_file_writer( - test_dir, flush_millis=10000, name='test') - checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') - - with tf.device(device): - for epoch in range(1, 11): - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(FLAGS.checkpoint_dir)): - global_step = tf.train.get_or_create_global_step() - start = time.time() - with summary_writer.as_default(): - train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval) - end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) - with test_summary_writer.as_default(): - test(model, test_ds) - all_variables = ( - model.variables - + optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--data-dir', - type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - parser.add_argument( - '--batch-size', - type=int, - default=64, - metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument( - '--log-interval', - type=int, - default=10, - metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument( - '--output_dir', - type=str, - default=None, - metavar='N', - help='Directory to write TensorBoard summaries') - parser.add_argument( - '--checkpoint_dir', - type=str, - default='/tmp/tensorflow/mnist/checkpoints/', - metavar='N', - help='Directory to save checkpoints in (once per epoch)') - parser.add_argument( - '--lr', - type=float, - default=0.01, - metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument( - '--momentum', - type=float, - default=0.5, - metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument( - '--no-gpu', - action='store_true', - default=False, - help='disables GPU usage even if a GPU is available') - - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py deleted file mode 100644 index 1af26553120b34d4682b17b1c29c81dc65e421d4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.mnist import mnist - - -def data_format(): - return "channels_first" if tf.test.is_gpu_available() else "channels_last" - - -class MNISTGraphTest(tf.test.TestCase): - - def testTrainGraph(self): - # The MNISTModel class can be executed eagerly (as in mnist.py and - # mnist_test.py) and also be used to construct a TensorFlow graph, which is - # then trained in a session. - with tf.Graph().as_default(): - # Generate some random data. - batch_size = 64 - images = np.random.randn(batch_size, 784).astype(np.float32) - digits = np.random.randint(low=0, high=10, size=batch_size) - labels = np.zeros((batch_size, 10)) - labels[np.arange(batch_size), digits] = 1. - - # Create a model, optimizer, and dataset as would be done - # for eager execution as well. - model = mnist.MNISTModel(data_format()) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) - dataset = tf.data.Dataset.from_tensors((images, labels)) - - # Define the loss tensor (as opposed to a loss function when - # using eager execution). - (images, labels) = dataset.make_one_shot_iterator().get_next() - predictions = model(images, training=True) - loss = mnist.loss(predictions, labels) - - train_op = optimizer.minimize(loss) - init = tf.global_variables_initializer() - with tf.Session() as sess: - # Variables have to be initialized in the session. - sess.run(init) - # Train using the optimizer. - sess.run(train_op) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py deleted file mode 100644 index 136085eba21284a42282395e54f32c33bf63b5c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import tensorflow.contrib.eager as tfe -from tensorflow.contrib.eager.python.examples.mnist import mnist - - -def device(): - return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" - - -def data_format(): - return "channels_first" if tfe.num_gpus() else "channels_last" - - -def random_dataset(): - batch_size = 64 - images = tf.random_normal([batch_size, 784]) - digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32) - labels = tf.one_hot(digits, 10) - return tf.data.Dataset.from_tensors((images, labels)) - - -def train_one_epoch(defun=False): - model = mnist.MNISTModel(data_format()) - if defun: - model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.train_one_epoch(model, optimizer, dataset) - - -def evaluate(defun=False): - model = mnist.MNISTModel(data_format()) - dataset = random_dataset() - if defun: - model.call = tfe.defun(model.call) - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.test(model, dataset) - - -class MNISTTest(tf.test.TestCase): - - def testTrainOneEpoch(self): - train_one_epoch(defun=False) - - def testTest(self): - evaluate(defun=False) - - def testTrainOneEpochWithDefunCall(self): - train_one_epoch(defun=True) - - def testTestWithDefunCall(self): - evaluate(defun=True) - - -if __name__ == "__main__": - tfe.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb index 459f2f4a7d2afa153e77069bc3ce0c5360ddd7e2..51d10a778413cfbb574b4e22e8adcb18bd731dee 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -7,17 +7,13 @@ "id": "U9i2Dsh-ziXr" }, "source": [ - "# Eager Execution Tutorial: Basics\n", + "# An introduction to TensorFlow\n", "\n", - "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n", + "This is an introductory tutorial for using TensorFlow. It will cover:\n", "\n", "* Importing required packages\n", - "* Enabling eager execution\n", - "* Creating and using TensorFlow Tensors and Variables\n", - "* Using TensorFlow interactively\n", - "* Using GPUs with eager execution enabled\n", - "\n", - "This notebook does *not* cover modeling topics, such as gradients." + "* Creating and using Tensors\n", + "* Using GPU acceleration\n" ] }, { @@ -27,9 +23,10 @@ "id": "z1JcS5iBXMRO" }, "source": [ - "# Step 1: Import Eager\n", + "## Import TensorFlow\n", "\n", - "The key imports for eager execution are the following:" + "To get started, import the `tensorflow` module and enable eager execution.\n", + "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later." ] }, { @@ -48,11 +45,9 @@ }, "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "import tensorflow.contrib.eager as tfe" + "tf.enable_eager_execution()" ] }, { @@ -62,10 +57,9 @@ "id": "H9UySOPLXdaw" }, "source": [ - "# Step 2: Enable eager execution\n", + "## Tensors\n", "\n", - "All future TensorFlow calls will execute the\n", - "underlying TensorFlow ops immediately:" + "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n" ] }, { @@ -77,60 +71,47 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 125 }, "colab_type": "code", - "id": "WPTUfGq6kJ5w" - }, - "outputs": [], - "source": [ - "tfe.enable_eager_execution()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "twBfWd5xyu_d" - }, - "source": [ - "# Step 3: Interactively Use TensorFlow!\n", - "\n", - "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n", - "\n", - "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1526420535530, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "colab_type": "code", - "id": "ngUe237Wt48W" + "id": "ngUe237Wt48W", + "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(3, shape=(), dtype=int32)\n", + "tf.Tensor([4 6], shape=(2,), dtype=int32)\n", + "tf.Tensor(25, shape=(), dtype=int32)\n", + "tf.Tensor(6, shape=(), dtype=int32)\n", + "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n", + "tf.Tensor(13, shape=(), dtype=int32)\n" + ] + } + ], "source": [ "print(tf.add(1, 2))\n", "print(tf.add([1, 2], [3, 4]))\n", "print(tf.square(5))\n", "print(tf.reduce_sum([1, 2, 3]))\n", "print(tf.encode_base64(\"hello world\"))\n", - "print(\"\")\n", - "\n", - "x = tf.constant(2)\n", - "y = tf.constant(3)\n", - "print(x * y + 1)\n", "\n", - "# Most TensorFlow ops are directly usable with eager execution, giving\n", - "# results immediately.\n", - "print(tf.contrib.signal.hamming_window(x * y + 1))" + "# Operator overloading is also supported\n", + "print(tf.square(2) + tf.square(3))" ] }, { @@ -140,7 +121,7 @@ "id": "IDY4WsYRhP81" }, "source": [ - "Numpy arrays are supported, too:" + "Each Tensor has a shape and a datatype" ] }, { @@ -151,178 +132,144 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, "colab_type": "code", - "id": "lCUWzso6mbqR" + "executionInfo": { + "elapsed": 215, + "status": "ok", + "timestamp": 1526420538162, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "srYWH1MdJNG7", + "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 2)\n", + "\u003cdtype: 'int32'\u003e\n" + ] + } + ], "source": [ - "import numpy as np\n", - "\n", - "ones = np.ones([3, 3])\n", - "\n", - "print(\"numpy 3x3 matrix of 1s:\")\n", - "print(ones)\n", - "print(\"\")\n", - "\n", - "print(\"Multiplied by 42:\")\n", - "print(tf.multiply(ones, 42))" + "x = tf.matmul([[1]], [[2, 3]])\n", + "print(x.shape)\n", + "print(x.dtype)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "PBNP8yTRfu_X" + "id": "eBPw8e8vrsom" }, "source": [ - "# Step 4: Define and Print TensorFlow Variables\n", + "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", "\n", - "To define TensorFlow variables, use the `get_variable()` function as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "3Twf_Rw-gQFM" - }, - "outputs": [], - "source": [ - "x = tf.get_variable(name=\"x\", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer)" + "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n", + "2. Tensors are immutable." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "45G7094TxsMb" + "id": "Dwi1tdW3JBw6" }, "source": [ - "## Printing TensorFlow Variables" + "### NumPy Compatibility\n", + "\n", + "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n", + "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n", + "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n", + "\n", + "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n", + "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 251 }, "colab_type": "code", - "id": "UJBJeZ5XxuwA" + "executionInfo": { + "elapsed": 238, + "status": "ok", + "timestamp": 1526420540562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "lCUWzso6mbqR", + "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorFlow operations convert numpy arrays to Tensors automatically\n", + "tf.Tensor(\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n", + "And NumPy operations convert Tensors to numpy arrays automatically\n", + "[[ 43. 43. 43.]\n", + " [ 43. 43. 43.]\n", + " [ 43. 43. 43.]]\n", + "The .numpy() method explicitly converts a Tensor to a numpy array\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]]\n" + ] + } + ], "source": [ - "# This does NOT print the Variable's actual value:\n", - "print(\"Printing a TensorFlow Variable:\")\n", - "print(x)\n", - "print(\"\")\n", + "import numpy as np\n", "\n", - "# A TensorFlow variable represents a reference to a tensor.\n", - "# The `read_value()` method provides access to the current value of the\n", - "# variable. Tensorflow Variables are automatically initialized according to the\n", - "# semantics defined in tf.get_variable().\n", - "print(\"Printing a TensorFlow Variable's value using .read_value():\")\n", - "print(x.read_value())\n", - "print(\"\")\n", + "ndarray = np.ones([3, 3])\n", "\n", - "print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n", - "print(x.read_value().numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "2njjWHcTpBEn" - }, - "source": [ - "## Changing a TensorFlow Variable's value\n", + "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n", + "tensor = tf.multiply(ndarray, 42)\n", + "print(tensor)\n", "\n", - "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "v3wr6Erbo_hB" - }, - "outputs": [], - "source": [ - "x.assign(42)\n", - "print(x.read_value())\n", "\n", - "x.assign_add(3)\n", - "print(x.read_value())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "uhtynjHVpTB5" - }, - "source": [ - "## Use a Variable just like any other Tensor" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "7PbktdnHoehR" - }, - "outputs": [], - "source": [ - "print(x + 3)\n", + "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n", + "print(np.add(tensor, 1))\n", "\n", - "# This code will broadcast the value across the list of numbers:\n", - "print(x * [1, 2, 4])" + "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", + "print(tensor.numpy())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "GVChqwlwy1SI" + "id": "PBNP8yTRfu_X" }, "source": [ - "# Step 5: Debug Errors with Instant Feedback\n", + "## GPU acceleration\n", "\n", - "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n", - "\n", - "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n", - "one being legal and the other being illegal, leading to a runtime error that is\n", - "raised immediately." + "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:" ] }, { @@ -334,125 +281,68 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, "colab_type": "code", - "id": "23ap04N0v4k0" - }, - "outputs": [], - "source": [ - "vector = tf.constant([10.0, 20.0, 30.0, 40.0])" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "FCUMsIYxxRRa" - }, - "outputs": [], - "source": [ - "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n", - "# arguments) are within the bound of `vector`.\n", - "print(tf.slice(vector, [1], [3]))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "executionInfo": { + "elapsed": 340, + "status": "ok", + "timestamp": 1526420543562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "colab_type": "code", - "id": "T8me2oCNxpFp" + "id": "3Twf_Rw-gQFM", + "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is there a GPU available: False\n", + "Is the Tensor on GPU #0: False\n" + ] + } + ], "source": [ - "# The following does NOT work, because the value of `size` (the 3rd\n", - "# argument) causes the indices to go out of the bounds of `vector`. The\n", - "# error is raised immediately.\n", - "try:\n", - " print(tf.slice(vector, [1], [4]))\n", - "except tf.OpError as e:\n", - " print(\"Caught error: %s\" % e)" + "x = tf.random_uniform([3, 3])\n", + "\n", + "print(\"Is there a GPU available: \"),\n", + "print(tf.test.is_gpu_available())\n", + "\n", + "print(\"Is the Tensor on GPU #0: \"),\n", + "print(x.device.endswith('GPU:0'))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "irxJhAgar84v" + "id": "vpgYzgVXW2Ud" }, "source": [ - "# Step 6: Using the GPU\n", + "### Device Names\n", "\n", - "You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n", - "\n", - "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster." + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host." ] }, { - "cell_type": "code", - "execution_count": 0, + "cell_type": "markdown", "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "7J4N9baqaKCL" + "colab_type": "text", + "id": "ZWZQCimzuqyP" }, - "outputs": [], "source": [ - "# The example code from here on will work only if your notebook\n", - "# is running on a machine with a functional CUDA GPU. The following\n", - "# line checks that.\n", - "is_gpu_available = tfe.num_gpus() \u003e 0\n", "\n", - "# Create some Tensors\n", - "SIZE = 1000\n", - "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", "\n", - "if is_gpu_available:\n", - " gpu_tensor = cpu_tensor.gpu()\n", - "else:\n", - " print(\"GPU not available.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4E-2n7VbzY1n" - }, - "outputs": [], - "source": [ - "# Time a CPU-based matrix multiplication\n", + "### Explicit Device Placement\n", "\n", - "print(\"Time to conduct matmul on CPU:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)" + "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:" ] }, { @@ -463,65 +353,73 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, "colab_type": "code", - "id": "vbSFW-T5zhZF" + "executionInfo": { + "elapsed": 1762, + "status": "ok", + "timestamp": 1526420547562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "RjkNZTuauy-Q", + "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "On CPU:\n", + "10 loops, best of 3: 35.8 ms per loop\n" + ] + } + ], "source": [ - "# Time GPU-based matrix multiplications.\n", + "def time_matmul(x):\n", + " %timeit tf.matmul(x, x)\n", "\n", - "if is_gpu_available:\n", - " # First use of the GPU will be slow:\n", - " print(\"Time to conduct first matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)\n", - " print()\n", + "# Force execution on CPU\n", + "print(\"On CPU:\")\n", + "with tf.device(\"CPU:0\"):\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"CPU:0\")\n", + " time_matmul(x)\n", "\n", - " # Subsequent uses are much faster:\n", - " print(\"Time to conduct second matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" + "# Force execution on GPU #0 if available\n", + "if tf.test.is_gpu_available():\n", + " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"GPU:0\")\n", + " time_matmul(x)" ] }, { - "cell_type": "code", - "execution_count": 0, + "cell_type": "markdown", "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "E5pIOe3Rz7iW" + "colab_type": "text", + "id": "YEOJTNiOvnpQ" }, - "outputs": [], "source": [ - "# Second timing demo for GPUs, after it has been used once:\n", - "\n", - "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", - "print(\"Time to conduct CPU matmul:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)\n", - "print()\n", + "## Next Steps\n", "\n", - "if is_gpu_available:\n", - " gpu_tensor = cpu_tensor.gpu()\n", - " print(\"Time to conduct GPU matmul:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" + "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n", + "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks." ] } ], "metadata": { "colab": { + "collapsed_sections": [], "default_view": {}, - "name": "Eager Execution Tutorial: Basics", - "provenance": [ - { - "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg", - "timestamp": 1504118841551 - } - ], + "name": "TensorFlow: An introduction", + "provenance": [], "version": "0.3.2", "views": {} } diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb index e6c7c117333e1e10aa571dae295e88747bd7d764..9c1af9c2084bac7ae6369babeaa13720e6199097 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb @@ -7,12 +7,9 @@ "id": "vDJ4XzMqodTy" }, "source": [ - "# Eager Execution: Working with Gradients\n", + "# Automatic Differentiation\n", "\n", - "This notebook demonstrates:\n", - "\n", - "* How to get gradients using TensorFlow's eager execution capabilities\n", - "* How to apply the gradients so you can update your variables" + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." ] }, { @@ -22,7 +19,7 @@ "id": "GQJysDM__Qb0" }, "source": [ - "# Setup: Import eager and enable eager execution.\n" + "## Setup\n" ] }, { @@ -40,14 +37,10 @@ }, "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "import tensorflow.contrib.eager as tfe\n", - "\n", - "# Enable eager execution.\n", - "tfe.enable_eager_execution()" + "tfe = tf.contrib.eager # Shorthand for some symbols" ] }, { @@ -57,28 +50,15 @@ "id": "1CLWJl0QliB0" }, "source": [ - "# Fitting a Simple Linear Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Step 1: Synthesize some data\n", + "## Derivatives of a function\n", "\n", - "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n", - "\n", - "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model." + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -86,170 +66,115 @@ } }, "colab_type": "code", - "id": "rQsdCg9PfIL-" + "id": "9FViq92UX7P8" }, "outputs": [], "source": [ - "# The constants we'll try to fit our variables to:\n", - "true_w = 3\n", - "true_b = 2\n", + "from math import pi\n", "\n", - "NUM_EXAMPLES = 1000\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", "\n", - "# Our inputs:\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", + "assert f(pi/2).numpy() == 1.0\n", "\n", - "# Our labels, with noise:\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", - "labels = inputs * true_w + true_b + noise" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 360, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 127, - "status": "ok", - "timestamp": 1505502830690, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "O4lsC4ckAcar", - "outputId": "2f760690-cafb-4777-b970-91d839f99faf" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAesAAAFXCAYAAACC+2avAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXt8VPWd99+TK7kykxtJQIebqZfaqogtrhKNa1ooEKl9\nCrpVn9ZNW6x9VWsbCi7aVUt01NZ9tq21KVZlFey2YkQNohhj3QWK2liCF5RIBCc3yEwmIZnMTOY8\nf/zmzJwzSSBAYibh+369eIU5c87vXLh8zvdu0TRNQxAEQRCEmCVurC9AEARBEISjI2ItCIIgCDGO\niLUgCIIgxDgi1oIgCIIQ44hYC4IgCEKMI2ItCIIgCDHOiIj16tWrufjii1m8eHF4269//Wvmz5/P\n0qVLWbp0Ka+//vpInEoQBEEQTjksI1Fn/eabb5KWlkZFRQWbN28GlFinpaXx7W9/+6QvUhAEQRBO\nZUbEsr7wwgvJzMwcsF36rQiCIAjCyTOqMesnn3ySsrIybr/9drq6ukbzVIIgCIIwYRk1sb722mt5\n5ZVXqK6uJicnh8rKytE6lSAIgiBMaEZNrLOysrBYLAB885vfZPfu3cc8RtzmgiAIgjCQhJFaKFpo\n29vbyc3NBeDll1+mqKjomGtYLBba2yeuuzw3N0Pubxwzke9vIt8byP2Nd06F+zsWIyLWt912Gzt3\n7sTtdnPZZZfxwx/+kJ07d/Lee+8RFxfH1KlTueuuu0biVIIgCIJwyjEiYv3ggw8O2Hb11VePxNKC\nIAiCcMojHcwEQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBOHXo6HCzcmUt\nTU2Z2O2dOBwl2GzWsb6smEfEWhAEQfjMWLmylurq6wAL9fUasJ6qqqVjfVkxj7jBBUEQhM+MpqZM\nwBL6ZAl9Fo6FiLUgCILwmWG3dwJa6JOG3e4Zy8sZN4gbXBAEQfjMcDhKgPWhmLUHh+Pysb6kcYGI\ntSAIgvCZYbNZJUZ9AogbXBAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFr\nQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhx\nRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBODE6OtysXFmL02mjsLADh6MEm806rGOamjKx2zuHdYww\n9oyIWK9evZrXXnuN7OxsNm/eDEBnZye33norn376KdOmTeOhhx4iIyNjJE4nCIIgACtX1lJdfR1g\nATRgPVVVS037RIuzz9dDTc33AQv19YMfI8QeI+IG//rXv866detM237/+98zb948XnrpJb70pS/x\nyCOPjMSpBEEQhBBNTZkooQawhD6b0QW9vv4qqquvZ/v27mMeI8QeIyLWF154IZmZ5j/wbdu2sXSp\neltbunQpr7zyykicShAEQQhht3eiLGoADbvdM2CfaEGH7GMeI8Qeoxaz7ujoICcnB4Dc3FxcLtdo\nnUoQBOGUxOEoAdaHYtYuHI7LAbPru61tD1AM2AAXkyY5sVr/CBxi3rwMHI5FY3cDwrCJuQSz3NyJ\nHdeW+xvfTOT7m8j3BhPz/uLi+klOTgQgOTmBnJwMsrIyuPnm5w2x7DKmTbuPgoJzaG7ew8GDt6PH\nuDMyNpKdncFNNz3Pxx+nM2NGFw8/vJCsrNhLOJuIf37Hw6iJdXZ2NocOHSInJ4f29naysrKGdVx7\ne9doXdKYk5ubIfc3jpnI9zeR7w0m7v2Vlz8XFuVduzT6+lSy2N69KRhd3zk5Z/LCC5dRWtrPwYOR\n7Xv3pnDjjYOvEUtM1D8/neG8iIxYnbWmaabPJSUlPPPMMwBs2rSJK664YqROJQiCIDB0gtlQsezB\ntg8nSU0Ye0bEsr7tttvYuXMnbrebyy67jB/+8Id897vf5Uc/+hF/+ctfKCws5D/+4z9G4lSCIAhC\nCLu9M1R+pdzauijrsWxVruUJx7JXrZrDrl2VuFzTsNkOsnr1EtaufWvQNYTYYkTE+sEHHxx0+2OP\nPTYSywuCIAiDMFSCmc1mHdSVXVn5Nk7nKsBCb6/G2rXrhxR2IbaIuQQzQRAEYXjoojxYTHewTmWD\nubyHEnYhthCxFgRBmIAYu5vpncrsdk1c3uMUEWtBEIQYYai+3SfSz3swK/rpp+cgLu/xiYi1IAhC\njDCYNVxVtXTI7UdjsOQzcXmPX0SsBUEQYoShyqhOpLxKEscmFiLWgiAIMcJQpVjm7S7a2t6ltJSw\nS3ywphojYUXLOM3YQcRaEAQhRhjKGjZub2t7F6dzFU6nconX1T1AaelU7r770mEL6XBF+ETc78Lo\nIGItCIIQIwxlDRu3l5aC0xlxibvdZ/KnPy06rjahwxVh6W4WO4xYu1FBEARh9IluGQpqPvXxCOlw\nRXg4IziFzwaxrAVBEMYRuku8ttaPx5MCLAQ0CgoODXuNoWLjQ51LktTGHhFrQRCEcYTuEr/hhv+i\npiYBeBY4xNtvu3G53ANiz4PFp4crwlLqFTuIWAuCIIxDmpsLgF5gOWChtVWjomJg7Hmo+LSI8PhC\nxFoQBGEcEG0hFxT4qK+fwrFiz5IkNjGQBDNBEITPiI4ON+Xlmygt3UZ5+TO4XO5hf69byPX1V1Fd\nfT0QoLBwN8dKAJMksYmBWNaCIAifEdEu6V27KqmtvS4cZz5aSVW0hdzcXEBt7SIqKgaOyDQiSWIT\nAxFrQRCEz4howXU6P09FRe2Qgmx0WRcUNFNf/xSQAXgoKPAcdUSmznCTxKRbWWwjbnBBEIRBOJbL\n+kQwu6RdwLts3Up4/aO7rBOBa4DFwLWhzyNHtJu9oqJ2RNcXTg6xrAVBEAZhNFptOhwl7NpVidP5\neeBdYCW9vRaqq9X6DkcJfX3r2LEjDjiMz5cWLsdqbs7B7AbPOalriUYS0WIbsawFQRAG4XjFaziW\nuM1mpbb2OsrK3KSkFA5Y32azkpychNv9bdzun1JTsyJs4UZb3QUFLeHzLVv21Elb/pKIFtuIZS0I\ngjAIw+3ypROxxDupr3+RurqXKS6OHxD71WPI5eXPhCxq8/pDvSREJ4r5fAkmy/94eoMb0WPV+/Yl\nUFhYSXZ2ETNn9kgiWowhYi0IgjAIx5tFHRHZGmABbvcWqqvT2LXrCWprrx+QrOVwlODzPcL27V1A\nNj5ffzhuPdhLQnSiWGnpNoZr+R8teczo7geNuXNlslYsImItCIIwCMfbajMisunAFvTOYk7n4kE7\ni9lsVpKSUnG7vwdYqKnRSEpaf9SXBKPotrXtAcoYjuV/PCVhEquOTUSsBUEQBmEoa3So7brI1tW1\n4HafyXAEcDChPNpLgtkKLqawsJK8vLMpKurl7ruHtvyPJsjH6+4XxgYRa0EQhEEYyhodarsusi6X\nm8svfwKnczFDCaAu+Pv3t6CSuoYnlGbRtZGXdzZbt15Bbm4GH3xwgPLyTYO6uo8myNI0ZXwgYi0I\nQswylo06IsLoBmrC9dD79iUQbaV2dLi55ZaXQiVXh5gzJ5kvfnEdzc052O0eVq26wCSkPl8PNTXf\nBzqBDVitXoqLE44plEcT3aO5uo8myDJZa3wgYi0IQswyGrXOwyXSMcwJ3Bauhy4srCTaGl65spYt\nW24Mb9u2bQNlZQG2br0CgPLyTab7sFofCO1rBa5l+vRnqaq6Ilz+NdTLiS66+/bF09HRRGNjEeXl\nz/Doo2VHdXWLII9/RKwFQYhZxjb5Se8Y9rzpGrKzi5g7V1mpBQUt+HwJvPZaErABWIgS4AyamvrD\nK0XfB2SjOphtAdJoa9uDyzXnmC8nuuhef/3TNDSswum0sHu3xo03PoHdjsSeJzAi1oIgxCxjmfwU\n6RjWhdGSnjmzJyygRotZ7fMgUAD00NbWTmkphnGWkTXmzQvyzjsP43SuwpgxPtyXE+Vuj+xXV6ex\nY8cVSOx54iJiLQhCzDKWyU+RF4WFDBVXHmgxfw5YRHLy7TidP8XptFFfr7Fgwe8oKzPex1dYtuwt\nnM7IsVu3gs023HKsQxhfIOCQuLonOCLWgiDELCMtQEdLWIv+bvXqOUReFAI4HFcOSG6LtvyhO/T7\n01Eu7nSgiwMHMnn11SVHPba3N5He3psoLKwkK6uIjo697Ntnp7z8mQGx63nz0qmp2YCawNXF/Pky\nHWuiI2ItCMIpw2Ax4fvuu5yVK2upqwvgdicDl1FfP5nBktmGEvTaWj8eTwrKCteAg4BqdgIaHR2V\nA65F9xps3Qq9vX7Uf8dv0NOTwFlnfUJDw3SczgwaGjz4fM/z+OPfCl8DJGG1eoGDzJuXwaOPXkN/\n/4BTCBMIEWtBEE4ZBosJR7fbhI3ANYPGi4dKALvhhv+ipkYD/gDk4PMlocqyAGpwuQoHWMjmHuEp\nqGQ2C273Il5//U7g1vA1bd/+AKCEuqRkfTjWDarrWVaWdch51sLEQMRaEIRThsES1gbGndOJjhfr\nFvXWrRj27WTz5oMUFf03gcCh0PbbAQuapqGywy3ActMYzGhr3eEooa7uZdzuyDX090+PuqZsQL0s\nqPGa0h70VEPEWhCEU4bBEtYqKl41CbjV+j7FxS5TIlnEot5AJLHrRYLBVSGR1YDHMQusDzWF+OjC\narNZ+fKX+9myJXINOTkHaWszZ4+D7hno5ni6ngkTAxFrQRBOGWw2azhG3dSUSUXFq1GJZB4cjuUD\nEsn27YtHucctwL1YLFPQNLMQQzvmDG0L8Klp2/vvv0lJyRFmzQqYXOIWSwD1IqASxs49N430dHP2\nOOiegSWha0mjsLABh+O6UXteQuwgYi0IQkwQnby1atUcKivfHvFWoyfSFa2jowmIxImTk9fg9Z6F\nWZyt6CIKO1BW9W3AfcDZwBG83ttoaNhCQ8P1pvM2NxcAV4XPd/jws2zYcMWA61Cegc2hZ+LG4bgO\nTYNlyzawd2/KZ96SVfjsELEWBCEmiBbRXbsqw4lUx9NqdLDyrNzcjPD3J9IVbfLkqTidG9FLsU47\nrZDZsz1s3/4A3d2ZBAKTQmumAW8CPwXeAGzAOcBiw2rpA8473OYvg5WyRbcy/SxbsgqfHSLWgiDE\nBNEi6nJN40QSqQaznJ999vrw98cSxsHEvrPzU4yW9ZEjlfzqV9excmUtjY2pHD78AR5PBt3de1Ad\nzGxEOp8ZO6C5gJ1AO3v2OLnhBicPPbT4pJq/yDzqUwMRa0EQYoJoEbXZDtLbG/mcn38oPOSioKAZ\nSAxNtTK7fo8lXqtWzWHnzntoa8sjPv4Q3d3puFzu8PGDiX12dpGp21h2dtGAkq/k5DXARcAelCir\nzmeZmR34fHfg9Z4P7AXuBiz4/VqosckLJCWlnrC7X+ZRnxqIWAuCMGocz4jLaOty9eolrF0b+ezz\n+amuVpOt1DSsaxisucn+/QHgSeBrwOQB4lVZ+TYtLbOAawgGLWzbplFREXEdDyb2M2d2snu3uT94\n9H59fRcBS1Au7/tISSmktBQcjjKWLXuL+vqrgM2mYyCD7ds/xe3+HsNxYw/2PB2OEpKTN4Zi1tIT\nfKIiYi0IwqhxPMlcg8Vjq6rs4d+Xlm4jInQZGEVv61bYtesJnM6bUC5oNYayuHjKAPFSmd1O1DSt\nLmAhdXUBGhubqKx8m/37W4gujVq1ag67dlXick3DZjvA6tVl3HnndswJZkfC1wNTuOyyI1RVqa5j\nEeu3K+qYLlQN9fDc2EM9z6efvkaaokxwRKwFQRg1oq3PfftSB8xr1jSGZX2b3b0ejKKn+mqvRu8+\nBhamTz+DqqqBGdXRmd2wAbd7Epde+if8/n9HdR4zD+6oqKgNJ7v19mosXVpJd7debuUDmoHvh86g\nAclApP+ncQ71oUP30NMzlbi4w8yblw4khLqfHduNLfHpUxcRa0EQRo3oeGpHx14aGswZ3sCg1uJQ\nfbhVL20f8ERo3URgAZFsbDia6EXHn5XYXoXfHwh9tqLizVU0NZ1BRcWrNDamYRTJSBexxSjX9lVA\nDSrT+wPgX2lufi18zqMNJHG53CQlDS+5TOLTpy4i1oIgjBrRceh9++wGoeykrq6Vvr4pKAt1IWAN\nW4tDuXxVL20Vu1ax6eXo4lVY2EBeXvCoojdz5pFQ/LkTeDG09QXgI4zdydzun1Bfr85dWLiWgS5v\njYgrezLqheFFIAd4gYKCwYV0sLjzcEutxnJkqDC2iFgLgjBqRFuU5eXP0NBgFkTzAI3lYWtxKJev\nUbCUIK4LZYV7cDiuO2YmtX58bW0LHs9PDedfh8redtHb68bvj8S0s7KmM3euOmdb27s4nStQrvh7\ngQySklbS359Mf/9cVDvQBcBfBj3/8cTxT0bYhYmFiLUgCJ8ZRqHdv99rGl6RkuKntHR92FocyuUb\n/QJgFLSKilePWfqkH19auo36eqM73A/00tX1KZr2C4wx7Vmz+sOu+VtvPURPzyaOHPkYv//HgA2f\nL5Kdrr94NDfnDCq2xxN3PpFua8LERMRaEITPDKPQKnd2RIxLSzEJ0VAu32gB7Orq5NVXf4guaD7f\nOh5/fNmAc+vH7dsXT0dHE93d8Zhd25OBa9G05zCKqdXqxeG4ElDiWVNzI2ZvwDVEZ6dDGna7e1Cx\ntdu1YcedJaFM0BGxFgRhTDhW/HWopKxoAUxMXItR0LZvjxtwzOHDxjnQG1HZ4CrrOzPTS3d3C8Hg\nTaG9zVOtiosThmy4EkloM2enT5q0i9Wrl/G9731EtNg+/XT04BBJKBOOjYi1IAhjwtEypI9GtGD2\n9+dgtpAPDzjmpptqQhncnahJWJF4dFzcM+Tnazidk0N7LwDuwGqdQXFxAqtWXRAuN2tr2wMUo2q5\nXUyatAuLxU1m5l407S7a2s5HDez4MZdc8is0bSrwGCpbXDVoOZ77loQyQUfEWhCEMWG43c2i9yso\n8Jmszby8Vlpa9PGSrfT2urDbN2GzHWDTpjJmzLDz8cdqAIfK1r4NYzwaDvPHP17OkiVr6OubgcXy\nMZdcks4f/nAlmgaXXfY4LS0/ALYA55KYuJaUlCx6erLwes8EvkZv72Ss1gdQHcwUfv+Foc9DN2g5\nFif6QiNMPESsBUEYE4abPBW934IFv+OKKx6hrs5CMHiY/v4errjiEIcPp/L++014vSo5rLfXRXHx\nL5k9+4vs21cPfA7owezG9jFvXjq//e1H9PWpnt2appGVtR5Ng5KS9bS0fAEl1KpEzO/vxu83J5Op\nuHU2Q3U0G6pBiyAMFxFrQRDGhOEmT0Xv19xcQFvbuwQCqrlKe7vGe+9VUl9/RSimq++7Ba/3Lhoa\nLMDVKFFNxyioU6Y0AqezdStE13qvXFkbcp13o4+1VEQnk6k1580LkpS0nrq6AG53K8aOZhJrFk6W\nURfrkpIS0tPTiYuLIyEhgT//+c+jfUpBEMaI4xncEZ08ZZyqZTx2sCSrDz4wj89U4zTBZjsQmtTV\nCfQxUFQvJTPzfk4/fSYdHXvp7k4JZXfrDVKeBRIpKPDQ1FRApGb6d6huZQNbnVqt71Nc7MLh+Ao2\nmxWXy80ttzzP9u1/ALKZNy+Iw/GVkXvIwinJqIu1xWJh/fr1TJ48+dg7C4IwrhnKtR0t4qtWzaG7\n20Ni4lr6+3PIzm7i7bcTaWubA3RTX78E2ExV1VJWrDiDmprb8fnsQCtvvHGE9HSLaXympn1ISclL\n+P3dJCSsIRBIAaZjdkt3A5NJTw8wa1ZPqO3p86HvazDXSa8LvSQsAZ4DJmOxrCEjYzpz5/aQlGRs\nxLLc9EJis1l5/PFvfRaPWziFGHWx1jSNYDA42qcRBCEGGMq1HS3iu3ZV4nTmoeK8GbS3dwA/wxgH\nbmrKZN++JhYufJ5gMNKk5PDhDcTH/4NJk9agaTPw+/fh9f6UhgYbEXd3MlBCxPX9D1QG94N0d/tp\nbEwNradPwUrH7GrPCZVYbWbfvgQ6OtxkZ5/HzJlHcDiWhsW5o8NNRcXwPAmCcDJ8Jpb1jTfeiMVi\nYdmyZXzzm98c7VMKgjBGDFUXHC3iym3dBhgbjAxsKnL11c8RDH4u6rsM+vtn0t9fTmFhJU7nl1FC\nrH/vB3YDS1HWsga8AawGLHg8GocP672+F6Ji1fuARabr1jOxy8s30dCwCqfTEuopHvEWqNptFdeu\nr19CX99fSE5OEvEWRpxRF+uNGzeSm5tLR0cH3/72t5k5cyYXXnjhaJ9WEIQxYKi64GgRV7HlwtBn\nN7AntIKKERcWNrBq1RIuvvgg8CEDZ0C3As/jdPaihHmx4ftEYAZKhDOIWM+6ld1FZmYuUGmYnnUd\ncB9wNoWFDTgc14Xv6WjeAn1spr7+jh1xuN3SHlQYeUZdrHNzcwHIysriyiuvZPfu3UcV69zcjNG+\npDFF7m98M5Hv70Tv7fBhNzfdVMPHH6czY0YXjz66hKwsszX56KNlrFixMbRPN2vXfotLLnmMlhYN\nFS+OuMCnTbuPd965iRUraggGVwGfALejYtDtKCs6H7gUaADsqIEaBSj394LQmkYSME7n6u6+j6lT\nz8XpXBzeIzW1kEWLjvDwwzeZrr+oqMf0olFU1EtubgZOp41ob4DF8qlpm9Np+8z+zkzkv5sw8e/v\nWIyqWPf29hIMBklLS6Onp4c33niDm2+++ajHtLd3jeYljSm5uRlyf+OYiXx/J3Nv5eXPhePRu3Zp\n9PUNZk3G8+tfLzJtOf/8PGpqNgD6HGkACzk5Z9LfH8/evSmh7XagAvgD8AVU/PkHRIu8Emz98/6o\n78wtSW222WRkfAw8hbK+D5OW9iF7987lO9+pNrmvf/zjL/DGG5W4XNOw2Q5w221ltLd3UVjYgdHi\nLyxs4ItftFJTY9zm+kz+zkzkv5twatzfsRhVsT506BA333wzFouF/v5+Fi9ezCWXXDKapxQEYYQY\nbhnWiQ6baG4uQLXhfAyj6L3zzm7OO28PZ51lrImeDEwFFpGfX09Ly2Sik8JU0xPlyk5IsBIIGL/L\nMp1j5swedu7sBH4Y3tbevoH29qvCCXB5eWdjt3fi8/nD7u7eXo21a9dTVWUfxOWvXOdJSdIeVBh5\nRlWsTzvtNKqrq0fzFIIgjBLD7TA2VFLZYOValZVvG9qGHgkd14MxvqxpBTidNxIM3kNZ2XoaG1Np\nb3+Pnh6NuLg/cs45mZx//jq2b+/A7Y4khcFeVCMSK3APA/uFb8Bq9VJcnIDDcTnnnVdLdOKa/nun\n8/M4nUuor9ewWv/IYC8jQ7UClRi1MBpIBzNBEAZluBbzUEllt9zyElu2qGzv+nqNF164g0DgrvDn\nBQvWsWDBOmpqfIbzgJpkZaGz024Yp9kTfnHYts1Fbu6DdHUB3IPFkk1c3Kf09/8EJdQagUAPkYSy\nbiwWK0uWBHA4rgx7B1SSmwvVSjQNleR2KcqKj7QKhUMYhV+6kQljwcBZcoIgCCiLWYkUgMb+/R9S\nXv4MLpfbtJ/NZuW++y7Hbvewb18ql1/+BCUlz7FtWwuqMxiAhUDAjlH8X3stgaSkROLjm4GvEmnr\n+S7gQtP2hs9lfnF4hvb2NPr7LwJmoWnX0N9fBNRgtT5KYWElKhltOSpLfDmTJ3cDsGzZW+F72LSp\njEmTfhnabwnwMzIz/5NJk+5AJao9BbiYNy+DsrL1nHfes5SVrT8h13ZHh5vy8k2Ulm4b9BkKwrEQ\ny1oQhEFxOEro61vHK68ECAS6cLu7qK6eyvbtf+Svf/22KX5tdJmDhtO5EZXBvQG4FiX6jRgt1N7e\nZKqrl2Ox3INxUIYS2Pvwem+jokJ1MTO72l1EN1BRMenFTJv2JJ98YkH913Ynykp2091dSHV1PHBZ\nKCb9MHl5ZzNp0gy83sgLRFzcJLzefwuvXVhYyUMPXXfStdLDDSkIwlCIWAuCMCg2m5Xk5CQCAWPj\nko20ta3hllseISkpNRx/bmxMY2AfbjXVCjaj6qKTgcdR86SnAN9ATbmaiu76jhxfAPyOLVuslJc/\nw+rVc9Bd7Q0NGVHJY16U21qjo6MJj2cFSvwvBHYBd4X214UdnE7V5ASexBzbzjZdR17e2SPS1ORE\nk/AEQUfEWhCEIYkWGV2Et2/vwu3+HrqlOGXKHShhzkANuuhFiV8Lqq92I5oWaRmqLG5QrmY/8L+Y\nG5skAT+jr+9RqqsTqav7G8XF8Tz99Bxuuul5tm0zCmwLaWk9/PM/r6exsQin02ilw8Dr7yISz+4h\nM/NeZs48C7vdg8/Xbyq9Gqn49FBJeIIwXESsBeEURs/YdjptFBZ2DCjPihYZFVceaIEePpyAcRBG\nQsJdBAIbUNnZk5k82YXbbRRNN/BrlKtcubaTk9fg988kGExBNTbRXd7fwe22UF2t3MdJSQA/B+ag\nLOrvEx9fFWoN+gy7dycSEWM9acyGPiHL6+3E6707fK3p6ZVs3apmTbtc7lEpvRoqCU8QhouItSCc\nwkTHmqNjqQ5HCUeOPMJrr2kEAk7i4rK45JKH2Lv3CGoaVTdwMYFAPkbxPuusc5g5s4emptcGtVhV\n1na84RgbfX3TgY+BuahxlQuAHAa6jzNRLvUl4evs6ckIX++WLY/Q16eL8SLgDmy2WcyfH4fDsZxv\nfGMnu3dH1szOLgqvM1Q51skyWusKpw4i1oJwCjBUg5NjxVJtNitPPfUvpm3l5ZtoabkFXXgtltvR\ntHOIxH5d7NnzNh99VITNtodHHilD0+CVV+7E75+NillfS2Lievx+o4A3Ar8wrZuRkYHHE+0+1qiv\nbzadr7+8jtPXAAAgAElEQVT/AKWl27DbO5k58wzee89oxV/A7NkJVFVdBsDMmUdCAzkiDVIEIdYR\nsRaEcc5wOo0NlY18tFjqcAVe07JQFnEVqnd3F8FgJb29quPX0qWVzJ07Db//34kI8wbmz8/kf/7n\nDrzeuSh39udN606ePJudO6/kllseYfv2LiAbn6+fn/98Hps3dxAMRlzdmvYL6uvVveXn/wJz0th7\nfPRRIeXlz+BwlIhLWhiXiFgLwjgnWojr6h6guDjPJNpDWdC6cKmYtcskXMMVeNU0pNLw+bemc7lc\nBQPOHxfnYc8eNxbLLJQrfSHK9R1Zd9IkJwBJSanhZLaaGo2kpPV85Svp1NQsN5wzsnZPTz6RjmgN\nwApcLls45l1VtXTYLunhtlwVhNFGxFoQxjnRQuh2n0l19SKM8eehLGg9ljrYoITIum6ghq1bCZdR\n9fWtY8eOOI4c2Y/ff6bp/Kq1Z+RcmtaI3T47dP5O4EWCwSAtLbOAr6FqoTcCC4iLu51g8MvAEVpa\nfkBFxeZBXzSefnoO77xTidM5DeVWj2SS9/a2ohLXQL0IbEHPAt+3L/64BFjqo4VYQcRaEMY5g2ds\nm+PPJ+L6LShopr7+KVRJViK9vUuorp7MSy+t4Z/+yYbb/R3gEeAAZrezF2OrT6/XRm1tVyi2nQXc\nZth3I3ANKSl9XHbZ0/zP/2Tg8fhQXczcvPiik/nzC03rt7Q0cMstzbhc01D/hX0/tE4asAO/f7ph\n//0YG6h89NEdfPnLTtzunzAcAZb6aCFWELEWhHGOLsR1dQHc7kkol7Kyns1WpMbTT885DjduIsZy\nLF1Yvd6LeP31N1EW60qUtfwEaiBHO6qLsdFFvQGP59rQPtEzoPuA5wgEGnj11UmGLO6rgY34/d+n\noeEOpky5k9bWTCCHlpZp1NT0oCzq7xPp7f0ukAt8E3gUVfaVazqf13s+Xm8iwxVgqY8WYgURa0EY\n5+iubJfLTUVFbbhcyuG4nIqKod24RiEvKurh7rsvNQl5c7O5bEpZyhpwhP5+O5GuY1ZUE5PridRG\n3wecg5o9nQc0AR+iyq6Mk7KSgCX4/Xpf8IENWDyeWSQnt2O2yO8AZgPVqBeEbuAW8vP/MzQ+MxX4\nDip2bbT6+1CW//AEWJLRhFhBxFoQJgjRtbwdHW7q6gIMZUVGx2O3blWJafooy/37WzAL3QcoUfwq\nmnY/0EYkVmxsF2pDCfWi0P7LUeJ9F8oK3xD6eRi4OXSMGo9pPp9qwKJp+4AZmIXc+HKgERdXyeLF\nm1m9+uusXbuerVuht9eC8jJsJDXVj9V6EKdzRegY87jM4T5TQRgrRKwFYYKycmUtbncyRgFsa3sX\nl2vOoCVYbvdsqqt7eeGFZwkEfoBqevI4CQkHiY930dcHyjLdiKaBGtDxBMpSNSd5KYu6m0gnsjwi\nVvi1obUnh36BalGqhFUJ//+GjrkTTZvOpEmfmu7DYslG0yLXnpmZHxbVqio75eXPhLK/rcByFi3a\nyN13XxdOWLPbzeMyBSHWEbEWhAmKEuPLiCR7fYDTuWKISVYaKuY7hUAgDlV+dROwhUDgCwQC/wvM\nAv7VsP8TKAs3A9iHxXIP8fF5BAKTQts04K+AhylTPqa11Xiut4F+LJZ7yMgoJDHxQw4f/hg4HdUi\nVG9peit9fRZaWlwUFlaSl3c2druH7m6LqT/4vHlB071Hu68ffngJ/f3xYiUL4xYRa0GYYOix6P37\nA8ALRMqjGgAL+/bFU16+iX37EigsrKS7Ow+PJxXoR8V6VwHPM3Bs5YOYXdFeIq7opSQn34HFkkIg\ncD1qulYkOe3ccx8hGFxDe7sVSEFlmF+IpnnxeBaQmPhbIuVWoCzvFoyu9by8s009vCsqjLHkr5ie\nQbT7OitrYGmaIIwnRKwFYYKgi3RdXWu4NElZqA8CU1GZ0y/S0dFEQ8Oq8PcLFqwjI8PCn/6Ui7KI\nLaj4cXTCVzbmmHIbcC9wJtCL16u3En0KJeSRY1991UJcXAoqSWwjymqPZJn7/YVRax9BxbUHTwST\nWLJwqiFiLQgThEjC2POYRfZzKMsYrFYv2dlFoVnO6vvm5hxefPEqtmypxOPJRAnkQuBhzHHoD4B7\nUOVQn6Cs9dNQ/43obvR7UWIcB/wB1VAlm2CwjWBwNsYs78j1pQE+EhPvxO+/ECXUXwX+jB7DLixs\nwOG4bljPYbCmJ7m5GcN/kIIQg4hYC8I4xihMjY3NKGs0Oqv6g9C2i0lNbaGpqdXwvYuWlgYuuiie\n1FQfHs8hIsM0koG1wNmo+dR+4N8M696DuQ57PxExVo1U4EbD9/eGfkZf35vArcyf/0fee68Bl6uA\nYPBBEhPzSEhwMW9eBg89dJ0pGexoXcgG6zr27LPXj+RjF4TPHBFrQRjHDBxxuQFlFW/AYulE0yaj\nksImAz/B6ZyDsnZ19/X7tLTcTkuLGiepYtjxeDyRrl9qNvVUVDmWbhF3hn4+jxLfhahY9FMoV/jn\nQvsaLegzQ9fXgYpPzwQ+4owzTufsszfj82XidN4aPm9f30ZgOe+8U3nU+46uH5euY8JEJG6sL0AQ\nhBMnWpiURfssYCEpCVSZlJVIzPkaVLz4Z6i4snnSVV7e2cTFTQltawLuIxA4DTW+8gPUCwGooRv/\nhnKTXwO8SE5OV+j35aiMbo9hfw34O6rL2b+EzvuvQCVnn51OVdXSIZqwdOJ0JvGlL71MefkzuFzu\nQe/bKMh2ux7rVueVrmPCREAsa0EYx+Tnt2N2KX+KsliXk529FqfT+F02A8XQQ3QS1/79h4hY6SsN\nx98R+mVHZY4b1+rC69Vbieq11A+jeocfRjVKuZWMjN/j9f4Wv/8H4WN1oR28x/mLwG243Zbw1Kyf\n/ewC3n//TZStoWq5jYIsXceEiYiItSCMA4aK0VosAZRL+xxUYtZNJCb+ioUL13PTTZdTVnYHXu8Z\nKBHXk8f0lqC7gHxgLSkpUykuDuDz+QkGe1FCrTcyIfTzDOA6VH11E+aXhAz6+oyNS14GvoDKLs9A\nxbxtBAIF5OYewOmcjHLHv0hjo5fzzvt/ZGbmUlhYSUdHPl7vx8BZKE+B2YK++urn8HrvDp970qQ7\ncDi+G35WkikuTERErAVhHDBUjLa5uQBlyR5BWco1zJ49i6qqpZSXb8LrvQtdnJOSKklIuJ2enjhg\nGiqurGqws7PvIzm5kOrqG9HHWMI+zILsDH13AGVdr0G9JAAsJCnpMIHAGjStCJVsdhvKotbLxzR6\nexPp7b2JwsJKenoScbt/gsdjwePRcDo3AuUUFq7F6bwRlQneHzpeXdP+/V48Hv2zcu9bLGdIJzJh\nwiNiLQjjgH374ol0IusKfdZdx06MYyA7OytDx6QSsUq34PPdh893H2bX9qNAKocPF1BX10JEBK8F\nqlCZ4fkogc5CDc643XD8htC+GoFAK5p2t+E7NaVLfc5An1kNVvLyzgagvn7g4I7U1Azi4n5PMHgm\nyoJ/mMREF37/atzugee12Q6OwBMWhNhGxFoQxgEdHU2ozmJKrDo6lCA7HCVs2/Ys3d0RIXc6M7nh\nho20tzehRk0aB20UYnZtu4Dv0NtrobdXL69agcoeTw/9Mo67/H3U8T5SUp6gtBS2bp0V9V1a6Pca\neXlO2toy0NuPFhR4SEpKHSRGrXHwYDvB4C8M2+8jIeE0/P7I2gkJXSQmPoHNdpBNm5aMwBMWhNhG\nxFoQxgHRjUy6u/MpLd2G3d5JWlob3d03YxS3mpofkJFRScQa34PK3Nbjyrqr20qk3MsKTCUh4X7i\n4vz4fBehksOMAtwedbwPTfuE1auXs2tXdUjw9VjyLs48Mxjq5Z3Ntm3Gmux14USwxsZUDh/eS1aW\nnVmz1g8i+oXYbAdMa3/taykSlxZOKUSsBWEcMHPmEXbvjoiVxzOJ+vqrqK/XyMy8n4Edyyx0dWWh\nOoFtQcWoVwE5KDd2GrAas8t6OZBIIHAPSsD/GZXR/RyRCVr5qASzT9AbpHi9LoqLf8mMGbPp6FiD\nxTILm62ZTZuWMWOGHYDS0m2ma2xuzglN7oL4+ATmzp2KwzEfm83Keef9P5Mwx8W9z2OPLeI3v5EM\nb+HURcRaEGIUYwZ4QcERFixYR3NzDvv3f4jbXR7ay0JcXA4DO5Y9h7Ki70fVNH+CyuZuAaaj/ukb\nBb6XSExZjzHXYIyFq45lFtRkrFzD8Vvweu/ivffUfmVl66mq+qHpXqLLsvLzD1FSsh6n8/NAN/X1\nSwA1DWzTpjKKi+/A650LHCEY/Cm/+c1msaSFUxoRa0GIUaIzwBcseCRUB52NcZqWGg+5jr/+tZfu\n7k+Bi1CW8OmoxiMbMVvRG1ATuIwCvxeoNHzuIjLUg9DPPOC7od8/aTg+zbSfPtXLWGbmcJTQ17eO\nHTvigMP8/e+dtLYas8U3huutZ8ywc+aZc0ICrpAuZMKpjoi1IMQI0bXU+/aZrd/t27twu7+HLqiZ\nmfeTnh7gwAE7s2YFuPRSqKkxCq4+0jJ6cEYGkEpy8hr6+opQJVnXAveRklLIxRd3smePO9yCNLJe\nu2Gdr6EsbTvwIcaBH8apXvX1Gjt33oPXO4nu7kwCgRRUh7PJRJLZrEAaBQVt4WcRbYlLFzLhVEfE\nWhDGEKNAt7Xtwem8CbBRX69RWFiJ2fo1dyCLi8vB6VyK07mFhgYbCQnNmEVZH2kZPTiji4yMOI4c\nyUEN2zgHZWmfTmlpAJhMS8vNqCSyDajGJMmobmcuVAw8DeU670MN67gPOJ1Jk97D5TIniLW0nAbc\nYDi/XtJ1DsrVvhyVABeplT6RLmRHG+4hCOMdEWtBGEPMgzjK0OueIZ3u7ngWLPgdzc0F2O0efL5+\namoiohsMtqISwFTcNxBIxizKHwLrAB8JCXeQmmqnt7cVvz+Prq6bQsdGyrL0TmDLlr2FuW3oY6hE\ntXZUDFwvq1qMst5/HfrcH2rCsiHqOj7GPPAjncjMaj9KvFfQ3Pxa+LmcSBeyW255iS1b1JSv+noN\nn28djz++7LjWEIRYRcRaEMaQgYM4VN0zWPB4FvHOO5U888ylVFa+zYEDqRQWVpKdXcTMmT3s2NGD\nx6N3KFPlUCrTuwg4hEokawb6uPLKqTz++DJKS7dRX38VqtXnFNO5LZaZVFS8SkGBL6r+OQnV4/t7\nqKYo0Znnt6EEWo9xL0QJsB/1wvBjIrHpDSi3ezfqBaAGZWWfvKtbxcONYQOZUyRMHESsBWEM0F22\n+/cHUMlaKlksISGDQCAiOE7n5/n615/D6VyFXtvc3d3K4cOddHbOwFwjnQccxOxyvg/4N/7+919Q\nWrqNtrY9QDHKlW22xHt7J1Fd/VVycx8gLq6SYPBclKguBJ4hMfGXTJ6sceiQUcg7iMTBdXe7FWWx\nbwQuQAm1up+MjB4uuSSN5uYUCgr+Avhpbn52hMqx9AEk+rUdPsn1BCF2ELEWhDEgeg611foAxcVT\n8PniTK5u2EN7+ySU8H0K3IbHsxGP5ybMMWA97mu2llUi1ye0tEyipUUD4oiLewzoIRj8FuamKWnA\nb2lvn40S/UuIWMQp+P134fGsJGJFd6GsZz0uvjD0nQc129ofWidyPxkZbTz00HWjEkueNy+dmprI\ntc2blz7i5xCEsULEWhDGgGj39/TpZ1BVdQUul5va2kiNMXyf/v77gVtRcd9OlGgbY8CHgZ8Dp6Hi\nwy4iItsM/BfG0q1g8FGUqNeiXNyXoBLMsgFzJzRlraeg11/7fJ9HxbHdKBd2H6rZyixUL3Er8+f3\n8NFHHQZvQCRJzelcQUXF6NRMP/TQYpKSamlq6sduD+BwLBrxcwjCWCFiLQhjwGClSR0dbm699QW8\n3lTgXVQf799hsVhD+3Whz3c210w7iSR96f29P48S4FuBNxgqLq72vxPlEg9gdqufTWLiLvx+Y1xc\nb1eqZ3FvBCJWfn7+PVRV/V+WLXsr1B5VT1J7At3CHq2aaRmNKUxkRKwF4SQYrFxI0zhmCdGqVXPY\ntasSl2saNtsBVq8uY+XKWmpqMlGCGBHI/v5VKKH7J1Ss2Si8XSir1rgtK7R/AcrCPow5lpsVtX8i\ng7ce3cP8+Vbee68y1GnsCPA14uJuJxiczWA13IcO5bFs2VuG2Lhu4SeG1tyA3R44mUcuCKckItaC\ncBJEdxnbseNOLJZEWlq+SHQbTSOVlW+H3MRq2tXatetDFmc8oAshqCztIpYsWU9dXStudwCz8LpQ\n7m/jtkmoGHRC6LNuMetx5vej9j8Ns3gfAdYQF5cNTGLTpq+wdu3bNDVl8v77/43X+wsi5VnmGu5A\noIv6+u8BZRQWqpeR3t5EdDe61erF4bhyJB69IJxSiFgLwkkQHXtubc3E7KbeyL598dxww5Ns394F\nZDNvXj8HD9pMx+lWeH19AsqtHRHA5OSPqaqqoKTkJdzuuURiyZ+gBnTEoVzZXyAu7m0SE0/Dau0h\nP9/PO++sQpVyAVyKcksfQrnN1QuFwije7cDdBIMWtm1TLxL6y4YqrzKWZx1Cud3PRDVJ0T0IFvLy\nzmbu3E6qqyO13MXFCdKoRBBOABFrQTgJomPPaqqV0UpN49Chf9DQMBNVp2yhpkajsHAtRoFsa3uX\nRx5ZwvPPr6e/H2ANMAP4kOeeUz2yOzo+QM2n/hmq3KsIVaOs1igsrKS2dkVYDE8/3YHRnR5xbx9C\nJY3prURdpKTcyec+d0FoSMh0ol8kdDIzP6S39ymUlR4EWiksTCMvz0Jb236czhWhPbVQOdbxdyIT\nBGEgItaCcBI4HCXs2mWM6YJRhAsLG+juzid6KEZW1nSCwXtCrTgP4XTm8POf/5WMjM/hdn8ntJ+b\nxMTf8KMfHaCx8QX6+rJRTU+CqCEdlqg1i/jRj14KNQc5hNc7jYHu7TtQGdr5JCSsITGxCJvtIK+/\nfiOZmVmUl3dSXR003YOxWcm5506ltdU8lzovL4etW6/A5ZpDRcVmkzBL0pcgjAwi1oJwEthsVmpr\nr6OiQh9l6QbWceCAlY6OvWRl2Wlvfx9lyUYEsKOjiba2eIwNTF555U5SUuyoEqgkQMPvn857730F\n+CbKMs4Grg8d86RpzY8+eoeGBqMlvQqze/sQytJ+ArievLxK6uuVkObmZtDe3oXDUUJX1yb++te1\n9PfnkJfXyurVXw/fb0tLtOcgKyzmgwmz9OsWhJFBxFo45TlZQRlMpMrLN9HQsCpUvuQCfoXqo51D\ncvJenM6fAq9hFD6//zz8/q8DT2F0b0cGX6SjMrvNk6+s1qmkprbgdJ6FWUhPJ9L0pBs1IUvPFrdw\n6JCV8877T7KzizjrLB93330pNpuVjAwrfv8PUUM4VMz6vvsms3JlLR988AnGF4BJk/6Ow/HdIZ9N\ndAIerBdLWxBOABFr4ZTnZAVlMLGPTjxLTEwmISEPm+0AaWmz+fBDGwOzsveimo34MIuuPviiG5X8\npR8zGYsFtmy5iNLSF4nUQOvrdaJGUBpFXwM+ADz4fB04nbfjdFrYvVtj69YHKC7OGzCas6kp0/CM\nzE1OZs8+86gvNtHPQeZSC8KJIWItnPKcrKAMJvYFBUeor9cTsRrw+1fj96syrbi421GiOR01ZcuF\nSkzTUIMyEjGKrsXyDzTtDSAXaMNYhlVSMpnKyrfxeH6KEtInAC8WSxslJalYLI/wt78l0Nv7CX7/\n5NCx/4pqQ/p703273WdSXb1owGhOu91jeEZ6k5PNwCJmzVp/1Gcjc6kFYWQQsRZOeU5WUAYT+4IC\nH2ZXduT7YLAIZeUeRDUNKUSJbyJKjL9NxH39Ppr2A5S4bgD+D/AUcXFTyM9vZu3aMr73vY9QQl2D\ncnHXk5CQQnp6DqtWzaGy8m2ami6goaGVQOBaw5V7MFvi3YCFtjYbmZn3Ehc3hXnzgjgcX6Gi4lXT\nM7Ja36e42HXM7G7JBheEkUHEWjjlOVlBGUzsm5qMiVjdmEVxHzAXZSk3oTK0dSt6DZo2GX1spGpu\nYkW5x53AfwM/Ixi04HSqeLLdrlFf/yKq8cgW4Iv4/X+jujqVHTueprVVTzp7LOo6JqFqtnWL/VpU\nYxM3Hk8u8G2SktZjs1kHeUbLhxXXl2xwQRgZRKyFU56TFZTBxN5siS5g0iR9OEcD5vnOD2O0ujVt\nKuaksNND3+k9wfVhHjVAOi+++AlPPTWX6upGlFDrDUgWAxtob3cb1r8KPclNZZtnYh7c8SAwFfg+\najZ2JCQgoisIY8uoi/Xrr7/O2rVr0TSNq6++mu9+d+jMUUGINYzJY0VFPeGM6ejv7HaNp5+eg6ZB\nRUUtH3zQR1LSSgKByWhaFgkJCeTkvMGhQzMwzneO7tsdH3+Q/v7lKOFNA/4GrEe5rI3DPJSL3e9f\nxHXX3YHqIKYnkaWH9rMQF5dNMOgCnkPVWbtRZWT7UF3QjIlsn0OJPOgxdIkxC0JsMKpiHQwGufvu\nu3nsscfIy8vjG9/4BldccQWzZs0azdMKwogRnTzW1xfJFDd/52LXrofp6cnH7U5GtQA9D11Uu7s1\nurvvZaBLvBdju87MzG5crl+i3OTdKGv6d8TFHSYYfCp0nC7cABb6+magyrgexNyx7A4CgWyUq/si\nlBv9bsP390ZdS1doTY3MzGYuv3y9xJgFIUYYVbH+xz/+gd1uZ+rUqQB87WtfY9u2bSLWwrjhaJni\n5u+2hAdzRFzKUzBbrlNR85/10qckoAKYTGbm/Vx+eT61tfmodqLGcqtzCAZ3EElYM2drJyU10tc3\nGbgg6nznh873o9Dn56K+n0JcXCWZmflcfHEQTfPT3PxsyJX/LWleIggxxKiKdWtrKwUFBeHPU6ZM\nYffu3aN5SkE4LvQZ0sYhGw899NWwUB0tU9z8XRpmIcxhYLb1p6h48JbQfsbM7CyqqpYye/bTUesk\nomZbzyYya/paVO/wmUAj55+vMWXKeurqWnC7jefrwzzCMtqqn0QwuBq3WyM9fSO//vWiE3+QgiCM\nKqMq1pqmjebygnDSRGZIR4ZsJCVFXN3G5LGiol7uvjviFjZ+19a2B6fzUpT1qgGfEB9/iJSU/Xi9\nOaSmdjB3bhJJSX/hwAErDQ31GIWzp6cJgNTUZjweo6C+jZqQFT2M42z07O3333+A555bSmNjE5dd\npiey7UG9GNQYzrMAWE1CwukEgx0Egz8I3YmFl1/uw+VyizUtCDHKqIp1fn4+Tqcz/Lm1tZW8vLyj\nHpObmzGalzTmyP3FFk6nMdlL/Xz5Zbj55s08/PBCiopO49lnrx/02I8//pitWz/C651OUpKb5OT7\n6euLCGt//4P097t5//2vMmuWPXzcsmUbaGiwY8z61rQscnMzyM+fTUuLMRu8ELOl3YNyg98U3max\n5JKbm8HNN+8OCfUSYD5KqA9jjImXlc3i2Wf/lWXLnuJPf5ocWkPD5UpizZo3ePrpa07mccY04+3v\n5vEi9zexGVWxPvfcc/nkk0/49NNPyc3N5YUXXuCXv/zlUY9pb+866vfjGX1YwkTls7y/kRoQUVjY\ngbI8jVauxp/+dA11dXdywQWn09ycg93eyaOPltHfHx8+trj4v/F6VUJXX9/AMiz4HL29izjnnDWc\nddaF4evcuzcF1RAl0go0Le1+PvjgAG1tjcBqIpb0PZhd1wdRLvGI0H75ywHa27tC6+qubiuwHKv1\nAdzuVeFrfuGFRzj33Cc57bROMjPvx+M5K3TMQvbufW3C/v2Uf3vjm1Ph/o7FqIp1fHw8a9as4Tvf\n+Q6apvGNb3xDksuEESE6S9vne4SkpNTjFm+Ho4QdO35Pa2ukhSf4AQutrZnU1NwYPseKFea4rsrC\nNoqzLvzmjmB9fRdRX78k3IpUNTHRO5KloHqEZ1BS8gRO57+gLO40EhPfRNM6CQSM1xYAesOJYfPm\nBXnooa8Aegx9Sfj4KVPexGJJQLnmu4EFBAIZNDRYaGhYQWHhWjyeReHrLShoobx8k0zIEoQYZNTr\nrOfPn8/8+fNH+zTCKUZ0lvb27V243SrufDzDOGw2K7m5X6S19RuGrZtRYmseB/nxx+mmYxMT38Pn\n08up9gNWLJZVaNoUVCb4wtA6R8JrNDVl8vTTc/D5nmf79k85csSH378aj8cSilXrE7bgnHOCfPCB\nJ6pF6BPAdSxePPD+VAxdnyftxuc7Pfyyoa7j58A09KSzrKzpzJ0bicd3dSXIhCxBiFGkg5kwLonO\n0lZzno8+jGMo13lHxweYLeJ/oCxRv2n71KkdpvXmzTuNurprUAKryq1UUuUToWP+GlrrJlQzkhfZ\nv99LRcWr3HnnpVRWvs3WreD361neVlRWOeiZ521tB+jtjVxDYuJHLFwYqX8+WjigtHQbZsv/QmAR\nqu5aY9as/rAY5+ZmcP75zx7zGQqCMDaIWAvjkugWnz5fPzU1Rx/GMdQozKys6TidxqQuG5BGUtLf\n8fkiLmhN8wMRgfzb3zKJuLKNomhDJXlp5OfXc/75f2H7dhdu909wuy1UV2vs2lUZVZetZ3nvITPz\nAOnpnTQ2FnHWWekEg/fQ2WnHZjvIpk3fZMaMSLLa0cZ75ucbx2lG3PLJyZlkZ1fS2FhEefkzOBwl\n5OZmyIQsQYhhRKyFcUl0r2qXy01SkhLv/PxD+Hx+Skqeo6OjiezsImbOPGKY0+wGati6FcrLn+G0\n047Q0BA993k+gUAzxlpop3MzYBbIwTuBvQtYsFrfp67u/2KzWSkt3UZ9fUTQW1ryMQu8L3TeFfT2\n/gaPZzVOp1qvrGxod/TRmrZYLAHMDViUWz47243TuSo8xxrW8+yz18uELEGIYUSshQmBUbzLyzdR\nXX0jSvwiohSZ01wDLKe3V1m5Cxaso6xsPXV1AdzuScDngQcJBmcCT6JaeU5mxoxuYKBAKsv7DlQ8\nuAOYDnhITvawbNlb2O2dFBT4TFZrMNiIWeCdwCpAw++fynDd0UezhpubC1DDO9TLSUrKc5SWQmNj\nUagP95kAAB2VSURBVOhFwLy+DOsQhNhFxFqYcETE1Ni9y0J2dhFz565n61bo7Y1s37LFD8SRlLSP\nSy9NYceO9/H7jT221wCn8cYbbXz8cdMQ8fJ/Ae4HvoxyN/8Tra1NtLbGU1+vYbM1oAR9BtCIGk/5\nKOACckhI8HHmmU9y8KATtzsXo5C3tb2Ly6WGhETHp49mDUeuU5VxlZYqC728/JmQRS3ubkEYL4hY\nCxOOiEh1YU4QcwNJJCe3mJK21Pzoa+nr0/jb39bQ3z8Ts+V8EbAEp1Nj6dJKamuvo69vHVu39hMM\ndqFqnp/D3GnsTuDfw59drmYiPb9dwC9DP28DLAQCGtOmraOjw4/bnYkS9rMAC07nCioqlAteud87\nqa9/kc2bXyQ//xCbNpWZ4tg6Qwm5uLsFYfwhYi1MOHQx2rcvno6OylDMugefzx9yj3cCG7Bavbjd\nzcC3ULHddPr6JqHKsIyWc6T0yuWahs1mJTk5iWBQj1u7gD9hFvjZUZ+Nru0tqOlYz5v22bEjLtTA\nxAIsxVjGFXGFW1Bu/GsIBi3hF4j6+h8OeA5DubXF3S0I44+4sb4AQTgROjrclJdvorR0G+Xlz+By\nucPf6WL05z/PZ+7cacTHJwAaBw7o7nErcC3Tp2eRnNwN/A8qE/tS1HCM6cDtKOt3FZAMPAW4sNkO\nAkZXuxvVuSwdJewQGdox1Gd96EdX1D6HMQu8uYzLbu8M7Wd277tc007gCQqCMJ4Qy1oYM06mZejR\nSpaG2ieSYBaJ1WZm5vL66z6MFmvEoq4M/VKfU1LuZNOmbwJGV3sNKiFtPpFe3/XAdeidxPLz/8E5\n56Tw1lsPEAza6OvbT1/fYlR2trLwi4sT8PnSTOVn8CbgJjHxQ1avXoamESr5ysKY+Ka/QAiCMHER\nsRbGjOEI7lAcrWRpqH2ys4v4whfWsWNHHHAYny8Nl+t0Iv20zRYrmMurZs/+AmvXvk1T00cUFPhY\nsOB3vPZaGr293ai49TWhdeopK3s93EnM4bjB9BLicrmpqNBjxgEcjiux2azh8jOVld4K3ArY8Ps1\n1q5dD2CqzY6LqyQ/HzZtWjKsZyYIwvhFxFoYM4YjuDC4BR6dkd3W9i6NjbOprHw7FKtuort7CkYL\n9PDhvRw4kIjb/RNAjcMsLFwL5KFi1p+iOnzplu1HGC3xDz98h927fwxsob5+Cvn573DxxbBt23J0\nK1qNpnRTVXXLkPetu+n1+9LLuxyOEqqqluJyufnSl17G7Y5MBDPHrNXPL3zhbLZuveL4HrogCOMS\nEWthzBhux6zBLHCHoyTkEv48cASncwVf//rDIctT1Vfr61qtD5Ca6sfpXAG8gVHwsrKm09PTh9t9\nLSr+vBHoBeJJT7fS3R3pbOb1TkElhy0HOmlp6aalxQU4UAlk76EakMwIdwY7mls/+r5eeukOzjjj\ni8yceYR587yDdGTTpMOYIJyiiFgLY8ZwSog6OtzU1bWiMqe7gIXU1QUAyMs7G6cz4gJWiVadKAs5\nsj9kk5WVHJpdbSznctHR0YT6Z2BM9Ipj0qQP+dKXckNWs3EQhje0dgORUiy9i9mtqBj2tVRXR9z6\nQ8Xmoz0LXu9cdu9ewu7dkUYtA5+NlFwJwqmIiLUwZgynhGjlytqw21qJ4gbc7klUVNSGRk1GLE2b\n7QC9vS+i1y4b909N3R/6HEnqSk1tCVnincCDoTOqY71eDYvlEcrK9CYqicBpgHGKlTG+fQ7K6s4I\nb9Nd10PF5gc2V4mUiG3fHsfOnZcPsMyl5EoQTk2kdEuIaQa29vQBC2lqysThKKGsbD3nnfcsZWXr\n2bSpDKvVO+j+2dlFoX1fo6wswM6dV5KXdzaRUq5CoMh07JtvJlFVtZTSUg3l+p5i+F5PSoOI0Kah\nLHe1TXUecw8am+/ocOPz9ZCYeCeqocq9wFfDx+ovJIIgCCCWtRCj6K7j/ftbMDcoSQYmY7d7BrXM\ni4vfCrmgzfvPnNkzYF+zZbsA1S50cfjYI0eacbnchvi4RiQBbQEqLn4xSqi/Sn7+b9C0Plpbn0OP\no1dUbB7gAbDbPaxcWUtNzfdRVv2LZGZm0Nv7K/z+81Gu9oU0Nb02sg9VEIRxi4i1EJNEXMeq21hm\nppf09BaysuzMmrWeVasuoLx804A4sB4Hb2xM5fDhvUPuv2LFGezc+Qlxcb9H09rQtGyU5RwZien3\n9/GlL71McXE8mzYtYfHi/6Kt7V5UMtmnXHppOllZ7tCam3E4bmDZsrdobdXj6CrePm1aIYWFkU5q\nDsflLFv2FsYGLTNnPovdnkF19VVE9wQfbu25IAgTFxFrIWYwJmIpi7oTo5ht3fp/wvuqyVqROPCO\nHXdywQX/v717D66yvvM4/s4dSAI5QIBEuiGAEay2TC11YVxCsY0SwKBopXWkRZuV0sEx7Qw3124t\n3VBTrbZDhyJip1AqWNYkUAhVA4RWKcvWTTEqZYg0CLmS5DQJhlzI2T8eTs41yUlyDufJyef1jyR5\n8jy/x4if/G7f379QVTWelBQb+/bdicVyj5frjbra+/f/DZttKvZtXUbFskSMFd23AGeBWVitVyks\nXAgcYO7cmRQUrMAepnFxO/rorR/qPsMabMye7VhwVlv7IcYsVAuw8PqCMc8V7mvXHtA8tYgorMU8\nPM+Jfg3jPGnPbUru88A1NaMpKjIWf5WW2igpeZ709AleVl4bVcpsNvszXgVGYdTyjsGo8x2O8yEc\nsIeKitFERUW4PLOqarzHO2zYcAenTm2msXEyHR2tdHZ67iNft+6oS3GT5OTN5OU9isWS4LHCvbfj\nMUVk+NACMzEN9wBOSLjavXjMfZuSo0421/853uV7rdYZFBau6F6k1VNdbSOclwMP4Dhw4zxGr95+\nTSy1tR9y7twZl2c6/wJhr1V+773/Q2VlCq2t99HZGef1evf3nDDh1u6hbvf30l5qEQH1rMVE3Lcy\n/eu/dhET00RFxWjWrj3iUmTEfcjY4LywrAXn3qx9LrukpBqr1blKWTze64I7evUjRpyisvJx4G3g\nBSIj4/nqVyPIy3MMs3uOCuwBFpGQ8DyTJ6fS0HCW8vIUsrPfICmpvcfiJjq+UkS8UViLaTiOthxF\nQ8NZ3n13NE1NI4H5lJaOwbl2uMWSwNGjj/LUUwc5caKZrq6RjBr1X3z6adL178nEOQjtK8dd63I3\n0dLSRXGxZ4979OirhIe/CtTT1TWRq1dPYN9j3dlp429/20xj4z9Zu9Y+x96Ja489DhhDevpE4FPK\nyjZQWRlGWZmNhQt/1UPBEx1fKSLeKazFNOxBlZ2dT1mZY07Xfq6z+/ytxZJAdPQorNYngDCamowg\njI6OoqLimNeeqc3m8hG5uf9Gbu4ujh6toqnJ0eP+9NOP6ezcdP3j3TiOtQQIo7LyNpYuzae6ehoQ\nAdTg3LNPSDhDenqj28pv43urqpJU01tE+kVhLUHR2/GYnoVQjLlfb/O37td6C0LnZ9XWfkBl5SPA\nCUpLLZw6Vcirr36ZoqIyjCpmxtx3Z2eM030XERb2PDabYw82NFJdzfW2NQNfJyrqP/nsZ79w/ZeE\n5S7z0KrpLSKDobCWoOjteEz3cHPupdo5iqZ0Ar/AmKOezJkzZzl/fjqpqSlenwVZwHPAOowe8hKW\nLv0B7e3P4dqTHwf8DmNOu4nY2Gu0tPwEo6zoFYzKaP/h8j2xsVO89pg1Dy0ig6WwlqDo7XhMz3Bb\n7lIYpKHByoIFu64vLmsB/oF9q9XVqzbuv38zpaVrenyWUVrU8XFbW6rb12MxVoR/B8ee6k20tKzC\nqP8dS2Rkk8u2LIhlzhz7QjdXmocWkcFSWEvA+XIetfPQcF/h5r5PGTbjHLaNjZNdnlldfRqoAyYB\nTURHv097u+PZMTEfc/Wq4+Pw8L8QE3Mzra2OeyYm3sq8eYc5e3YkKSlW2tvDXY6wTE4u46WXHvXr\nvyNVLhMRO4W1BFxP51E79557Kh/qjWtP+Z/ANYzDMIxqYBbLRS9D369h1P22MW9eM7Gxjmd/97uZ\nfOtbRiETi+Ui+fnfIDfXtcb41KmfsnfvCurqjIM6GhutREc79/4fHVS49jYtICKisJaA8zbk7d57\ndi8f2ltYTZpUh2Pl9SGc545HjPgB+fkP88QT53Ad2o4HrEAR77wziowMG3v3Oupul5be7vKMvDxj\nq1hP88z+HtrubVpAREQVzCTgfKnK5R5Wb74J2dlv0Nho7b7GXiXs3XerMHrKBzAWejm+b8aMO0hN\nTfFS4awZo/DJclpbV7hUN3O/f0ZG8fUiLF9mz547AFi27CSf+cxmFizY79Euf1DlMhHpjXrWEnB9\nrYb2drBFa2sUhYXLgV0899yXWbfuKCUlnVitMcDNGNXGwFix7Riurq39kIwMSEq6wsKFO6iqGk9S\n0mWgg2PHYl3mod17r84nfZWWHqKk5C1Gjap2mR+/eHEPZWUr8PcwtVaMi0hvFNYyIN4WRCUmxvf6\n9Z7mdD0XjD0HrMIeqJ6lPH+CUdP7MAAjRjzDzTfPor7+LJWV36Gy0kJpqY2srF28+ebd3W2Jiemk\ntXU39pO2ej4cxCg9arWGYbXux3PPt/+HqbViXER6o7CWAfG2IMo4PrLnr/cURp5bq27FOBrTGA72\n/PoM4JcYx1oa27UmT95BRMStVFZauq9zPuXKOewTEp4nPX2i18NBjLY6lx5twbPmuIapReTGUljL\ngHhbEFVfbyU7e7+X86h774m6b+OaNOk0V69eBuppb48lKanNrUjKOVpaEl32OZ84EU56uvftYO5t\nnTLlZrZv77l4iethHwtJTt7MuHFpNDaeIyHhM0yb5jgFTFuuRORGUFjLgHjbJ716dZHP51E7y8tb\nQFvbDv7yl3CgHputDat1JRBGUZG3gy+Wc+edr2G1Ovd468nLM+a43ed9fS332dNhH/ZtWYmJD3Zv\n3bLTlisRuREU1jIg3vZJL1z4v7ifRz1lSkGfC6YslgRiYqKxWpdgzEN3YAR9JpDgtd73nDlxFBW9\nhrElq5k5c+KwWBJYv/4LLFu2n7//PYk//nEbqak3M2WKY7GZL4u3+jN/rC1XInIjKKxlQLztk25s\njMJ5fjc9PdLrcLOd8xCyMWy+H1iBo7e8B1jutSf80ktLiI4+SkXFNVJSOsnLWwzAsmX7XRarffTR\nHj76aEX3YrPBcB7m96USm4iIvyisxS+MHuV8jIAdQVTU/1FefgvZ2W/0OI9rDCHbe9MzgCqce6kj\nR3aQkbGrx+pm3nq/jY2TXe7hz9XbzsP8PVVi05YrEQkEhbX4hdHDHIOx//l3dHQ8S1lZGGVlrvO4\nnr3p/wYex3FutKOXmpFB9/nWvs4LWyyf0NoamNXb5887rxL3XolNRCQQFNbiF3l5CwgL28mxY9do\nauqgq8v7PK7nnukXcD43Oioql8jIz2CxXGTjxvsAKC8fhXNIfvzxKI/n238JGDNmOg0Nz2CzJREW\nVk1q6nTS0nb5pcebmtrMqVMa8haRG09hLT5raLDy1FN/vL5q+zJz5sTx0ktLsFgSsFgSiI6Oxmpd\njrE4zHuouS/IioyMp7PTfu0YOjpS6ej4Bq2tNnJzd7F9ewoNDX93uV99/VngHpe2uf4S8DWysnax\nffsK/Gnr1kza2jTkLSI3nsJafLZu3VEOH7YPWdsoKnqN6Oij3cPAjmHiTGDP9TlnXELNfUHWqFEN\nxMUZ+5g/+eQ8Vms29gM39u/v5NSpXxAbm4QxFx4HtDB2bIpH227EquyxYzXkLSLBobAWn3lWEoun\nouJa99cdw8QJwHIyMjznlh2FRzqxWkfQ1PQdmprGMHv2LqZOnUBh4Rjsq8BttjAqK22MGPEMsAl7\nwE+btsujbVqVLSKhTGEtPjMC0V6TOxb4gKQkxypvX4aJ7QuyMjKKKS1d2v35iorR7N17B7CLwsJm\nHD3pZq5ds7gVRfG8r1Zli0goU1iLz/LyFnDy5C+prjZqcsMSYEf31/szTOzeE66t/ZCHH4aUFBsx\nMVW0ta3u/lpExA/Yvv3fe72fVmWLSChTWIvPLJYEJk26jepqx1B4VdX4Ad3LuSdcW/uhy2lZ8fHb\naWtzPCM19TZ/NF9EZMhSWIvPvJ07PdC5YeeecEYGLqdlRURYcV79nZbWNui2i4gMZQpr8Zn7udPJ\nyZvJy3t00Pd1HxKfMyee6GjNP4uI2CmsxWfuq8EnTLgViyWhuyBJZaWF5OSGfh8T6bk4bLGOmRQR\ncaKwHqYGcg5zT9ujPKuS9e+YyIEsDtM50iIynCish6mBnMPc0/aoG3VMpHNA19Z+QGXlasCic6RF\nJOQprIcJ957oxx/H0t+A7akH3FOP29+9X9cefBbGXuyv+9x+EZGhSmE9TLj3pJOTc+mpfndf3EN4\n40ajmIkxZ93Y3eMeSO+9N54V1GKv/1kVy0QktAUsrLds2cLrr7/OuHHjAMjJyWHevHmBepz0wT3o\nxo6dwuzZA1tx3VMIJybGU1fX3OMzB9v7de/BJyeXMWFCl1aMi0jIC2jPeuXKlaxcuTKQjxAfuQfd\ntGnXBtzL9TWE+1uvu69hc88580e1qExEhoWAhrXNZgvk7aUf/Fk729cQ7u8z+xo2V0lRERmuAhrW\nu3fvprCwkNtuu43169cTHx8fyMdJL/wZdL6GcH+feaNWlYuIDDVhtkF0f1euXMnly5c9Pp+Tk8Os\nWbOwWCyEhYXx4osvUldXR25u7qAaKwNXX29l9eoizp+PIzW1ma1bMxk71lxDyA8//Dtef91Y3Q02\nvva1Pezd+/VgN0tEJOgGFda+unTpEqtWreLAgQN9Xuu8QCnUuC/AupGys/NdCpdkZfl/X/Jg36+x\n0cratUddeuxmmpMO5s8v0EL53UDvN9QNh/frS8CGwevq6khMTATgrbfeIi0tLVCPEh84hpitQBFv\nvgnZ2W+YqvKX5qRFRLwLWFj/9Kc/5aOPPiI8PJybbrqJH/3oR4F6lPjAsSisCFhOa2sYhYUD3/vs\nbeW2L78diohI/wUsrPPy8gJ1axmADRvu4NSpzVRVTcJmG/wiLm8rtwsKVviruSIi4kQVzIaJzZvf\nu3685Wv0VrnM3mMuL4+goaGCcePSmDr1isdwuVZui4jcOArrYcIRrpnAHkaO7CAjA49tV44e8x5g\nA5WVYbz/vudwube91vX1VrKz9+skLBERP1NYDxOOcE0AlpOR4Qhf5/nnf/yjEyOA43DuOZeXjyI7\nO9+jHrh95faGDV9g1qxfcfHiOvpbC1zHXYqI9E5hPUz0VsjE9TSr3RjD5M04D5dfvnyGsrKnsQdx\ne/sOfvObh7vvkZ2dz8WLtzKQoXF/H/ghIhJqFNbDRG/bolznnxeRkPA8kycn09Cw+fqc9accPZqA\ncxCfOBHu5R4tDOQkL81/i4j0TmEtbvPPY0hPn8j27fe5XJOWthXnIIZ6L/e4D2OuO5bk5DLy8h4d\nwPN13KWIiDuFtfhU63vOnDiKil4D4oFm5syJ87hHTMxhzp4dSUqKtV8nYvnzkBERkVB0Q8qN9keo\nl5Qbqu/nSynQofx+vgjl9wvldwO931A3HN6vL+pZB0ioVfhSKVARkeBRWAeIKnyJiIi/hPd9iQyE\nVjiLiIi/KKwDJCXlnxirpkErnEVEZDA0DB4gWuGsymQiIv6isA6Q/i7ICsVgU2UyERH/UFibRCgG\nm+btRUT8Q3PWJhGKwaZ5exER/1DP2iRCseSm5u1FRPxDYW0SoRhsKqQiIuIfCmuTULCJiEhPNGct\nIiJicgprERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgpr\nERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NY\ni4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NYi4iImJzC\nWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicoMK68OHD7N48WJmzpzJBx984PK1\nbdu2kZGRwcKFC/nzn/88qEaKiIgMZ4MK67S0NLZs2cLs2bNdPl9eXk5RURGHDh1i+/btPPvss9hs\ntkE1VEREZLgaVFhPnTqVKVOmeARxcXExmZmZREZGMnnyZFJSUjh9+vSgGioiIjJcBWTOuqamhqSk\npO6PJ06cSE1NTSAeJSIiEvIi+7pg5cqVXL582ePzOTk5LFiwwOv3eBvyDgsLG0DzREREpM+w/vWv\nf93vm06aNImqqqruj6urq5kwYYJP35uYGN/v5w0ler+hLZTfL5TfDfR+Q12ov19f/DYM7tybXrBg\nAYcOHaK9vZ1PPvmECxcu8LnPfc5fjxIRERlWwmyDWKb99ttvs2nTJhobGxk9ejQzZszglVdeAYyt\nW/v27SMyMpKnn36au+66y2+NFhERGU4GFdYiIiISeKpgJiIiYnIKaxEREZNTWIuIiJicacN6x44d\nzJgxA6vVGuym+NXPf/5z7rvvPpYuXcrjjz9OXV1dsJvkV3l5eSxcuJCsrCzWrFlDS0tLsJvkN73V\nwh/Kjh8/zr333ss999zDyy+/HOzm+NXGjRuZO3cuS5YsCXZTAqK6upoVK1aQmZnJkiVL2LlzZ7Cb\n5Dft7e089NBDLF26lCVLlrBly5ZgNykgurq6uP/++1m1alWv15kyrKurq3n33XdJTk4OdlP87tvf\n/jb79++noKCA+fPnh9x/gHfddRcHDx6ksLCQlJQUtm3bFuwm+U1PtfCHsq6uLjZt2sSOHTv4wx/+\nwMGDBykvLw92s/zmgQceYMeOHcFuRsBERESwYcMGDh06xJ49e9i9e3fI/Pyio6PZuXMnBQUFFBQU\ncPz48ZAsW71z506mTZvW53WmDOvc3FzWrl0b7GYERGxsbPefW1tbCQ835Y9gwObOndv9TrNmzaK6\nujrILfKfnmrhD2WnT58mJSWFm266iaioKBYtWkRxcXGwm+U3X/ziFxk9enSwmxEwiYmJzJw5EzD+\n3zJt2jRqa2uD3Cr/GTlyJGD0sjs7O4PcGv+rrq6mpKSEhx56qM9r+6xgdqMdOXKEpKQkbrnllmA3\nJWBefPFFCgsLiY+PD6lhK3f79u1j0aJFwW6G9MJbHf/3338/iC2Sgbp48SJnzpwJqQJUXV1dPPDA\nA1y4cIFHHnkkpN4NHB3T5ubmPq8NSlj3VG/8qaeeYtu2bbz66qvdnxuKvZi+6qnn5OSQk5PDyy+/\nzG9/+1vWrFkThFYOnC/14rdu3UpUVNSQmyscSC38oWwo/v0ST1euXOHJJ59k48aNLqN3Q114eDgF\nBQW0tLSwevVqzp07x/Tp04PdLL84duwY48ePZ+bMmZw8ebLP64MS1j3VGz979iyXLl0iKysLm81G\nTU0Ny5Yt4/e//z3jxo27wa0cOF/rqS9evJgnnnhiyIV1X++Xn59PSUnJkBw1GEgt/KFs0qRJVFZW\ndn9cU1Pjcx1/MYfOzk6efPJJsrKy+MpXvhLs5gREXFwcX/rSl/jTn/4UMmH93nvvceTIEUpKSmhr\na+PKlSusXbuWvLw8r9ebasI0LS2Nd955h+LiYo4cOcLEiRPJz88fUkHdl4qKiu4/FxcXM3Xq1CC2\nxv+OHz/OK6+8wtatW4mOjg52cwImVHqkt99+OxcuXODSpUu0t7dz8OBB7r777mA3y69C5WfVk40b\nNzJ9+nS++c1vBrspftXQ0NA9PHz16lVOnDgRUv+//N73vsexY8coLi7mZz/7GXfeeWePQQ0mnLN2\nFhYWFnJ/0V544QXOnz9PeHg4ycnJPPvss8Fukl/9+Mc/pqOjg8ceewyAz3/+8/zwhz8MbqP8xLkW\n/qpVq1xq4Q9VERERPPPMMzz22GPYbDYefPBBn1amDhXf//73OXnyJFarlfnz57NmzRqWLVsW7Gb5\nzV//+lcOHDhAWloaS5cuJSwsjJycHObNmxfspg1aXV0d69evp6uri66uLjIzM0lPTw92s4JGtcFF\nRERMzlTD4CIiIuJJYS0iImJyCmsRERGTU1iLiIiYnMJaRETE5BTWIiIiJqewFhERMTmFtYiIiMn9\nPyQ+uNKCpR6MAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0xa813090\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Plot the Data (Optional)\n", "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.scatter(inputs.numpy(), labels.numpy())\n", - "plt.show()" + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JaFHyAG9nDET" + "id": "v9fPs8RyopCf" }, "source": [ - "## Step 2: Define our TensorFlow variables\n", + "### Higher-order gradients\n", "\n", - "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/contrib/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias.\n", - "\n", - "(**Note**: We're using the implementation of `Dense` found in `tf.layers.Dense` though the documentation link is for `tf.contrib.keras.layers.Dense`. When TensorFlow 1.4 is released, the documentation will also be in `tf.layers.Dense`) " + "The same API can be used to differentiate as many times as you like:\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "height": 34, - "output_extras": [ - { - "item_id": 1 - } - ] + "height": 276 }, "colab_type": "code", "executionInfo": { - "elapsed": 22, + "elapsed": 730, "status": "ok", - "timestamp": 1505502830753, + "timestamp": 1527005655565, "user": { "displayName": "", "photoUrl": "", "userId": "" }, - "user_tz": 240 + "user_tz": 420 }, - "id": "z9r-ZeyrXu3A", - "outputId": "6230a7a3-29fe-4d08-f101-da80425bad82" + "id": "3D0ZvnGYo0rW", + "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff" }, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n", "text/plain": [ - "[]" + "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e" ] }, - "execution_count": 4, "metadata": { "tags": [] }, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "# Create TensorFlow Variables using Keras's Dense layer.\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", "\n", - "wb = tf.layers.Dense(units=1, use_bias=True)\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", "\n", - "# We can access the underlying TensorFlow variables using wb.variables.\n", - "# However, the variables won't exist until the dimensions of the input\n", - "# tensors are known. Once the dimensions of the input tensors are known,\n", - "# Keras can create and initialize the variables. Until then, Keras will\n", - "# report the variables as an empty list: [].\n", + "import matplotlib.pyplot as plt\n", "\n", - "wb.variables" + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "docKLUaonYG_" + "id": "-39gouo7mtgu" }, "source": [ - "## Step 3: Define our loss function\n", + "## Gradient tapes\n", "\n", - "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)." + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -257,145 +182,42 @@ } }, "colab_type": "code", - "id": "0_w8ZJSCtuY7" + "id": "MH0UfjympWf7" }, "outputs": [], "source": [ - "def loss_fn(inputs, labels, wb):\n", - " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n", - " predictions = wb(inputs)\n", - " return tf.reduce_mean(tf.square(predictions - labels))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 34, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 24, - "status": "ok", - "timestamp": 1505502830875, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "RkNbXoXkpjVH", - "outputId": "c36fc98d-3a57-4074-901d-c10ae017ae3f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=40, shape=(), dtype=float32, numpy=7.3549819\u003e" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Test loss function (optional).\n", - "\n", - "loss_fn(inputs, labels, wb)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 51, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 57, - "status": "ok", - "timestamp": 1505502830981, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "K_7beXoHOU7t", - "outputId": "1ad0856a-02ec-4117-a6c0-b41030981d87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w: tf.Tensor([[ 1.56891453]], shape=(1, 1), dtype=float32)\n", - "b: tf.Tensor([ 0.], shape=(1,), dtype=float32)\n" - ] - } - ], - "source": [ - "# At this point, the variables exist, and can now be queried:\n", - "\n", - "w, b = wb.variables\n", - "print(\"w: \" + str(w.read_value()))\n", - "print(\"b: \" + str(b.read_value()))" + "def f(x, y):\n", + " output = 1\n", + " for i in range(y):\n", + " output = tf.multiply(output, x)\n", + " return output\n", + "\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", + "\n", + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "YIlebeb_qYtC" + "id": "aNmR5-jhpX2t" }, "source": [ - "## Step 4: Create our gradients function using `implicit_value_and_gradients()`\n", - "\n", - "With a loss function defined, we can calculate gradients and apply them to our variables to update them.\n", - "\n", - "To calculate the gradients, we wrap our loss function using the `implicit_value_and_gradients()` function.\n", - "\n", - "`implicit_value_and_gradients()` returns a function that accepts the same inputs as the function passed in, and returns a tuple consisting of:\n", + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", "\n", - "1. the value returned by the function passed in (in this case, the loss calculated by `loss_fn()`), and\n", - "1. a list of tuples consisting of:\n", - " 1. The value of the gradient (a `tf.Tensor`) with respect to a given variable\n", - " 1. The corresponding variable (`tf.Variable`)\n", - "\n", - "Test it out below to get a feel for what it does. Notice how the first value of the returned tuple (the loss) is the same as the value returned in the cell above that tests our loss function." + "For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -403,94 +225,48 @@ } }, "colab_type": "code", - "id": "v1spZQ4NwW1U" + "id": "bAFeIE8EuVIq" }, "outputs": [], "source": [ - "# Produce our gradients function. See description above for details about\n", - "# the returned function's signature.\n", - "\n", - "value_and_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 153, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 46, - "status": "ok", - "timestamp": 1505502831114, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "21WMcpsmFFLd", - "outputId": "f51b3171-33f5-4f87-8bf7-0be2dc8edc8a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Outputs of value_and_gradients_fn:\n", - "Loss: tf.Tensor(7.35498, shape=(), dtype=float32)\n", - "\n", - "Gradient: tf.Tensor([[-3.00773573]], shape=(1, 1), dtype=float32)\n", - "Variable: \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e\n", - "\n", - "Gradient: tf.Tensor([-4.06519032], shape=(1,), dtype=float32)\n", - "Variable: \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e\n" - ] - } - ], - "source": [ - "# Show outputs of value_and_gradients_fn.\n", - "\n", - "print(\"Outputs of value_and_gradients_fn:\")\n", - "\n", - "value, grads_and_vars = value_and_gradients_fn(inputs, labels, wb)\n", - "\n", - "print('Loss: {}'.format(value))\n", - "for (grad, var) in grads_and_vars:\n", - " print(\"\")\n", - " print('Gradient: {}\\nVariable: {}'.format(grad, var))" + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JVDWpL9VYWdP" + "id": "DK05KXrAAld3" }, "source": [ - "## Step 5: Create an optimizer\n", + "### Higher-order gradients\n", "\n", - "We'll use a `GradientDescentOptimizer` to fit our model." + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -498,362 +274,45 @@ } }, "colab_type": "code", - "id": "DudNEebMKDWN" + "id": "cPQgthZ7ugRJ" }, "outputs": [], "source": [ - "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "YBeJYxY8YaiO" - }, - "source": [ - "### Step 5a: Test Our Optimizer\n", - "\n", - "Now we have everything needed to start fitting our variables to the data!\n", + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", "\n", - "In the next cell, we'll demo these capabilities. We'll:\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", "\n", - "1. Print the current values of `w` and `b`\n", - "1. Calculate the loss and gradients\n", - "1. Apply the gradients\n", - "1. Print out the new values of `w` and `b`\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", "\n", - "You can run the cell multiple times. Each time, you should see the values of `w` and `b` get closer to their true values of 3 and 2." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 102, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 103, - "status": "ok", - "timestamp": 1505502831285, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "diDZfrMJM3OC", - "outputId": "d585fff0-ecb3-4e98-9b33-bbae07a95d8c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Values of w, b, BEFORE applying gradients:\n", - "(array([[ 1.56891453]], dtype=float32), array([ 0.], dtype=float32))\n", - "()\n", - "Values of w, b, AFTER applying gradients:\n", - "(array([[ 1.86968815]], dtype=float32), array([ 0.40651903], dtype=float32))\n" - ] - } - ], - "source": [ - "# Test the optimizer.\n", - "\n", - "print(\"Values of w, b, BEFORE applying gradients:\")\n", - "w, b = wb.variables\n", - "print(w.read_value().numpy(), b.read_value().numpy())\n", - "print()\n", - "\n", - "# Calculate the gradients:\n", - "empirical_loss, gradients_and_variables = value_and_gradients_fn(\n", - " inputs, labels, wb)\n", - "optimizer.apply_gradients(gradients_and_variables)\n", - "\n", - "print(\"Values of w, b, AFTER applying gradients:\")\n", - "print(w.read_value().numpy(), b.read_value().numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "61TgeLVlKEQp" - }, - "source": [ - "## Step 6: Create a training loop\n", - "\n", - "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 397, - "output_extras": [ - { - "item_id": 1 - }, - { - "item_id": 2 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 225, - "status": "ok", - "timestamp": 1505502831550, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "VukGe-huNaJ4", - "outputId": "f0a8d665-1910-477c-d8ab-c94ccdc4afcd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2.111051321029663, 2.3047544956207275, 2.4602210521698, 2.5850086212158203, 2.6851789951324463, 2.7655951976776123, 2.830157995223999, 2.8819968700408936, 2.9236228466033936, 2.9570505619049072]\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAFXCAYAAADnFpTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd4FFUbBfAzu+m9koSShBQCSC+igIAgRRGkChJEiggo\nHURAEBQBQeADRcWCha50ULFLk6IivYRQQwskhPS6O/P9sckmm4Rkk2x2difn9zz7bLuZvC8JHO7M\n7FxBkiQJREREVOlUchdARERUVTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMjArdlJQU\njB8/Hk8//TS6d++OkydPVnZdREREiiMY8znd6dOno2XLlujbty80Gg0yMzPh4uJijvqIiIgUo9TQ\nTU1NRa9evfDbb7+ZqyYiIiJFKnX38s2bN+Hp6YkZM2agd+/emD17NjIzM81RGxERkaKUGroajQbn\nzp3DoEGDsH37djg4OOCzzz4zR21ERESKUmro+vv7w9/fHw0bNgQAdO3aFefOnSvxa3g5ZyIioqJs\nShvg4+ODgIAAXL16FbVr18aRI0cQGhpa4tcIgoC4uBSTFSkHX19Xq+8BUEYfSugBYB+WRAk9AMro\nQwk9ALo+jFFq6ALArFmzMHXqVGg0GtSqVQsLFy6sUHFERERVkVGhW7duXWzdurWyayEiIlI0XpGK\niIjITBi6REREZsLQJSIiMhOGLhERkZkwdImIiMyEoUtERCbRuXM7uUuweAxdIiIyCUEQ5C7B4hn1\nOV0iIqKy+OijFTh69BAEQYUhQ4ajU6fOuH8/HnPmzER6ehq0Wi2mTJmOJ59sgwUL3kZU1HkAArp3\n74nnn39B7vIrDUOXiEhh5s6dhd27d5h0mz169MLcue8aNXbv3t9x+XI01qz5Fg8eJODll4egadNm\n+PXXn9Cq1eN48cVhkCQJmZmZOH/+POLi7uGbbzYBANLSUk1at6Xh7mUiIjKp06dP4qmnugIAPD29\n0LRpc5w/fw716j2CH37Yha+++hyXLkXD0dERtWrVwp07t7F8+RIcPXoYTk7OMldfuTjTJSJSmLlz\n3zV6VloZCq80l/e8ceOm+Oijz3H48EEsWDAXAwcOxuDBA/D11xtx9Ohh7Ny5DX/88StmzHhLjrLN\ngjNdIiIyifxwbYbff/8VoijiwYMHOHXqBOrXfwSxsbHw8PDEs8/2wrPP9sLFixeQmJgIUdSiffsn\n8fLLoxEdHSVzF5WLM10iIjKJvLOX27d/EmfPnsbQoS9AEFR49dXx8PT0wp4932PjxrWwsbGBk5Mz\nZs16G7GxsXj99TcgSSIEQcDo0eNk7qJyCVIlrThv7esjKmmNR2vvQwk9AOzDkiihB0AZfSihB8D4\n9XS5e5mIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIismjHjx/DmTOn9M93\n7NiKn3/+0STbXrv2K5Nsx1gMXSIismjHjx/D6dP5odurV1907fqMSba9Zo15Q5dXpCIiogrbsGEN\n7O3t0bfvAHzwwVJcvnwJK1Z8gmPH/sGPP+7C7NnzDMZHRV3Ahx8ug0aTDWdnN7z55hx4eXlj8+ZN\n2LlzG2xsbBAcXBujR4/Fzp1boVbb4Ndf92DixNfx779/w8nJCQMHDsa4caNQp04ETp48gczMTMya\nNRdr136FK1cuo2PHzhg5cgwAYMaMqYiLu4fs7Cz07/8CevTohVWrViI7OwvDh0eidu0QzJ49D7/8\nsgebN2+CVqtB/foNMGXKdJOuE8zQJSJSGOe5s2Bv4qX9snr0QloJiyg0btwM3367Hn37DkBU1AXk\n5ORAq9Xi1KkTaNy4mcFYjUaD5csX4733liEsrBY2bdqGTz/9CDNmvIX167/Bli27YWNjg7S0VDg7\nu+C55/rqQxYA/v33b4Pt2dra4Ysv1mDz5k2YPn0KvvpqPVxcXDFgQC8MGBAJNzc3zJw5B66ursjK\nysLIkUPQvn1HjB49Ftu2bcaXX64HAFy/fg2///4LVq36Emq1GkuXLsIvv+wx2awaYOgSEZEJRETU\nRVTUeaSnp8PW1hYREXVx/vw5nDx5HJMmTTMYGxNzHVeuXMakSa9BrVYhO1sDHx9fAEBYWDjmzn0T\n7dp1wBNPdDDqe7dt2w4AEBoahpCQUHh6egEAqlevgXv37sLNzQ3ffbcBBw7sAwDcu3cPN2/GoH79\nBgYrIv3779+4eDEKI0cOgSRJyM7OhpeXV0X/aAwwdImIFCZt7rslzkorg42NDfz9A/Djj7vQsGFj\nhIWF4/jxf3H79i0EBQUXGi0hJCQUn3zyZZFrL7///gqcOPEfDh7cjzVrvsSaNd+W+r1tbe0A6BZc\nsLW11b8uCAK0Wi2OHz+G//77F5999jXs7OwwbtwoZGdnF7MlCd26dceoUa+V40/AODyRioiITKJx\n46bYuHEdmjRphkaNmmDHjq0ID69TZFxgYDAePEjEmTOnAeh2N1+9egUAcPduLJo2bY4xY8YhLS0N\nGRnpcHJyQlpaWrnrSktLhaurK+zs7HD9+jWcPXtG/56trS20Wi0AoHnzR7F37+948OABACA5ORmx\nsbHl/r7F4UyXiIhMonHjpli79is0aNAQ9vYOsLe3L3I8F9DNit99dxGWL38fy5cvQnZ2Dp5//gXU\nqhWId96ZnRuwEvr3HwhnZxe0adMOs2a9gb/+2o+JE183OLGppJOc8t5r1ao1duzYisGDn0dgYBAa\nNGioH9OzZ2+89NJARETUxezZ8/Dyy2MwefJrEEUJtra2mDx5Gvz9/U32Z8Sl/R5CSctNWXsfSugB\nYB+WRAk9AMroQwk9AFzaj4iIyOIwdImIiMyEoUtERGQmDF0iIiIzYegSERGZCUOXiIjITBi6RERk\ndt99txFZWVlyl2F2DF0iIjK7zZs3Iisrs9j3RFE0czXmw9AlIqIK27BhDbZu1V0n+YMPlmLCBN2S\neseO/YN582YbjN2yZRPi4+MwbtxovPTSSwCAzp3bYeXK5Rg2bBDOnDmF/v17Ijk5CQBw4cJ5jBs3\nCgCQmZmJhQvfwciRL2H48ME4eHC/uVo0CV4GkohIgbyaNyj29YRjZ4p9vazjCyvL0n79+g3Et99u\nxIcfforQ0BqIi0tBZmYGGjRoiLFjJ+aOMry8Y94lHb/5ZjWaN38UM2a8hdTUVIwcOQQtWz4Ke3sH\no+qUG0OXiIgqrCxL++lIuTcdtVqN9u07Fnq/qH/+OYpDhw5g48Y1AHSLJdy9G4vAwGCT9VKZGLpE\nRApk7Ay1vOMLK9vSfkXZ2dkbLF6gVqshirrgzc7OP+FKkiS8++5i1KoVWKF65cJjukREZBLGLu0H\nAE5OzgbL9RVeeycgoDqios4DAPbt+0P/+qOPPoYtWzbpn0dHR5myhUpn1Ey3Y8eOcHFxgUqlgo2N\nDbZs2VLZdRERkZUxdmk/AOjZsxemTh2PgAB/LFmyssgSfUOHjsR7770DFxcXNG3avMDrL+ODD5bi\npZcGAgD8/QOwaNH/Kq8pEzNqab9OnTph27ZtcHd3N2qjFy9ehKdnQIWLk5OSlpuy9j6U0APAPiyJ\nEnoAlNGHEnoATLy0nyRJZfrc1IABA5CTk2P0eCIioqrAqNAVBAEjRoxA37598d1335U6/sSJE/jw\nQ+uZ7hMREZmDUcd0N23aBF9fXyQkJGDYsGEICQlBixYtHjq+Ro0aWLp0Ebp164769R8xWbFERETW\nzKhjugWtXLkSzs7OGDZs2EPH/PDDD3j22WfRvHlzHDlyBDY2/GQSERFRqWmYkZEBURTh7OyM9PR0\nHDx4EGPHji3xa7p3747nn38B3323EXPnvosJE6aYrGBzUdLBfWvvQwk9AOzDkiihB0AZfSihB8D4\nE6lKDd34+HiMHTsWgiBAq9WiR48eaNu2bakbfvfd97Bv3594//2F6NatOyIi6hpVEBERkVKVeiJV\nrVq1sHPnTuzYsQO7d+/GK6+8YtSGPTw88f77y5GdnY0JE8ZAo9FUuFgiIrJMsbF3MGTIAJNuMzr6\nIg4f/kv//ODB/Vi//huTbFuupQUr9YpU3bo9g759n8d//x3DqlUfVea3IiIimRW+wEVFXbp0EUeO\n5Idu27btEBn5kkm2XdLSgpWp0s9wmj9/Efbv34tFi95F165PP/SSYEREZN00Gg3eeWc2Ll68gNq1\nQzFr1tuwt7c3GHPr1k0sW7YYSUmJcHBwwHvvLYCLiw/++OM3fP3151Cr1XB2dsHy5R/jiy9WITs7\nG6dPn8TgwcOQlZWJCxfOYdKkaViw4G3Y2dkjOjoKiYkPMGPGW9iz53ucPXsa9es3wMyZcwAAS5a8\nh6ioc8jKykKHDp0wfPgrBksLenh4YMWKT/D330fw5ZefIScnBzVq1MTMmXPg4GD6lYsqPXS9vLyx\nePH/MGxYJCZMeBW7d/8MtVpd2d+WiKjKmjvXHrt3m/af9x49NJg7t+TdsTEx1zFjxhw0aNAQCxe+\ng+3bN2PgwMEGYxYvXoBp02aiRo2aOHfuDObOnYslS1bim2++wLJlH8HHxwdpaamwsbHByy+PRlTU\neUyc+DoAYM+e7w1m06mpKfj0069w8OA+vPHGJKxa9RVq1w7BiBEv4tKlaISFhWPUqNfg6uoKURQx\nYcIYXLlyyWBpQTc3NyQlJWLNmi+xYsXHsLd3wPr132DTpnUYOvRlk/4ZAmZaZah79x7o1asPduzY\nhs8//wSjR5d89jMREVkfPz9/NGjQEADQtesz2LLlW4PQzcjIwJkzJzF79hsFFjjQ3Tds2Bjz589B\nx46d0b79k0Z9vzZtngAAhISEwcvLG7VrhwAAatcOQWzsbYSFheP333/Grl07oNVqkZBwH1evXkVI\nSBgKLi149uwZXLt2BWPGjIAkSdBoNGjQoFHF/0CKYbYP0C5YsAQHD+7HggXvoEuXbrlNExGRqc2d\nm1XqrLQyFD6mW/gQrySJcHV1w5dfrte/lveRoalTZ+D8+bM4dOggRox4EatXryv1+9nZ2QEAVCqV\n/nHec61Wizt3bmPTpvVYvXotnJ1dsGDB2wbLBObXJaFly8cwZ867ZWm3XMy2tJ+Pjw/ee28pMjMz\nMWHCa2W6ljMREVm+2Ng7OHtWty7vr7/+jEaNmhi87+TkjICA6vjzz9/0r124cAGA7lhvvXqPYMSI\nUfDw8MS9e3fh5ORksPxfSYq7zlNaWhocHR3h5OSMhIT7OHLkkEEtedt+5JGGOH36JG7dugkAyMrK\nxI0bMWXo3HhmvVRUz5690aPHduzevQOrV3+KkSPHmPPbExFRJQoKCsa2bd9h4cK3ERwcgl69+hUZ\nM2fOu3j//YX45psvodVq0LNnD/Tv/yI+/ngFbt68AQBo3rwlwsLCUa2aH9at+xrDh0di8OCHXwUR\nKP7M6bCwcISHRyAysh+qVfNDo0aN9e/lLS3o4+OLFSs+wcyZczB37kxkZ+dAEASMHDkGtWoFVvBP\npJg6y3oZSGM97AojcXFxeOKJlsjMzMSffx7S74O3NEq6Soq196GEHgD2YUmU0AOgjD6U0ANg4qX9\nTMnX1xcLFy5Beno6Jk0ay93MRERUZZg9dAGgV6++ePrpZ3Ho0EF8/fVqOUogIiIyO1lCVxAELF78\nP3h4eOCdd97C9evX5CiDiIjIrGQJXQDw8/PD/PmLkZ6ehsmTxxV75hkREZGSyBa6ANCv3wB06dIN\nBw7sw5o1X8lZChERUaWTNXQFQcCSJSvg7u6BuXNnVdrnooiIiCyBrKELAP7+AZg3byHS0lK5m5mI\nyEoZu7Tfnj3f4/79eDNUZJlkD10AGDBgEDp16ox9+/7Ehg1r5S6HiIjKwZil/X78cTfi4uKKfa8q\nfITUIkJXEAQsXfoBXF3d8NZbM3H79i25SyIiojLKW9pv8OD+mD17epFF4vfu/R0XLpzHvHmzMXx4\nJLKystCxY0d88smHGDHiRfz5528YN24UoqJ0l4ZMSkpE//49AegC+eOPV2DkyJcwdOgg7Nq13ez9\nmYJFhC4AVK9eA++8swApKcmYMmU8dzMTEVVA8+bOxd5MNb44MTHX0afP81i3bjOcnJywfftmg/c7\ndOiEevXqY86cd/Hll+v1a+26u3tg9eq16NSpSzFb1c2ev/9+J1xcXPH559/g88+/wa5d2xEbe6dM\n9VkCiwldABg06EV06NARv//+K779doPc5RARURkUXtrv1KmTRcZIkoTCc6pOnTqXuu2//z6Cn376\nAcOGDcIrr7yE5OQkqzz51qwLHpRGEAQsW/Yh2rV7DLNnz0CHDh3h7x8gd1lERFbn2DHjVucp7/ji\nlLa038M4OjrqH6vVakiS7thudnZ2gVESJk16HS1bPlbRMmVlUTNdAKhZsxbmzJmHpKRETJ06gbuZ\niYisRGlL+wGAs7Mz0tJSH7qNgIAauHDhHAAYLAH46KOPY9u2LdBoNACAGzdikJWVacryzcLiQhcA\nhgwZhieeaI9ffvkJW7Z8K3c5RERkhLyl/QYP7o+UlORil/Z7+ulnsWTJQv2JVIVnxy+8EInt27di\n+PDBSE5O1r/eo0cvBAfXxogRgzFkyAAsWbIQWq220nsyNbMv7WesmJjraNfuMdjZ2eLAgX/g5+dn\nosqMo6Tlpqy9DyX0ALAPS6KEHgBl9KGEHgALXtrPWIGBQZg9+20kJiZi2rRJ3M1MRERWz2JDFwCG\nDXsZrVu3xZ4932PHjq1yl0NERFQhFh26KpUK//vfSjg5OWHGjKm4d++e3CURERGVm0WHLgDUrh2C\nN9+cg4SEBMyYMVXucoiIiMrN4kMXAEaMGIVWrR7H7t07rPbSX0RERFYRuiqVCitWfAQHBwdMnz4F\n8fFVd4UKIiKyXlYRugAQEhKGGTPeQnx8PGbO5G5mIiKyPlYTugDwyitj0KLFo9ixYxt++GG33OUQ\nERGViVWFrlqtxooVH8Pe3h7Tpk1CQsJ9uUsiIiIymlWFLgCEh9fBG2/MQlzcPbz55htyl0NERGQ0\nqwtdABgzZiyaNWuOrVu/w08//Sh3OUREREaxytDV7Wb+BHZ2dnj99YlITHwgd0lERESlssrQBYCI\niLp4/fUZuHs3FrNnz5C7HCIiolJZbegCwGuvTUDjxk3x7bcb8OuvP8ldDhERUYmsOnRtbGzwwQef\nwNbWFlOnTkRSUqLcJRERET2UVYcuANSrVx+TJ0/DnTu3MWfOm3KXQ0RE9FBWH7oAMH78ZDRo0Agb\nNqzFH3/8Jnc5RERExVJE6Nra2uKDDz6BjY0NJk8eh5SUZLlLIiIiKkIRoQsADRo0xMSJU3H79i3M\nnTtb7nKIiIiKUEzoAsDEiVNRv34DrF37Ffbt+1PucoiIiAwYHbqiKKJ3794YPXp0ZdZTIXZ2dvjg\ng4+hVqsxefI4pKamyF0SERGRntGhu2bNGoSGhlZmLSbRqFETjB8/CTduxGDevDlyl0NERKRnVOjG\nxsZi37596N+/f2XXYxKTJ7+BunXr4auvvsDBg/vlLoeIiAiAkaG7YMECTJs2DYIgVHY9JmFvb48V\nKz6GSqXCxIljkZaWJndJREREsCltwN69e+Hj44N69erh6NGjRm/Y19e1QoVVVJcuHTBt2jS89957\nWLZsAT744IMyb0PuHkxFCX0ooQeAfVgSJfQAKKMPJfRgLEGSJKmkAcuWLcOuXbugVquRlZWFtLQ0\ndO7cGYsXLy5xw3Fx8p/ElJmZiaeeegIXL0Zh5849ePzxNkZ/ra+vq0X0UFFK6EMJPQDsw5IooQdA\nGX0ooQfA+P84lLp7efLkydi7dy9+//13LFu2DK1atSo1cC2Fg4MDli//CCqVChMmvIr09HS5SyIi\noipMUZ/TLU6LFo9i9OixuHbtKhYunCd3OUREVIWVKXQfffRRrFq1qrJqqTRvvPEmQkPD8NlnH+Po\n0SNyl0NERFWU4me6AODo6Ijlyz8GAEyc+CoyMjJkroiIiKqiKhG6ANCq1WN45ZUxuHz5EhYtmi93\nOUREVAVVmdAFgBkz3kJwcG2sWrUS//77t9zlEBFRFVOlQtfJyQkrVnwMURQxYcKryMzMlLskIiKq\nQqpU6ALA44+3wcsvj0J09EUsWfKe3OUQEVEVUuVCFwDefHMuAgODsXLlchw/fkzucoiIqIqokqHr\n7OyM5ctX6nczZ2VlyV0SERFVAVUydAGgbdt2GDp0BC5cOI///c86rrBFRETWrcqGLgC89dY7qFUr\nECtWLMOpUyfkLoeIiBSuSoeui4srli37EFqtFuPHv4rs7Gy5SyIiIgWr0qELAO3bP4kXXxyGc+fO\nYPnyJXKXQ0REClblQxcA5s6dhxo1amL58iU4c+a03OUQEZFCMXQBuLq6YenSD6DRaDB+/Bjk5OTI\nXRIRESkQQzdXx45PYdCgF3HmzCl8+OH/5C6HiIgUiKFbwNtvz4e/fwCWLl2E06e5m5mIiEyLoVuA\nu7sHli5dgZycHAwdOhSpqalyl0RERArC0C2kc+duiIwcgv/++w8DBvRGcnKS3CUREZFCMHSL8f77\nyzFo0CD8889R9O3bEwkJ9+UuiYiIFIChWwwbGxusWbMGgwa9iJMnj6N372cRFxcnd1lERGTlGLoP\noVarsWzZhxg+fCTOnz+LXr2exp07t+Uui4iIrBhDtwQqlQoLFy7Bq6+OR3T0RfTs2Q03bsTIXRYR\nEVkphm4pBEHAnDnzMGXKG7h+/Rp69uyGK1cuy10WERFZIYauEQRBwBtvvIlZs+bi1q2beO65pxEV\ndUHusoiIyMowdMtg/PjJmD9/Ee7ejUWvXk/j9OlTcpdERERWhKFbRiNHjsGSJSuQkJCAPn2exfHj\nx+QuiYiIrARDtxyGDBmGDz9chZSUZPTt2xNHjhyWuyQiIrICDN1yev75F/DZZ18hMzMDAwf2xoED\n++QuiYiILBxDtwJ69uyNr75aD41Gg0GD+uG3336WuyQiIrJgDN0K6tr1aaxb9x1UKhVeemkQfvhh\nt9wlERGRhWLomkCHDh2xceNW2NnZ4+WXh2Dbts1yl0RERBaIoWsirVu3xebNO+Ds7IIxY17Ghg1r\n5S6JiIgsDEPXhFq0eBTbtu2Gp6cnJk58DatXfyZ3SUREZEEYuibWqFETbN/+I3x9q2HGjKn4+OMP\n5S6JiIgsBEO3EtSrVx87d+5BQEB1zJ37JpYuXQRJkuQui4iIZMbQrSRhYeHYuXMPAgODsGjRfCxY\n8A6Dl4ioimPoVqLg4NrYuXMPQkJCsWLFUsyePZ3BS0RUhTF0K1mNGjWxc+dPqFu3Hj777BNMnToR\noijKXRYREcmAoWsGfn5+2L79RzRo0Ahr136FceNGQ6PRyF0WERGZGUPXTLy9vbFt2240b94Cmzdv\nwujRI5CTkyN3WUREZEYMXTPy8PDE5s078fjjbbBr13YMHz4YmZmZcpdFRERmwtA1MxcXV2zcuBXt\n2z+Jn3/egyFDBiI9PV3usoiIyAwYujJwcnLC2rXfokuXbti79w8MGtQPqakpcpdFRESVrNTQzc7O\nRv/+/dGrVy/06NEDK1euNEddiufg4IAvv1yHHj164dChg+jfvxeSkhLlLouIiCqRTWkD7OzssGbN\nGjg6OkKr1eKFF15Au3bt0KhRI3PUp2h2dnb49NMvYW9vjy1bvkWfPj3w3Xc74O3tLXdpRERUCYza\nvezo6AhAN+vlR11My8bGBitXfooXXxyK06dPok+f7rh7967cZRERUSUodaYLAKIook+fPoiJiUFk\nZGTps9zgYHiJRa+8lHDsTLHDvZo3KPZ1WcerhCI9VGY9XwFwGDkan3++Cr16PY2tW3ejevUaFd9+\ngT6s6s+/oNweLKaeco5HzHWLqofjOd4SxisiL4CH/v0uzKjQValU2LFjB1JTU/Hqq6/i0qVLCAsL\nK/Fr1CqhyGu+vq4P+QZFx1rC+MI9VHY9n376Mby83LFo0SL07v0M/vjjDwQHB1d4+3l9yP3nWZHx\napVgUfWUZ/xDv8ZK6i843uBrLaCe8ozXP7eQeso7vrh/a+Wsp8zjoYy8MJYglfFiwCtXroSzszOG\nDRtW4ri4OOs+G9fX11WWHiRJwtKli7B48QJUr14D27btRkhIyf/BKYlcfZiSEnoA2IclUUIPgDL6\nsPgeRBHIzISQmQEh9x4ZmRCyMiFkZgKZGRAyMuE+dJBRmyt1ppuQkABbW1u4uroiMzMThw8fxiuv\nvFLhPqh4giBg6tTpcHBwxDvvzEbPnk9jy5ZdqFu3ntylERHJy8gAzHtfN7bg89z3szLzt5M7Hpm5\n28nIHZf3fna2cbWZKnTj4uIwffp0iKIIURTxzDPPoH379sYVQeU2duwEODo6YMaM19G79zP47rsd\naNiwsdxlEREVJUlAdjaE9DQI6em5tzT9PdLTIaQ95D1JA9fEFMMAzMrSPy9XAJa1fEEAHB0hOThA\ncnCE5OICyccXkoM9JAdHIO91BwdIjo6Avb3hcwcHuBj5vUoN3YiICGzfvr2CLVF5jBgxCg4Ojpg8\neRz69OmBTZu2onnzlnKXRUTWSJKAjIwioWd4nw4UfC2taEgWHZf7ulZb7tIcCpZZXAB6+0BydCga\ngA4ORQPRwQGSfe57jo4FxjoCDvaGz/O2aWsLCGU7NluYyUKX5BUZOQQODg4YO3YU+vV7Dhs2bMbj\nj7eRuywiqkyiqAuylBQIqakQUpINH6emQJWaCkg5cI5/UCQQ9Y/T8h8jIx2CCdbzllQqSE7OkJyc\nACcniN4+kJyc9K9JTk6QnAs8dnIGDN43fM+rpi/i00VdANo7AHZ2FQ5AS8bQtQJ9+z4POzt7jB49\nHAMH9sGaNZvQvv2TcpdFRAVJku64YEpKbiimFB+aqbrHKv17KbrX8h6npEBISzU6IJ2KK8XWVh9u\nors7pIDqucH38PArGJYlhSTs7U0bir6ukCz5RCoTY+haiR49noODw3oMH/4iBg9+HqtXr0GXLk/L\nXRaR9cvJyZ095oeeKi0lPwALhmZaam5gFgzRlPyvL+fFgyQ7O0iurpBcXCEGBUN0dc197gLJxS3/\nsasrJFfc0+i4AAAgAElEQVQ3iC4ukFxc4FHTDwlZAJxzw9HRUReMtram/TMik2HoWpHOnbth/frN\nGDJkIIYOjcSnn36JHj16yV0WkbxEEUJyEoTERKiSEiEkJkJISoQqMbH415ISgdRkeCcl6YKynMtr\nSmo1JBddOIoB1SE560JRdHXLD0gXV/2Y/OB0g+iS/1hycdHNHsvD1xXaKjRLVAKGrpVp164DNm3a\nhkGD+mPkyKH48MNV6N9/oNxlEVWMkcGpSnxQJECF5KQyHauUnJwAd3eInl6QagUWmUnqQzMvLAuG\npqsrRGfdPRwdFX3skSoHQ9cKPfZYa2zZshMDBvTB2LGjkJWVhcGDX5K7LKrqSgzOB/qQzAvSigan\n6O4BsXp1iPXqQ/LwgOTuAbHQveThAdHDE5KHJ0R3D0ju7oC9PXx9XfGAM0SSAUPXSjVr1gLbtn2P\n559/DpMnj0NmZgZefnm03GWRUmRkQBUfB9X9eKjux0OIj4fq/n2o7scDmalwi40zTXB6eEKsXgNi\n/UfyQ1Iflh6FXvPUvwc7u0psnqjyMHStWMOGjbBjxx707dsDM2dOQ0ZGJsaNmyh3WWSJ0tL0AaoP\n0fgCz+/H54bsfaji43UXLShB3hFIyckZoocHg5PISAxdKxcRURe7du1B3749MW/eW8jISMfrr8+A\nwGNNyiVJhiEaHwchNyzzn+cFqm52KqSnl75Ze3uI3j7QhIVD8vaG6O2ju/n4QPLxzX3uDc/QWojX\n2up21TI4icqEoasAISFh2LlTN+NdsuQ9ZGZmYvbstxm81kKSdB9FiYszDMr4eMNdvPfv658bc8at\n5OCgC9HwiEIh6gvJx0cfonnPJWcX404MqmKfqyQyJYauQgQGBmHXrp/Qt28PrFy5HBkZ6Zg/f7Hc\nZVVdkqQ7eSg2Fqo7t6G6GwukJcL5+q1Cx0lzH2dllb5JR0eIPr7Q1K2nuwpQboDqZ6O5AZoXrnB2\n5tm1RBaGoasgAQHVsWPHHvTv/xxWr/4MWVlZ+Prr1XKXpTzp6VDF3oH6bm6g6oP1DtR37kAVeweq\nu7HFzkYLXj1IcnKG6OMDTf1HdCFaIDAfGqJEZNUYugpTrVo1bN/+PQYM6IN1677BvXt3MG/eYtSu\nHSJ3aZZPo4Hq3l1daBYIT/Wd27rHsXd0AZuU+NBNSCoVRN9qutmof4D+pg2oDrewIDywdc4PUafi\nLuBHRErG0FUgLy9vbN26C6+8Mgy//PIL9u/fjwkTpmDs2ImwL++Vb6yZJEF4kKAL0oKz0dhYqGIL\nzFTj7pX4kRfRwwNiQAA0TZvlBmkARL8AiAHVIfr76+59fAGbh/y18nWFhsdCiao0hq5Cubm5Y+PG\nrfjzzz2YMGEiFi2ajy1bvsWiRcvQrl0HucsznbQ0qO8WmJnmBqsqNm+GGgvV3TslHjOVHBwg+gcg\np9XjEIsJUq2fP0T/AN0ViIiIKoChq2CCIGDAgAFo0aINFi2aj9WrP0O/fj3Rp08/vP32Qvj5+cld\n4sPlnoikvhEDJMXB4eKV/BlqgWBVJSc9fBMqFUQ/f90xU78AXaDm7uoV/fz1wSq5e/CEIyIyC4Zu\nFeDm5o758xdjwIBBmDZtErZt24Jff/0FM2fOxtChL0OtVstTWFoa1DdioI65BlXMdaivX4c6RndT\nxVyHKiVZP9S10JeKnp4Qa9SEpnkLaP0Dip2hij6+gFy9EREVg6FbhTRq1AQ//PAb1q79GvPnv40Z\nM17Hpk0b8P77/0OTJs1M/w2zs6G6dVMfpLowvaZ7fP06VPFxxX6Z5OQEbWAQcgJbQxsYBKd6dZDs\n6gWtf26g+gcADg6mr5eIqJIxdKsYtVqNoUNH4JlneuDtt2dh8+ZN6Nr1SQwdOgIzZ74Fd3cP4zcm\nirqPzsRch+r6NYNZqjrmOlR3bkMQxSJfJtnaQluzFjSPNIA2MAjawCCIuffawGBIPj4Gu3udfF2R\nxROQiEgBGLpVVLVq1fDRR59h0KAXMW3aJHz11Rf4/vtdePvt+ejb93nd1awkCUJCAtS5s1OVfvdv\n7u7gmzcgZGcX2bYkCBADqiPn0ccKhGkQxKBg3b1/AHf7ElGVxNCt4to2boIDH32OXz/7GCd3bkPW\nqyNxcdZ0NPX0hGNsLFRpqcV+nejjkztTDS4UrEHQ1qhV/kW5iYgUjKGrdFlZUF+OLjBLzdv9mzt7\nTUgAAAzOvQEAEu4jOeE+Yn184dGmLVA7JDdYdTNVba1AwMVFro6IiKwWQ1cJJAlCfDxsoqOgvhgF\ndXQUbC5GQX0pGrh9C17FXPBBsreHtlYgNI2b5odpkC5Qf4m+iCnz38btO7cReOEC3hs6Ak891VWG\nxoiIlIWha01EEapbN2Fz8QLUFy/mh2t0FFQPHhQZrq1eA2jfHhkBNQ1OVBKDgiBW8wNUqmK/Taem\nzXHwmR5YunQRPv30Iwwa1B/du/fEu+++hxo1alZ2l0REisXQtUQ5OVBfvQL1xagCs9eLsLl0sci6\nqJJKBW1wbeS0ehza8Aho6kRAWycC2vA6kFxc4evritRynPnr4uKCOXPm4fnnX8C0aZPwww+78Oef\nv2PatJkYOXI0bG1tTdUtEVGVwdCVU1oabC5dzA/V3Fmr+uoVCBqNwVDJwQHa0HBo6tTJD9fwCGhD\nQiv1pKV69epj5849+PbbDXj77VmYO/dNfPvtBixe/D+0avVYpX1fIiIlYuiagXD/ftHjrdEXob55\no8hY0d0DmibN8kO1Th1owiMg1gqU7WM2KpUKL7wwGF27Po13352Ldeu+QY8eXRAZOQSzZ78NLy9v\nWeoiIrI2DF1TkSSobt8qsEv4ItQXL8AmOgqq+/eLDNf6+SP7iQ76UNXWiYAmPAJStWoWex1gLy9v\nLFv2IQYMiMS0aZOwfv0a7NnzPd56ax4GDoyE6iHHiImISIehW1YaDdTXrhaatUZBHR1d5DOtkkoF\nMTAIWc1bFtglXAfaOhGQ3NxlaqDiWrV6DL/9th9ffPEpFi2aj4kTX8OGDWuxePH/UL/+I3KXR0Rk\nsRi6D5OeDpvTJwuEq+5sYfWVyxBycgyGSnZ20IaGI7tAqGrCI6ANDVPsNYJtbW0xZsxYPPdcb8ya\nNR3ff78TnTq1xahRr2Hq1Olw4ed4iYiKYOgCEFKSYXPqJGxOnoDNqeOwOXkCuHIZnoU+3yq6uELT\nsBG0deoW2CVcB2JQcJW9rGH16jXw5Zdr8dtvP2P69Nfx8ccfYMeOrZg/fzGeeeZZ3eUkiYgIQBUM\nXSE5CTanT+kC9uR/uvsrlw3GiG7uQLt2yKgdVuCEpgjdNYMZIsV66qmuOHCgHVasWIIPP1yOYcMi\n0blzVyxY8D6CgoLlLo+IyCIoOnSF5KQiM9giAevugewn2kPTqAk0TZoip1ETiMG14VvNrVyfb63K\nHB0dMX36bPTtOwBvvDEZv/76Mw4e3I9Jk17Hq6+Oh52dndwlEhHJSjGhW6aAbdwUmsZN9AHL2atp\nhYfXwdatu7F163eYM+dNLFjwDjZv3oRFi5ahbdt2cpdHRCQbqwzdIgF74jhsrl4xGKML2A7QNG7C\ngJWBIAjo128AOnfuioUL5+Grr75Anz7Pol+/AZg7dz6qVasmd4lERGZn8aGrD9gTx/NnsKUFbOOm\nupObGLCyc3f3wHvvLcWAAYMwbdpkbNnyLX755Se8+eYcDBkyDOoqegIaEVVNFhW6QlJi0V3EJQRs\nTpOm0DRqwoC1Ak2bNsdPP/2Br79ejQUL3sEbb0zGpk3r8P77y9GoURO5yyMiMgvZQteogPXwQHa7\nJ3Nnr00YsFZOrVZjxIhX8OyzPTFnzkxs27YFXbp0wPDhIzF9+iy4WfEFQ4iIjGGW0DUI2JPHYXvy\nONTXrhqMYcBWHX5+/li16ku88MKLmD59Cr744lPs2rUD8+YtRK9effnZXiJSrMoJ3T/+gOPev2Bz\n6oRxAdu4KcTAIAZsFdO+/ZPYu/cwVq5cjuXLl2DUqOFYv34tFi1agtDQcLnLIyIyucoJ3U6dkHcR\nQIOAzTsGy4ClXPb29pgy5Q306dMfM2ZMxR9//Ib27R/HuHGTMGHCFDgo9DKaRFQ1VU7oTp+OpPD6\nDFgyWu3aIdi4cSu+/34XZs16A0uXLsLWrd/lnvncW+7yiIhMotS12GJjYzFkyBA888wz6NGjB9as\nWVP6VhcuRHaPXjwmS2UiCAJ69HgOf/31D0aNeg03bsRg4MA+6NevH44d+wdSoWthExFZm1JDV61W\nY8aMGfjxxx+xadMmrF+/HpcvXy7ty4jKzcXFFfPmLcSvv+5HixaPYuvWrXj66U7o0OFxfPbZx0hI\nKLo+MRGRNSg1dH19fVGvXj0AgLOzM0JDQ3Hv3r1KL4yoQYOG+P77X/DTTz+hZ8/euHQpGrNmTUej\nRhEYNWoY9u/fC1EU5S6TiMhoZTqme/PmTVy4cAGNGjWqrHqIDKhUKnTt2hXNmrVGfHw8Nm/ehHXr\nvsb27VuxfftWBAUFIzJyCAYOjIS/f4Dc5RIRlUiQjDxQlpaWhhdffBGvvvoqnnrqqRLHBgej2BnI\nsWNpxY5v3ty52NflHK9SqYr0YE315ynYhyXUU57xeT3kjZckCX//fRTr13+DXbu2Iz39LADAwcER\nLi4ucHBwgCAIFlN/npgYFeKKWbnK0v/8C4/39XU16EPuesozvmAPllBPecf7+roiMLD4vT3WUD8A\ntGzpavV5Aej+fhvDqJmuRqPB+PHj8dxzz5UauHlUqqIF+Pq6PmRs8duQe3zhHuSup7zj8/qwlHrK\nM16lUhmMf/bZznj22c5ISkpCSIgKqakpyMzMQGZmBtRqNVxcXJCUdB9hYWEWUX9JX2MNf/6Fxxd8\nbAn1lGd83nNLqaf844v/AmupX/c11p8XxjJqpjtt2jR4enpixowZRm+4uP/RW5PC/5u3Vkrow9ge\nTp8+hQ0b1mDLlu+QlJQIAGjbth0iI4ege/eesn/mVwk/C0AZfSihB0AZfSihB6Dk/1QUVGpGHzt2\nDLt378aRI0fQq1cv9O7dG/v3769wgUSm1rBhIyxcuASnTkXh448/R5s2T+Dgwf0YM+ZlNGpUBzNn\nvo6zZ8/IXSYRVWFGH9MtK2v/n4uS/vdl7X1UpIcrVy5hw4Z12LhxHeLidGfdN2vWHJGRL6F3775w\ncTHuf6emoISfBaCMPpTQA6CMPpTQA2DCmS6RNQsJCcOsWXNx4sR5fPPNRnTp0g0nThzHlCnj0aBB\nHUyc+Br++ecoL7xBRGbB0KUqwdbWFk8/3R3r1n2H48fPYcaM2fDx8cWGDWvRvXtntGvXCqtWrcT9\n+7zwBhFVHoYuVTkBAdUxadLr+PvvE9i8eSd69eqDq1ev4K23ZqJRozoYOXIo9u79gxfeICKTk20R\neyK5qVQqtG//JNq3fxL379/Hli2bsH79GuzcuQ07d25DYGAQXnhhMF54YTCqV68hd7lEVMmys4H0\ndCA9XShwr3uclpb/WkZG0TEbNxr3PXgi1UMo6eC+tfdhzh4kScKxY/9g/fo12L59K9LT06BSqdCx\n41OIjHwJXbp0g62tbbm2rYSfBaCMPpTQA6CMPsrSgyiimMDLv8/IKDksC79XOEA1mvIv0GNsknKm\nS1SAIAho0eJRtGjxKObNW4gdO7Zh/fpv8Ntvv+C3336Br281DBgwCIMHD0FISNELbxCRjiQBaWlA\naqqAlBQBKSmGj9PSdI+1WiA+3v6hQVg4LE1BpZLg5AQ4OenuvbzEAs91rzk7S3B0zB9jeF/0NehX\nkS8ZZ7oPoYT/QQLK6MMSejh37iw2bFiDzZs34cGDBwCA1q3bIjJyCJ599jk4OjqWug1L6MMUlNCH\nEnoATN+HJAFZWXnhmB+SqanIDUvd4/zwzH+vuK+RpPKHpL19ySFX8N7R0XCMs/PDw9LREbC3N/2q\ns8Z+ZIih+xD8S2k5LKmHzMxM7NnzPdatW4MDB/YCANzc3NGv3/OIjHwJDRs+fDEQS+qjIpTQhxJ6\nAPL70GhQKAx1j4ubZRYOybzHea/n5JQvjezsJLi6SnBxAVxcdI9dXQFXVwnOzvmPdWN0z11cJNSs\n6YTs7LQiM0y12sR/WJWMoVtBSvtLac0stYdr165i48a12LhxPWJj7wAAGjduisjIIejTpx/c3NwN\nxltqH2WlhD4srQdJ0h2rTEwUkJgoICkp7x548CD/eeH309JUSE6Wyr3bVRAMw9DZuWAwokBA5j/P\nC1NdkOaHp719+Xq3tJ9FeTF0K0hJvwjW3oel96DRaPD7779i/fpv8OuvP0Or1cLR0RE9e/ZGZORL\naNXqMQiCYPF9GEsJfVRWD5mZKBSQMAjJwqFZ8P3sbOOD09ZWgru7BC8vFRwdtUVmjwXDMO/1ggGa\n956Tk+l3s5aVEn6fAIZuhSnpF8Ha+7CmHmJj7+Dbbzdg/fo1uHbtKgAgLCwckZEvYeTIobCzc5O5\nwoqzpp/Hw5TUQ04ODELz4YFZ9P3MTOMTTK2W4OEhwd0d8PCQ9Dd3d6nQ86Lv54Wl0n8W1oShW0FK\n+kWw9j6ssQdRFHHo0EGsW/cNfvhhF7KysgAAoaFhaN26rf4WEFBd5krLzlp+Hnlnz8bHC7h/X0B8\nvID4eBXu3xeQnm6PO3dy9KFZcBduWXbVCoIuFN3dJXh65gem4fOi73t46HbXVnSWaS0/i5IooQeA\noVthSvpFsPY+rL2HBw8SsG3bZhw48Cf27z+A1NT8XkJCQvUB3KbNE1YRwnL+PDIzDUM0Li7vsapQ\nuOoeZ2QYl2pubg+bZepC82GzUFfXsq+nakrW/ncDUEYPAEO3wpT0i2DtfSihB0DXx507D3DmzCn8\n9ddBHDp0AEeOHEZKSrJ+TO3aIWjT5gk8/ngbtGnzhEVeCcuUP4+cHCAhQReexYWmLlhV+sepqaWH\nqL29BB8fw5u3twQfH1H/PDTUCZKUCg8PCW5ugI2VXrFACX83lNADwNCtMCX9Ilh7H0roASi+D61W\naxDChw8fMgjh4ODaBiFco0ZNc5ddREk/D61Wd7Zt4QDNn5EWfE+FxMTSQ9TGJi808wPU17domOa9\n7uxc+m5bJf9OWRsl9AAwdCtMSb8I1t6HEnoAjOtDq9Xi7NnTBiGcnJykfz8oKNgghGvWrFXZZUOj\nAe7dE3DnjoDYWBUyMx1x7VpWkRlpfLyAhAQBolhy4qlUEry8Cs9Ciz7OC1N398q5kEFV+Z2ydEro\nAWDoVpiSfhGsvQ8l9ACUrw+tVotz587gr78O4NChgzh8+BCSkhL17wcGBqNNm/wTs2rVCizT9tPS\ngNhYAbdvq/SheueOgNu38x/fu1d6kHp46EKy5Bmp7ubpKcl+4YOq/DtlaZTQA8DQrTAl/SJYex9K\n6AEwTR+6ED6LQ4cO4K+/DuLw4b8KhXAQWrdui8cea4v69dtDrQ7MDVEVYmMF3LmjC1LdTYXk5IeH\nqZ2dBH9/CQEBIgICJAQESPD3FxEW5gBb23T4+OhC1ctLQjnXgJANf6cshxJ6AIwPXSs9fYCoalKr\n1ahTpxHc3BqjcePx6NVLwokT93DqVDyuXMnErVu22LTJD5s21QBg99DtuLtLqFFDRPPmulD195dQ\nvXr+44AA3ey0uN26vr4OiIvTVl6TRArG0CWyEJIEJCejwK5e3Wy04K7e2FjdCUiGaufedMdLfXyy\n4eAQj5yca0hMPI2srCsAbgG4CT8/EW3ahKB9+1Zo3botAgODIMh9SSKiKoShS2QGWi1w6xZw5oyq\nwK7egrt7da+VdGEGJyfdDLRuXU3uzFTM3eWrm6FWr67b3as7XuoKoCFE8RGcP38Ohw8fxF9/peDw\n4YPYtu0Atm37BgBQo0ZN/WeEW7dui6CgYIYwUSXiMd2HUNJxBmvvw1p6SEkBrl9X4do1Fa5fF3D9\nukr//ObNkldv8fExPG4aEKAL1bxdvQEBItzcKn4WryiKuHDhfG4IH8Thwwdx//59/fs1atTUnxnd\nunVbBAfXLhLC1vLzKIkSegCU0YcSegB4IlWFKekXwdr7sJQetFrdmb7Fher16wLu3y/+0kQ+PiKC\ngiSEhqrh5ZVtcGJSQIAIP7/yr9BSUaIoIirqAg4dOoBDh/7CoUMHDEK4evUaBiFcu3YIqlVzs4if\nR0VYyu9URSmhDyX0ADB0K0xJvwjW3oc5e0hNhT5M84JVF6oq3LhR/EowtrYSAgMlBAWJ+ltwcP5z\nFxfz91FekiQhKuoC/vrrAA4f1oVwfHy8/n1//wA0bdoEgYEhqFMnAuHhEQgPrwNvb28Zqy47a/hZ\nGEMJfSihB4BnLxMVSxR1s9W8UL12LT9Ur18v7iQlHW9vEQ0aFAxV3ew1KEg3a5X7c6emIggC6tat\nh7p162HEiFcgSRIuXozSh/CRI4ewZ8+eIl/n7e2tD+D8WwRq1qwFlZwXJyayMAxdUpz0dBjMVAvu\nAo6JUSErq+hs1cZGQq1aEho00BQJ1aAg3fHUqkgQBERE1EVERF0MHz4SAGBjo8GRI/8hOvoiLl6M\nwqVLuvu//z6CI0cOGXy9o6MjQkPDUadOnQKhHIGQkFDYy7VPnUhGDF2yOpIE3L1reGy14Gz13r3i\nZ1aenhLq1Ss6Uw0K0p35a60XvTc3T09PtGjxKFq0eNTg9czMTFy9egXR0VEFwvgiLl+OxpkzpwzG\nqlQqBAUFo06dCISF1cndVa2bIbu7e5izHSKz4j8zZJEkSbcb+MIFFWJjgTNn7PWhGhOjKnbJNrVa\nQs2aEtq10+hDVXevu7m7y9BIFeLg4IB69eqjXr36Bq+LooibN2/khvFF/cw4OjoKP/+8Bz//bLi7\nulo1v9wwDjc4bhwQUJ0fZyKrx9AlWUmS7mL6Fy6oEBWlu124oMbFiyokJRX8B1Z3dSU3Nwnh4WKB\nMJX0M9caNThbtUQqlQqBgUEIDAxCp05dDN67f/8+oqOj9Luqo6OjcOlSNA4e3I+DB/cbjHV2dkF4\neDjCwyMMZsjBwbVha23XoaQqi/9EkdnExeWHa37Iqoss76ZWSwgJEfHEEyIiIkS0bGkPb+80BAWJ\n8OCeR0Xx9vaGt3drPPZYa4PX09PTcflydIEw1s2Qz507ixMnjhuMtbGxQe3aIQXCOFx/7+Ji3Bml\nRObC0CWTu39fKBSsulvhz7GqVBJq15bQurUGdevqAjYiQkRoqGjwuVVfX3vExYlm7oLk5OTkhIYN\nG6Nhw8YGr2s0GsTEXEN0dLTBSVzR0RcRHX0RP/6422B89eo1DM6mzpsh+/i4mLMdIj2GLpXbgwdA\nVJS60K5hVZGP3QiChKAgCS1b5hiEa1iYCAcHmYonq2RjY4OQkDCEhISha9en9a9LkoR79+4VOYkr\nOjoK+/b9iX37/jTYjouLC/z9A+DvHwA/P//cez+D1/z8/OHk5GTuFknhGLpUqqQk4MIFtUGwRkWp\nij1LODBQRJcuGkREaBERIaJuXV248t8uqkyCIMDPzw9+fn5o27adwXupqSkFPt6kmyHfvHkdt2/f\nxqVL0SVu193dA/7+/vDzC4C/v39uKPvnhnL+Y378iYzF0CW9lBTkBqraIFxjY4uGa61aIp56SpM7\na9Wibl0R4eEinJ1lKJyoBC4urmjatDmaNm2ufy3vKkhZWVm4d+8u7t6NRWxsLO7evYPY2FjExt5B\nbOyd3NfvICrqQonfw8vLq5hgDjAI6WrV/HjCFzF0q6LUVODixfwzhfPC9fbtouFao4aIjh01ubNW\n3ey1Tp38SxsSWTN7e3vUqhWIWrUCSxyXkZGBe/fuFgjm/HDOC+Zbt27i/PmzD92GIAjw9vbRB3HB\nXdsFw9nHxxc2PA1fsfiTVbD0dODff4HDh230s9eoKBVu3CgargEBIjp00Oh3CeftHnblyZ9EcHR0\nRFBQMIKCgkscl5aWhrt3Y/VBnB/M+Y+vXLlc5GIhBalUKvj6Vis0Yy46g7a2612TDkNXITQa3a7h\n48fVOH5chf/+081gRREAHPXjqlUT8cQT+WcL581eeeEIoopzdnZGSEgoQkJCSxyXmppisBu74K7t\n/F3a53Hy5PGHbsPGxgZeXl5wc3OHu7s73N09Ctx76J97eHjAza3ovVopFwy3MgxdKyRJQEyMgOPH\n1fjvP13InjqlNrhKk5OThJYttWjZ0gaBgZn62aunp4yFExEA3XHmsDBXhIWFP3SMJElITk4q9hjz\n3bt3ERt7B8nJiUhIeICYmOvIzs4uUw2urm7FhLW7QVjnv+ZpEOCOjo68Olg5MXStQEICcOJEXsDq\nQrbgx3JUKgl164po1kyLZs1ENG2qm73a2OSdMJIjY/VEVB6CIOhnrBERdYsdk3dCmCRJyMjIQHJy\nEhITE5GUlISkpAe597rniYmJ+vcL3sfEXEdKSnKZarOzs9PPmjnLLhuGroXJyABOn87bTawL2mvX\nDI/BBgaKeO65HDRtqgvZhg21PGuYqAoTBAFOTk5wcnKCv39Amb9eq9UiOTnJIKSTkhILBHhigZth\nkF+/fg05OWX7j72rq5s+gH18vGBraw9HR139jo6OcHJy1t87ORV87lRgXP69s7Pu3hrCnKErI60W\niI5W4b//VPpZ7PnzKmg0+bttPD0ldOyoyQ1YLZo0EeHrK8lYNREpjVqthqenFzw9vcr8tWWdZRcM\n7piY6zh79rTJ+rC3ty8S2oXDOu/2sJA3DHvDcLezs6vwbnWGrplIEnD7tqA/Bnv8uBonTqiRlpb/\nA7S3l9CkiW43cdOmulvt2hJ46ISILFVFZ9ne3s6IibmH9PR0ZGSkF3ufd8vIyEB6elqh+4LjdK+l\npaUjOTkZsbGxyMhIhyia5jKyarW6SFjnzcT3799r1DZKDd2ZM2di79698Pb2xu7du0sbTrmSkqDf\nRZx3NnHBKzgJgoSICBFNm4r6WWzduiLs7GQsmojIzFQqFZydneFcScfIJElCVlZWgSDXBXZ6enEB\nXtmry4cAAAsRSURBVFyQFwx9w/vExESkp6eVafd6qaHbp08fvPjii5g2bVqFGleyrCzg7FmVwdnE\nly4ZHluoXl1E9+45aNpUN5Nt3FjLz8ASEVUyQRDg4OAABweHcu0+N4ZJQ7dFixa4detWhQpSElEE\nLl/WHYfNm8meOaNCTk7+PmBXV91C6rrdxLqZrL8/j8MSESlRWS7vyWO6pbh7V3ccNu9kpxMn1EhJ\nyQ9YW1sJDRqI+mOwzZrplqZTFb3oExERVXEM3QIkCTh/XoUDB9Q4fhw4csS5yPWIw8K06NYt/2Sn\nRx4xXPuViIjoYSotdH19reOA5Y0bwG+/6W6//w7cvZv/np+fCj17Ao8+CrRqBbRoAXh4qAGoAVjP\naiHW8rMoiRJ6ANiHJVFCD4Ay+lBCD8YyKnQlqezHI+PiUsr8NeaQlAQcPGiD/fvV2L/fBpcv589k\nq1UT0a+fFu3aadCzpyMcHVMMPq6TkwPExclQdAXkXbHGmimhB4B9WBIl9AAoow8l9AAY/x+HUkN3\nypQpOHr0KBITE9GhQweMGzcOffv2rXCB5pKVBfzzj1ofsidOqCCKuiR1dpbQpYsG7dpp0K6d7tKJ\neSHr62t9AUtERJat1NBdunSpOeowGVHUfXxn3z5dyB49mr8QgI2NbhGAdu10t2bNtOCa0kREZC6K\nOJHq+nUB+/frdhkfOKBGQkL+LuN69fJCVoPHH9dy8XUiIpKNVYZuQoLuuGzebPb69fyQrV5dxMCB\nOWjXToMnntDCz4+fjyUiIstgFaGbkQEcPZp/XPb0aRUkSbfL2M1NwjPP5Ohns6GhvFYxERFZJosM\nXa0WOHVKpd9l/PffamRl6ZLUzk5Cmzb5u4wbNdKtG0tERGTpLCKuJAm4elXAvn26kD140AZJSfnT\n1YYN80O2VSstnJxkLJaIiKicZAvde/cEHDyYv8v45s3847KBgSJ69tTtMm7TRgsfHx6XJSIi62e2\n0E1NBY4cUetns+fP56/C4+kp6UO2XTsNgoMZskREpDyVFro5OcDx4/nHZf/9Vw2NRrfL2MFBQvv2\nugtStG+vQYMGXCCAiIiUr1JCt2dP4M8/XZCaqgtZQZDQpImov/JTy5ZaODhUxncmIiKyXJUSurt3\nAyEhEvr10+0ybttWAw+PyvhORERE1qNSQvfaNcDJKa0yNk1ERGS1KuVIalBQZWyViIjIuvH0JSIi\nIjNh6BIREZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIR\nEZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eI\niMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMGLpE\nRERmwtAlIiIyE4YuERGRmRgVuvv370e3bt3QtWtXfPbZZ5VdExERkSKVGrqiKGLevHlYvXo1vv/+\ne/zwww+4fPmyOWojIiJSlFJD99SpUwgKCkKNGjVga2uL7t274/fffzdHbURERIpSaujevXsXAQEB\n+ud+fn64d+9epRZFRESkRKWGriRJ5qiDiIhI8WxKG+Dv74/bt2/rn9+9exfVqlUrdcO+vq4Vq8wC\nKKEHQBl9KKEHgH1YEiX0ACijDyX0YKxSZ7oNGzZETEwMbt26hezsbPzwww/o1KmTOWojIiJSlFJn\numq1GrNnz8bw4cMhSRL69euH0NBQc9RGRESkKILEg7ZERERmwStSERERmQlDl4iIyEwYukRERGZS\n6olUZbF//34sWLAAkiShb9++eOWVV0y5ebOYOXMm9u7dC29vb+zevVvucsolNjYW06ZNQ3x8PNRq\nNfr3748hQ4bIXVaZZWdnIzIyEjk5OdBqtejatSvGjh0rd1nlIooi+vbtCz8/P6xatUrucsqlY8eO\ncHFxgUqlgo2NDbZs2SJ3SeWSkpKCN998E9HR0VCpVFiwYAEaN24sd1lGu3r1KiZNmgRBECBJEm7c\nuIEJEyZY5d/xr7/+Glu2bIEgCKhTpw4W/r+9u3mJag8DOP6dHKRQexElCyzIjCySFr1AEyamSTXV\nxGCLNiVRbdIow14oghYJLfoHWkREEBEaRG1EszGmQiuGYIgwIhhMKkRT5yXPnOcu4l64G+89x7nz\na7rPZz1n+A6HmYcznHmmo4P8/HzTWY7cunXrr/fCv/qslQxJp9NSX18vsVhMfvz4IXv37pWhoaFM\nPX3WDAwMSDQaFb/fbzrFtS9fvkg0GhURkcnJSdmxY0dOngsRkXg8LiIilmVJU1OTRCIRw0Xu3Lx5\nU9ra2uT48eOmU1yrq6uTsbEx0xmzdvbsWbl//76IiExPT8vExIThIvfS6bT4fD4ZHh42neLYyMiI\n1NXVSSqVEhGRkydPSldXl+EqZ96/fy9+v19SqZRYliWHDx+WT58+zXhMxr5e/l12NG/YsIH58+eb\nzpiV0tJSqqqqACgoKKCioiJnV3fOmzcP+HnVa1mW4Rp3RkZGePr0KU1NTaZTZkVEsG3bdMasTE5O\nMjg4SDAYBMDr9VJYWGi4yr1wOMyyZcv+tqo3l9i2TSKRwLIsksnkv1q89Cv58OED69evJz8/n7y8\nPDZu3Eh3d/eMx2Rs6OqO5l9TLBbj3bt3VFdXm05xxbZtAoEAPp8Pn8+Xk6/j6tWrtLe34/F4TKfM\nisfj4ciRIwSDQe7du2c6x5VYLMaiRYs4f/48+/fv59KlSySTSdNZrj1+/Jjdu3ebznBl8eLFNDc3\nU1tbS01NDUVFRWzZssV0liOVlZUMDAwwPj5OIpEgFArx+fPnGY/J2NAV/bnvL2dqaorW1lYuXLhA\nQUGB6RxX5syZw4MHDwiFQkQiEYaGhkwnOdLX10dJSQlVVVU5/x65e/cunZ2d3Lhxgzt37jA4OGg6\nyTHLsohGoxw8eJCuri7mzp2bs/8RPj09TW9vLzt37jSd4sr379/p6enhyZMn9Pf3E4/Hc+4+moqK\nCo4ePUpzczPHjh1j9erVeL0z3yqVsaHrdkez+m9YlkVrayv79u2jvr7edM6sFRYWsmnTJvr7+02n\nOPL69Wt6e3vZvn07bW1tvHz5kvb2dtNZrpSWlgJQXFxMQ0MDb9++NVzkXFlZGWVlZaxbtw6AxsZG\notGo4Sp3QqEQa9eupbi42HSKK+FwmPLychYuXEheXh4NDQ28efPGdJZjwWCQzs5Obt++zYIFC1i+\nfPmMj8/Y0P2ddjTn+hUJ/LwLe+XKlRw6dMh0imujo6NMTEwAkEwmef78OStWrDBc5czp06fp6+uj\np6eH69evs3nzZq5du2Y6y7FEIsHU1BQA8XicZ8+eUVlZabjKuZKSEpYsWcLHjx8BePHiRc6utX30\n6BF+v990hmtLly4lEomQSqUQkZw9F6OjowAMDw/T3d39j+ckYz8Z+l12NP95NTI2NkZtbS0tLS1/\n3XSRK169esXDhw9ZtWoVgUAAj8fDqVOnqKmpMZ3myNevXzl37hy2bWPbNrt27WLbtm2ms/6Xvn37\nxokTJ/B4PKTTafbs2cPWrVtNZ7ly8eJFzpw5g2VZlJeX09HRYTrJsWQySTgc5sqVK6ZTXKuurqax\nsZFAIIDX62XNmjUcOHDAdJZjLS0tjI+P4/V6uXz5MkVFM/9jku5eVkoppbJEN1IppZRSWaJDVyml\nlMoSHbpKKaVUlujQVUoppbJEh65SSimVJTp0lVJKqSzRoauUUkpliQ5dpZRSKkv+AO2e4yf8wTuC\nAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0xc1dc310\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Train our variables.\n", - "\n", - "# numpy is used for its asscalar() function.\n", - "import numpy as np\n", - "\n", - "num_training_steps = 10\n", - "\n", - "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n", - " loss_at_step = []\n", - " w_at_step = []\n", - " b_at_step = []\n", - " for step_num in range(num_training_steps):\n", - " loss, gradients_and_variables = value_and_gradients_fn(inputs, labels, wb)\n", - " loss_at_step.append(np.asscalar(loss.numpy()))\n", - " \n", - " optimizer.apply_gradients(gradients_and_variables)\n", - " w, b = wb.variables\n", - " w_at_step.append(np.asscalar(w.read_value().numpy()))\n", - " b_at_step.append(np.asscalar(b.read_value().numpy()))\n", - "\n", - " print(w_at_step)\n", - " t = range(0, num_training_steps)\n", - " plt.plot(t, loss_at_step, 'k',\n", - " t, w_at_step, 'r',\n", - " t, [true_w] * num_training_steps, 'r--',\n", - " t, b_at_step, 'b',\n", - " t, [true_b] * num_training_steps, 'b--')\n", - " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n", - " plt.show()\n", - "\n", - "train_model(inputs, labels, wb, optimizer, num_training_steps)" + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "UNurY9VJ-hpH" - }, - "source": [ - "## Other Ways to Compute Gradients\n", - "\n", - "Using our loss function as an example (`loss_fn()`), there are several other ways we could compute gradients:\n", - "\n", - "1. `tfe.implicit_gradients()`\n", - "1. `tfe.gradients_function()`\n", - "1. `tfe.implicit_value_and_gradients()`\n", - "1. `tfe.value_and_gradients_function()`\n", - "\n", - "Each of these functions does the following:\n", - "* Wraps a function.\n", - "* Returns a function with the same input signature as the wrapped function.\n", - "\n", - "They differ only in what information they return.\n", - "\n", - "### Gradients-only functions\n", - "\n", - "The following two functions return a function that returns only the variables' gradients:\n", - "\n", - "1. `tfe.gradients_function()`: Returns the partial derivatives of the function `f()` with respect to the parameters of `f()`.\n", - "1. `tfe.implicit_gradients()`: Returns the partial derivatives of the function `f()` with respect to the trainable parameters (`tf.Variable`) used by `f()`.\n", - "\n", - "In our example above, the `tf.layers.Dense` object encapsulates the trainable parameters.\n", - "\n", - "### Value and gradients functions\n", - "\n", - "The following two functions are identical to their counterparts above, except that they also return the value of the wrapped function.\n", - "\n", - "1. `tfe.implicit_value_and_gradients()`\n", - "1. `tfe.value_and_gradients_function()`\n", - "\n", - "### Gradient demos\n", - "\n", - "In the demos below, we show examples for the `implicit_*` functions, since our existing loss function works seamlessly with these versions. (The other versions require that your parameters are tensors and tensors only; in our example, we're using a `Dense` layer.)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 85, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 100, - "status": "ok", - "timestamp": 1505502831671, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "aEoCftnfAIH5", - "outputId": "72f1c1dc-a574-463f-f860-c4e5f48fcdaa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[(\u003ctf.Tensor: id=673, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n", - " (\u003ctf.Tensor: id=671, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)]" - ] - }, - "execution_count": 13, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# tfe.implicit_gradients() demo\n", - "gradients_fn = tfe.implicit_gradients(loss_fn)\n", - "\n", - "# Returns only gradients and variables:\n", - "gradients_fn(inputs, labels, wb)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 102, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 88, - "status": "ok", - "timestamp": 1505502831785, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "bbgCUdCzAVhH", - "outputId": "152aa9b6-9e42-4b7e-848a-9423c0b1929c" + "id": "4U1KKzUpNl58" }, - "outputs": [ - { - "data": { - "text/plain": [ - "(\u003ctf.Tensor: id=688, shape=(), dtype=float32, numpy=1.0623235\u003e,\n", - " [(\u003ctf.Tensor: id=720, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n", - " (\u003ctf.Tensor: id=718, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)])" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], "source": [ - "# tfe.implicit_value_and_gradients() demo\n", - "value_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)\n", + "## Next Steps\n", "\n", - "# Returns the value returned by the function passed in, gradients, and variables:\n", - "value_gradients_fn(inputs, labels, wb)" + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." ] } ], "metadata": { "colab": { + "collapsed_sections": [], "default_view": {}, - "last_runtime": { - "build_target": "", - "kind": "local" - }, - "name": "Eager Execution Tutorial: Working with Gradients", + "name": "Automatic Differentiation", "provenance": [], "version": "0.3.2", "views": {} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb index 0088da5c4b583dd13251de5839235de666fe8b78..bfcc7feb075c403d024772e0d715339d58877a51 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb @@ -16,7 +16,9 @@ "\n", "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", "\n", - "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." + "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", + "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", + "As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." ] }, { @@ -48,11 +50,8 @@ "# Import TensorFlow.\n", "import tensorflow as tf\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "import tensorflow.contrib.eager as tfe\n", - "\n", "# Enable eager execution\n", - "tfe.enable_eager_execution()" + "tf.enable_eager_execution()" ] }, { @@ -137,32 +136,27 @@ "source": [ "# Step 3: Iterate\n", "\n", - "Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n", - "\n", - "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls." + "When eager execution is enabled `Dataset` objects support iteration.\n", + "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 0, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "height": 153, - "output_extras": [ - { - "item_id": 1 - } - ] + "base_uri": "https://localhost:8080/", + "height": 153 }, "colab_type": "code", "executionInfo": { - "elapsed": 201, + "elapsed": 388, "status": "ok", - "timestamp": 1505952405928, + "timestamp": 1525154629129, "user": { "displayName": "", "photoUrl": "", @@ -171,7 +165,7 @@ "user_tz": 420 }, "id": "lCUWzso6mbqR", - "outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a" + "outputId": "8e4b0298-d27d-4ac7-e26a-ef94af0594ec" }, "outputs": [ { @@ -179,9 +173,9 @@ "output_type": "stream", "text": [ "Elements of ds_tensors:\n", - "tf.Tensor([4 9], shape=(2,), dtype=int32)\n", + "tf.Tensor([1 9], shape=(2,), dtype=int32)\n", "tf.Tensor([16 25], shape=(2,), dtype=int32)\n", - "tf.Tensor([36 1], shape=(2,), dtype=int32)\n", + "tf.Tensor([ 4 36], shape=(2,), dtype=int32)\n", "\n", "Elements in ds_file:\n", "tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n", @@ -191,22 +185,19 @@ ], "source": [ "print('Elements of ds_tensors:')\n", - "for x in tfe.Iterator(ds_tensors):\n", + "for x in ds_tensors:\n", " print(x)\n", "\n", "print('\\nElements in ds_file:')\n", - "for x in tfe.Iterator(ds_file):\n", + "for x in ds_file:\n", " print(x)" ] } ], "metadata": { "colab": { + "collapsed_sections": [], "default_view": {}, - "last_runtime": { - "build_target": "", - "kind": "local" - }, "name": "Eager Execution Tutorial: Importing Data", "provenance": [], "version": "0.3.2", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..84f1d031d40604ae029e8a8347474950ee01b38a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k2o3TTG4TFpt" + }, + "source": [ + "# Training Models\n", + "\n", + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3LXMVuV0VhDr" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "PJ64L90aVir3" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eMAWbDJFVmMk" + }, + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VkJwtLS_Jbn8" + }, + "outputs": [], + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wfneTXy7JcUz" + }, + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "itxmrMil6DQi" + }, + "outputs": [], + "source": [ + "v = tfe.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-paSaeq1JzwC" + }, + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BMiFcDzE7Qu3" + }, + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "gFzH64Jn9PIm" + }, + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_WRu7Pze7wk8" + }, + "outputs": [], + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tfe.Variable(5.0)\n", + " self.b = tfe.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xa6j_yXa-j79" + }, + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Y0ysUFGY924U" + }, + "outputs": [], + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qutT_fkl_CBc" + }, + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gxPTb-kt_N5m" + }, + "outputs": [], + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-50nq-wPBsAW" + }, + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 293 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1210, + "status": "ok", + "timestamp": 1527005898290, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "_eb83LtrB4nt", + "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current loss: 9.48636\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "sSDP-yeq_4jE" + }, + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "MBIACgdnA55X" + }, + "outputs": [], + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RwWPaJryD2aN" + }, + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 446 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 569, + "status": "ok", + "timestamp": 1527005915434, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "XdfkR223D9dW", + "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: W=5.00 b=0.00, loss=9.48636\n", + "Epoch 1: W=4.58 b=0.42, loss=6.28101\n", + "Epoch 2: W=4.24 b=0.76, loss=4.29357\n", + "Epoch 3: W=3.98 b=1.02, loss=3.06128\n", + "Epoch 4: W=3.78 b=1.23, loss=2.29721\n", + "Epoch 5: W=3.61 b=1.39, loss=1.82345\n", + "Epoch 6: W=3.49 b=1.52, loss=1.52970\n", + "Epoch 7: W=3.38 b=1.62, loss=1.34756\n", + "Epoch 8: W=3.30 b=1.70, loss=1.23463\n", + "Epoch 9: W=3.24 b=1.76, loss=1.16460\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vPnIVuaSJwWz" + }, + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Training Models", + "provenance": [], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5749f22ac58e0a012ed7e3fec4dfe2913d3f8273 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -0,0 +1,551 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "pwX7Fii1rwsJ" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UEu3q4jmpKVT" + }, + "source": [ + "# High level API\n", + "\n", + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zSFfVVjkrrsI" + }, + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8PyXlPl-4TzQ" + }, + "outputs": [], + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fn69xxPO5Psr" + }, + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1527783641557, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "E3XKNknP5Mhb", + "outputId": "c5d52434-d980-4488-efa7-5660819d0207" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", + "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" + ] + }, + "execution_count": 3, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1527783642457, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Wt_Nsv-L5t2s", + "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 226, + "status": "ok", + "timestamp": 1527783643252, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "6ilvKjz8_4MQ", + "outputId": "f647fced-c2d7-41a3-c237-242036784665" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" + ] + }, + "execution_count": 5, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "O0kDbE54-5VS" + }, + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 391 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 251, + "status": "ok", + "timestamp": 1527783661512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "5Byl3n1k5kIy", + "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", + "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", + " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", + " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", + " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", + " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", + " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", + " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", + " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", + " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", + " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" + ] + } + ], + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tk8E2vY0-z4Z" + }, + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Qhg4KlbKrs3G" + }, + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 190 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 420, + "status": "ok", + "timestamp": 1527783698512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "N30DTXiRASlb", + "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]\n", + "\n", + " [[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" + ] + } + ], + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wYfucVw65PMj" + }, + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 361, + "status": "ok", + "timestamp": 1526674830777, + "user": { + "displayName": "Alexandre Passos", + "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", + "userId": "108023195365833072773" + }, + "user_tz": 420 + }, + "id": "L9frk7Ur4uvJ", + "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", + "array([[[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]],\n", + "\n", + " [[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]]], dtype=float32)\u003e" + ] + }, + "execution_count": 26, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "c5YwYcnuK-wc" + }, + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "4 - High level API - TensorFlow Eager.ipynb", + "provenance": [], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 536cad998d94e45187d30fce3be0d7a57178e0c1..0c0e28dd95c68dc300384a128eb5aa2208f63a0d 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -14,6 +14,17 @@ py_library( ], ) +py_library( + name = "resnet50_test_lib", + srcs = ["resnet50_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":resnet50", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + cuda_py_test( name = "resnet50_test", size = "large", diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index 9982fdb07eefa665379e7be095f4f8017d92cf97..a28bc8a43d7c90737c9baf9a634d736e9de52948 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -27,10 +27,11 @@ from __future__ import print_function import functools import tensorflow as tf -import tensorflow.contrib.eager as tfe +layers = tf.keras.layers -class _IdentityBlock(tfe.Network): + +class _IdentityBlock(tf.keras.Model): """_IdentityBlock is the block that has no conv layer at shortcut. Args: @@ -50,31 +51,24 @@ class _IdentityBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - data_format=data_format, - name=conv_name_base + '2b')) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) + self.conv2a = layers.Conv2D( + filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) + self.bn2a = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = layers.Conv2D( + filters2, + kernel_size, + padding='same', + data_format=data_format, + name=conv_name_base + '2b') + self.bn2b = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -92,7 +86,7 @@ class _IdentityBlock(tfe.Network): return tf.nn.relu(x) -class _ConvBlock(tfe.Network): +class _ConvBlock(tf.keras.Model): """_ConvBlock is the block that has a conv layer at shortcut. Args: @@ -121,41 +115,35 @@ class _ConvBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - strides=strides, - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - name=conv_name_base + '2b', - data_format=data_format)) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) - - self.conv_shortcut = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - strides=strides, - name=conv_name_base + '1', - data_format=data_format)) - self.bn_shortcut = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')) + self.conv2a = layers.Conv2D( + filters1, (1, 1), + strides=strides, + name=conv_name_base + '2a', + data_format=data_format) + self.bn2a = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = layers.Conv2D( + filters2, + kernel_size, + padding='same', + name=conv_name_base + '2b', + data_format=data_format) + self.bn2b = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') + + self.conv_shortcut = layers.Conv2D( + filters3, (1, 1), + strides=strides, + name=conv_name_base + '1', + data_format=data_format) + self.bn_shortcut = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '1') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -176,7 +164,8 @@ class _ConvBlock(tfe.Network): return tf.nn.relu(x) -class ResNet50(tfe.Network): +# pylint: disable=not-callable +class ResNet50(tf.keras.Model): """Instantiates the ResNet50 architecture. Args: @@ -220,32 +209,28 @@ class ResNet50(tfe.Network): self.include_top = include_top def conv_block(filters, stage, block, strides=(2, 2)): - l = _ConvBlock( + return _ConvBlock( 3, filters, stage=stage, block=block, data_format=data_format, strides=strides) - return self.track_layer(l) def id_block(filters, stage, block): - l = _IdentityBlock( + return _IdentityBlock( 3, filters, stage=stage, block=block, data_format=data_format) - return self.track_layer(l) - - self.conv1 = self.track_layer( - tf.layers.Conv2D( - 64, (7, 7), - strides=(2, 2), - data_format=data_format, - padding='same', - name='conv1')) + + self.conv1 = layers.Conv2D( + 64, (7, 7), + strides=(2, 2), + data_format=data_format, + padding='same', + name='conv1') bn_axis = 1 if data_format == 'channels_first' else 3 - self.bn_conv1 = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')) - self.max_pool = self.track_layer( - tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format)) + self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') + self.max_pool = layers.MaxPooling2D( + (3, 3), strides=(2, 2), data_format=data_format) self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) self.l2b = id_block([64, 64, 256], stage=2, block='b') @@ -267,13 +252,12 @@ class ResNet50(tfe.Network): self.l5b = id_block([512, 512, 2048], stage=5, block='b') self.l5c = id_block([512, 512, 2048], stage=5, block='c') - self.avg_pool = self.track_layer( - tf.layers.AveragePooling2D( - (7, 7), strides=(7, 7), data_format=data_format)) + self.avg_pool = layers.AveragePooling2D( + (7, 7), strides=(7, 7), data_format=data_format) if self.include_top: - self.fc1000 = self.track_layer( - tf.layers.Dense(classes, name='fc1000')) + self.flatten = layers.Flatten() + self.fc1000 = layers.Dense(classes, name='fc1000') else: reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] reduction_indices = tf.constant(reduction_indices) @@ -288,7 +272,7 @@ class ResNet50(tfe.Network): else: self.global_pooling = None - def call(self, input_tensor, training=False): + def call(self, input_tensor, training): x = self.conv1(input_tensor) x = self.bn_conv1(x, training=training) x = tf.nn.relu(x) @@ -317,7 +301,7 @@ class ResNet50(tfe.Network): x = self.avg_pool(x) if self.include_top: - return self.fc1000(tf.layers.flatten(x)) + return self.fc1000(self.flatten(x)) elif self.global_pooling: return self.global_pooling(x) else: diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 23317886e712323f4b520000e0fd372734fc53a1..551c76b0df71c88919df9cd6d81b4176b23b0ba3 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -55,7 +55,7 @@ class ResNet50GraphTest(tf.test.TestCase): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() @@ -114,7 +114,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 0ff8746884c288f824f5f22ab4c550370d0e0302..b14ef1df8ff4c660b9b6f2abfd5df6572d10b1e8 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -36,9 +36,7 @@ def device_and_data_format(): 'channels_last') -def random_batch(batch_size): - _, data_format = device_and_data_format() - +def random_batch(batch_size, data_format): shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) shape = (batch_size,) + shape @@ -51,41 +49,50 @@ def random_batch(batch_size): return images, one_hot -def train_one_step(model, images, labels, optimizer): - - with tfe.GradientTape() as tape: +def compute_gradients(model, images, labels): + with tf.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - grads = tape.gradient(loss, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) class ResNet50Test(tf.test.TestCase): - def _apply(self, defun=False): + def _apply(self, defun=False, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - with tf.device(device): - images, _ = random_batch(2) - output = model(images) + with tf.device(device), tfe.execution_mode(execution_mode): + images, _ = random_batch(2, data_format) + output = model(images, training=False) + tfe.async_wait() self.assertEqual((2, 1000), output.shape) def test_apply(self): self._apply(defun=False) + def test_apply_async(self): + self._apply(defun=False, execution_mode=tfe.ASYNC) + def test_apply_with_defun(self): self._apply(defun=True) + def test_apply_with_defun_async(self): + self._apply(defun=True, execution_mode=tfe.ASYNC) + def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): - images, _ = random_batch(2) - output = model(images) + images, _ = random_batch(2, data_format) + output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) @@ -94,11 +101,11 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): - images, _ = random_batch(2) - output = model(images) + images, _ = random_batch(2, data_format) + output = model(images, training=False) self.assertEqual((2, 2048), output.shape) - def test_train(self): + def _test_train(self, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() @@ -106,34 +113,44 @@ class ResNet50Test(tf.test.TestCase): with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) - images, labels = random_batch(2) - train_one_step(model, images, labels, optimizer) + images, labels = random_batch(2, data_format) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) self.assertEqual(320, len(model.variables)) + tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') + def test_train(self): + self._test_train() + + def test_train_async(self): + self._test_train(execution_mode=tfe.ASYNC) + def test_no_garbage(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): - images, labels = random_batch(2) + images, labels = random_batch(2, data_format) gc.disable() # Warm up. Note that this first run does create significant amounts of # garbage to be collected. The hope is that this is a build-only effect, # and a subsequent training loop will create nothing which needs to be # collected. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() previous_gc_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL) for _ in range(2): # Run twice to ensure that garbage that is created on the first # iteration is no longer accessible. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() # There should be no garbage requiring collection. self.assertEqual(0, len(gc.garbage)) @@ -155,7 +172,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): def _train_batch_sizes(self): """Choose batch sizes based on GPU capability.""" for device in device_lib.list_local_devices(): - if 'GPU:0' in device.name: + if tf.DeviceSpec.from_string(device.name).device_type == 'GPU': # Avoid OOM errors with larger batch sizes, which seem to cause errors # later on even if caught. # @@ -166,79 +183,120 @@ class ResNet50Benchmarks(tf.test.Benchmark): return (16,) if 'P100' in device.physical_device_desc: return (16, 32, 64) + + if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': + return (32,) return (16, 32) def _report(self, label, start, num_iters, device, batch_size, data_format): avg_time = (time.time() - start) / num_iters - dev = 'cpu' if 'cpu' in device else 'gpu' + dev = tf.DeviceSpec.from_string(device).device_type.lower() name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format) extras = {'examples_per_sec': batch_size / avg_time} self.report_benchmark( iters=num_iters, wall_time=avg_time, name=name, extras=extras) - def _force_gpu_sync(self): - # If this function is called in the context of a GPU device + def _force_device_sync(self): + # If this function is called in the context of a non-CPU device # (e.g., inside a 'with tf.device("/gpu:0")' block) - # then this will force a copy from CPU->GPU->CPU, which forces - # a sync. This is a roundabout way, yes. + # then this will force a copy from CPU->NON_CPU_DEVICE->CPU, + # which forces a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False): - device, data_format = device_and_data_format() - model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) - batch_size = 64 - num_burn = 5 - num_iters = 30 - with tf.device(device): - images, _ = random_batch(batch_size) - for _ in xrange(num_burn): - model(images).cpu() - gc.collect() - start = time.time() - for _ in xrange(num_iters): - model(images).cpu() - self._report(label, start, num_iters, device, batch_size, data_format) - - def benchmark_eager_apply(self): - self._benchmark_eager_apply('eager_apply', defun=False) - - def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) - - def _benchmark_eager_train(self, label, make_iterator, defun=False): - device, data_format = device_and_data_format() - for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size) - num_burn = 3 - num_iters = 10 + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(0.1) - + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 30 with tf.device(device): - iterator = make_iterator((images, labels)) + images, _ = random_batch(batch_size, data_format) for _ in xrange(num_burn): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() gc.collect() - start = time.time() for _ in xrange(num_iters): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format) - def benchmark_eager_train(self): - self._benchmark_eager_train('eager_train', MockIterator, defun=False) + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, data_format) + model = resnet50.ResNet50(data_format) + optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + 'eager_train_async', + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, defun=True) + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) def benchmark_eager_train_datasets(self): @@ -248,7 +306,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, defun=False) + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) def benchmark_eager_train_datasets_with_defun(self): @@ -258,7 +317,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, defun=True) + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index aa87b94e7b0876e65405f6bcb2d6aabde36582bf..5ee2176154ec7011dcb3d7b384a86213e778014f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -60,6 +60,7 @@ import functools import os import sys import time +import urllib import six import tensorflow as tf @@ -72,6 +73,8 @@ try: except ImportError: HAS_MATPLOTLIB = False +layers = tf.keras.layers + def parse(line): """Parse a line from the colors dataset.""" @@ -89,13 +92,35 @@ def parse(line): return rgb, chars, length +def maybe_download(filename, work_directory, source_url): + """Download the data from source url, unless it's already here. + + Args: + filename: string, name of the file in the directory. + work_directory: string, path to working directory. + source_url: url to download from if file doesn't exist. + + Returns: + Path to resulting file. + """ + if not tf.gfile.Exists(work_directory): + tf.gfile.MakeDirs(work_directory) + filepath = os.path.join(work_directory, filename) + if not tf.gfile.Exists(filepath): + temp_file_name, _ = urllib.request.urlretrieve(source_url) + tf.gfile.Copy(temp_file_name, filepath) + with tf.gfile.GFile(filepath) as f: + size = f.size() + print("Successfully downloaded", filename, size, "bytes.") + return filepath + + def load_dataset(data_dir, url, batch_size): """Loads the colors data at path into a PaddedDataset.""" # Downloads data at url into data_dir/basename(url). The dataset has a header # row (color_name, r, g, b) followed by comma-separated lines. - path = tf.contrib.learn.datasets.base.maybe_download( - os.path.basename(url), data_dir, url) + path = maybe_download(os.path.basename(url), data_dir, url) # This chain of commands loads our data by: # 1. skipping the header; (.skip(1)) @@ -109,7 +134,7 @@ def load_dataset(data_dir, url, batch_size): # pylint: disable=not-callable -class RNNColorbot(tfe.Network): +class RNNColorbot(tf.keras.Model): """Multi-layer (LSTM) RNN that regresses on real-valued vector labels. """ @@ -127,23 +152,20 @@ class RNNColorbot(tfe.Network): self.label_dimension = label_dimension self.keep_prob = keep_prob - # Note the calls to `track_layer` below; these calls register the layers as - # network components that house trainable variables. - self.cells = [ - self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size)) - for size in rnn_cell_sizes - ] - self.relu = self.track_layer( - tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu")) + self.cells = tf.contrib.checkpoint.List( + [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) + self.relu = layers.Dense( + label_dimension, activation=tf.nn.relu, name="relu") - def call(self, chars, sequence_length, training=False): + def call(self, inputs, training=False): """Implements the RNN logic and prediction generation. Args: - chars: a Tensor of dimension [batch_size, time_steps, 256] holding a - batch of one-hot encoded color names - sequence_length: a Tensor of dimension [batch_size] holding the length - of each character sequence (i.e., color name) + inputs: A tuple (chars, sequence_length), where chars is a batch of + one-hot encoded color names represented as a Tensor with dimensions + [batch_size, time_steps, 256] and sequence_length holds the length + of each character sequence (color name) as a Tensor with dimension + [batch_size]. training: whether the invocation is happening during training Returns: @@ -151,6 +173,7 @@ class RNNColorbot(tfe.Network): passing chars through a multi-layer RNN and applying a ReLU to the final hidden state. """ + (chars, sequence_length) = inputs # Transpose the first and second dimensions so that chars is of shape # [time_steps, batch_size, dimension]. chars = tf.transpose(chars, [1, 0, 2]) @@ -191,7 +214,7 @@ def test(model, eval_data): """Computes the average loss on eval_data, which should be a Dataset.""" avg_loss = tfe.metrics.Mean("loss") for (labels, chars, sequence_length) in tfe.Iterator(eval_data): - predictions = model(chars, sequence_length, training=False) + predictions = model((chars, sequence_length), training=False) avg_loss(loss(labels, predictions)) print("eval/loss: %.6f\n" % avg_loss.result()) with tf.contrib.summary.always_record_summaries(): @@ -204,7 +227,7 @@ def train_one_epoch(model, optimizer, train_data, log_interval=10): tf.train.get_or_create_global_step() def model_loss(labels, chars, sequence_length): - predictions = model(chars, sequence_length, training=True) + predictions = model((chars, sequence_length), training=True) loss_value = loss(labels, predictions) tf.contrib.summary.scalar("loss", loss_value) return loss_value @@ -277,7 +300,7 @@ def main(_): (chars, length) = (tf.identity(chars), tf.identity(length)) chars = tf.expand_dims(chars, 0) length = tf.expand_dims(length, 0) - preds = tf.unstack(model(chars, length, training=False)[0]) + preds = tf.unstack(model((chars, length), training=False)[0]) # Predictions cannot be negative, as they are generated by a ReLU layer; # they may, however, be greater than 1. diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py index 75b342ba78bd5de5c2827296f6fba01ffa86d560..b7d8395e277b526ba40ccafa323ba453a8667b62 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -67,5 +67,5 @@ class RNNColorbotTest(tf.test.TestCase): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 5c5c59c87744f4ffa6db90e5d8d3aa3bc8132756..c2340a293a80924f2dfa90e2fb23134b0f1feb6b 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -38,22 +38,26 @@ import tensorflow as tf from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.eager.python import tfe +layers = tf.keras.layers -class RNN(tfe.Network): + +class RNN(tf.keras.Model): """A static RNN. - Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer. + Similar to tf.nn.static_rnn, implemented as a class. """ def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - for _ in range(num_layers): - self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)) + self.cells = tf.contrib.checkpoint.List([ + tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) + for _ in range(num_layers) + ]) def call(self, input_seq, training): batch_size = int(input_seq.shape[1]) - for c in self.layers: + for c in self.cells: state = c.zero_state(batch_size, tf.float32) outputs = [] input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0) @@ -64,10 +68,14 @@ class RNN(tfe.Network): input_seq = tf.stack(outputs, axis=0) if training: input_seq = tf.nn.dropout(input_seq, self.keep_ratio) - return input_seq, None + # Returning a list instead of a single tensor so that the line: + # y = self.rnn(y, ...)[0] + # in PTBModel.call works for both this RNN and CudnnLSTM (which returns a + # tuple (output, output_states). + return [input_seq] -class Embedding(tf.layers.Layer): +class Embedding(layers.Layer): """An Embedding layer.""" def __init__(self, vocab_size, embedding_dim, **kwargs): @@ -87,7 +95,8 @@ class Embedding(tf.layers.Layer): return tf.nn.embedding_lookup(self.embedding, x) -class PTBModel(tfe.Network): +# pylint: disable=not-callable +class PTBModel(tf.keras.Model): """LSTM for word language modeling. Model described in: @@ -109,19 +118,16 @@ class PTBModel(tfe.Network): self.keep_ratio = 1 - dropout_ratio self.use_cudnn_rnn = use_cudnn_rnn - self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim)) + self.embedding = Embedding(vocab_size, embedding_dim) if self.use_cudnn_rnn: self.rnn = cudnn_rnn.CudnnLSTM( num_layers, hidden_dim, dropout=dropout_ratio) else: self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) - self.track_layer(self.rnn) - self.linear = self.track_layer( - tf.layers.Dense( - vocab_size, - kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))) + self.linear = layers.Dense( + vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) self._output_shape = [-1, embedding_dim] def call(self, input_seq, training): @@ -136,7 +142,7 @@ class PTBModel(tfe.Network): y = self.embedding(input_seq) if training: y = tf.nn.dropout(y, self.keep_ratio) - y, _ = self.rnn(y, training=training) + y = self.rnn(y, training=training)[0] return self.linear(tf.reshape(y, self._output_shape)) @@ -148,7 +154,7 @@ def clip_gradients(grads_and_vars, clip_ratio): def loss_fn(model, inputs, targets, training): labels = tf.reshape(targets, [-1]) - outputs = model(inputs, training) + outputs = model(inputs, training=training) return tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=outputs)) @@ -290,7 +296,7 @@ def test_model(use_cudnn_rnn): def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() if not FLAGS.data_path: raise ValueError("Must specify --data-path") @@ -301,32 +307,37 @@ def main(_): have_gpu = tfe.num_gpus() > 0 use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(FLAGS.logdir)): - with tf.device("/device:GPU:0" if have_gpu else None): - # Make learning_rate a Variable so it can be included in the checkpoint - # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") - sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) - model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, - FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, - use_cudnn_rnn) - optimizer = tf.train.GradientDescentOptimizer(learning_rate) - - best_loss = None - for _ in range(FLAGS.epoch): - train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) - eval_loss = evaluate(model, eval_data) - if not best_loss or eval_loss < best_loss: - if FLAGS.logdir: - tfe.Saver(model.trainable_weights + [learning_rate]).save( - os.path.join(FLAGS.logdir, "ckpt")) - best_loss = eval_loss - else: - learning_rate.assign(learning_rate / 4.0) - sys.stderr.write("eval_loss did not reduce in this epoch, " - "changing learning rate to %f for the next epoch\n" % - learning_rate.numpy()) + with tf.device("/device:GPU:0" if have_gpu else None): + # Make learning_rate a Variable so it can be included in the checkpoint + # and we can resume training with the last saved learning_rate. + learning_rate = tfe.Variable(20.0, name="learning_rate") + model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, + FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, + use_cudnn_rnn) + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + checkpoint = tfe.Checkpoint( + learning_rate=learning_rate, model=model, + # GradientDescentOptimizer has no state to checkpoint, but noting it + # here lets us swap in an optimizer that does. + optimizer=optimizer) + # Restore existing variables now (learning_rate), and restore new variables + # on creation if a checkpoint exists. + checkpoint.restore(tf.train.latest_checkpoint(FLAGS.logdir)) + sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) + + best_loss = None + for _ in range(FLAGS.epoch): + train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) + eval_loss = evaluate(model, eval_data) + if not best_loss or eval_loss < best_loss: + if FLAGS.logdir: + checkpoint.save(os.path.join(FLAGS.logdir, "ckpt")) + best_loss = eval_loss + else: + learning_rate.assign(learning_rate / 4.0) + sys.stderr.write("eval_loss did not reduce in this epoch, " + "changing learning rate to %f for the next epoch\n" % + learning_rate.numpy()) if __name__ == "__main__": diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..638c57d1c92c1dce0ef9e73e9a6ac2369358080b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +cuda_py_test( + name = "scan_test", + size = "small", + srcs = ["scan_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "scan_graph_test", + size = "small", + srcs = ["scan_graph_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b8c8941ec411912f3089315d038fc4bcd049ae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under graph mode execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + with tf.Session() as sess: + sess.run(sum_op) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a02fc24c79dae6c2565db8b138b1d7391d169ed8 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under eager execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index a1f8a759e2a556bc219f0aa13942f293c4f34cfa..5966f1d4873e8e77b3ad5914da7bfc7e69d4e341 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -38,5 +38,9 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], - tags = ["no_pip"], # because spinn.py is under third_party/. + tags = [ + "no-internal-py3", # flaky + "no_cuda_on_cpu_tap", + "no_pip", # because spinn.py is under third_party/. + ], ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 081b0af14fcc983a3f85d2a50e2bb04d2f2493b3..8ac553e0ae71382966d03d9ef4429adf5137b369 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,7 +36,8 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -417,12 +418,15 @@ class SpinnTest(test_util.TensorFlowTestCase): if event.summary.value and event.summary.value[0].tag == "train/loss"] self.assertEqual(config.epochs, len(train_losses)) - self.assertLess(train_losses[-1], train_losses[0]) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - ckpt_variable_names = [ - item[0] for item in checkpoint_utils.list_variables(config.logdir)] + object_graph = checkpointable_utils.object_metadata( + saver.latest_checkpoint(config.logdir)) + ckpt_variable_names = set() + for node in object_graph.nodes: + for attribute in node.attributes: + ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index d97ff6b74cf033617154f7cbbd00cb6492a1d2f4..2d2aba6908b168e0bf63f4706b6344cbb4ca82bd 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -1,903 +1,18 @@ -# TensorFlow Eager Execution - -## What is this? +# Eager execution Eager execution is a feature that makes TensorFlow execute operations -immediately: concrete values are returned, instead of a computational graph to -be executed later. - -As a result, enabling eager execution provides: - -- A [NumPy](http://www.numpy.org/)-like library for numerical computation with - support for GPU acceleration and automatic differentiation. -- A flexible platform for machine learning research and experimentation. - -Eager execution is under active development. This guide walks through an -alpha/preview release. In particular, not all TensorFlow APIs currently work -with eager execution enabled, and some models may be slow to execute, compared -to models defined without using eager execution. - -## Installation - -Eager execution is included in TensorFlow versions 1.5 and above. -Installation instructions at https://www.tensorflow.org/install/ - -The contents of this guide are compatible with TensorFlow 1.5. -However, if you run into bugs that are fixed in source but not the -release, you may want to either [build from -source](https://www.tensorflow.org/install/install_sources) -or try a nightly build. The nightly builds are available as: - -- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and - -- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. - -For example, to run the latest nightly docker image: - -```sh -# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker -docker pull tensorflow/tensorflow:nightly-gpu -docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu - -# If you do not have a GPU, use the CPU-only image -docker pull tensorflow/tensorflow:nightly -docker run -it -p 8888:8888 tensorflow/tensorflow:nightly -``` - -And then visit http://localhost:8888 in your browser for a Jupyter notebook -environment. - -## Getting Started - -With TensorFlow installed, eager execution is enabled via a single call: - -```python -import tensorflow as tf - -import tensorflow.contrib.eager as tfe - -tfe.enable_eager_execution() -``` - -Enabling eager execution changes how TensorFlow functions behave (in particular, -`Tensor` objects will reference concrete values instead of being symbolic -handles to nodes in a computational graph). As a result, eager execution should -be enabled at the beginning of a program and cannot be disabled afterwards in -the same program. - -Code examples in the rest of this guide assume that eager execution has been -enabled. - -## A library for numerical computation - -A significant fraction of the [TensorFlow -API](https://www.tensorflow.org/api_docs/python/) consists of numerical -operations: -[arithmetic operations](https://www.tensorflow.org/api_guides/python/math_ops#Arithmetic_Operators), -[matrix operations](https://www.tensorflow.org/api_guides/python/math_ops#Matrix_Math_Functions), -[linear algebra operations](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), -etc. - -With eager execution enabled, these operations consume and return -multi-dimensional arrays as `Tensor` objects, similar to NumPy -[`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html). -For example: - -```python -# Multiply two 2x2 matrices -x = tf.matmul([[1, 2], - [3, 4]], - [[4, 5], - [6, 7]]) -# Add one to each element -# (tf.add supports broadcasting) -y = tf.add(x, 1) - -# Create a random random 5x3 matrix -z = tf.random_uniform([5, 3]) - -print(x) -print(y) -print(z) -``` - -Output: - -``` -tf.Tensor( -[[16 19] - [36 43]], shape=(2, 2), dtype=int32) -tf.Tensor( -[[17 20] - [37 44]], shape=(2, 2), dtype=int32) -tf.Tensor( -[[ 0.25058532 0.0929395 0.54113817] - [ 0.3108716 0.93350542 0.84909797] - [ 0.53081679 0.12788558 0.01767385] - [ 0.29725885 0.33540785 0.83588314] - [ 0.38877153 0.39720535 0.78914213]], shape=(5, 3), dtype=float32) -``` - -For convenience, these operations can also be triggered via operator overloading -of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`, -`-` to `tf.subtract`, `*` to `tf.multiply`, etc.: - -```python -x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1 -print(x) -``` - -Output: - -``` -tf.Tensor([ 3.], shape=(1,), dtype=float32) -``` - -### Converting to and from NumPy - -The operations above automatically convert Python objects (like lists of -numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used -as NumPy arrays by numpy operations. - -```python -import numpy as np - -x = tf.add(1, 1) # tf.Tensor with a value of 2 -y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2 -z = np.multiply(x, y) # numpy.int64 with a value of 4 -``` - -Alternatively, they can be explicitly converted using -[`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as -shown in the next example. - -Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain -its NumPy `ndarray` value. For example: - -```python -import numpy as np - -np_x = np.array(2., dtype=np.float32) -x = tf.constant(np_x) - -py_y = 3. -y = tf.constant(py_y) - -z = x + y + 1 - -print(z) -print(z.numpy()) -``` - -Output: - -``` -tf.Tensor(6.0, shape=(), dtype=float32) -6.0 -``` - -### GPU acceleration - -Many TensorFlow operations support GPU acceleration. With eager execution -enabled, [computation is *not* automatically -offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you -must explicitly specify when GPUs should be used. - -The simplest way to do this is to enclose your computation in a `with -tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function, -which returns the number of available GPUs. - -For example, consider this snippet to measure the time to multiply two 1000x1000 -matrices on CPU: - -```python -import time - -def measure(x): - # The very first time a GPU is used by TensorFlow, it is initialized. - # So exclude the first run from timing. - tf.matmul(x, x) - - start = time.time() - for i in range(10): - tf.matmul(x, x) - end = time.time() - - return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape) - -# Run on CPU: -with tf.device("/cpu:0"): - print("CPU: %s" % measure(tf.random_normal([1000, 1000]))) - -# If a GPU is available, run on GPU: -if tfe.num_gpus() > 0: - with tf.device("/gpu:0"): - print("GPU: %s" % measure(tf.random_normal([1000, 1000]))) -``` - -Output (exact numbers will depend on the characteristics of the hardware): - -```python -CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times -GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times -``` - -Alternatively, methods on the `Tensor` object can be used to explicitly copy the -`Tensor` to a different device. Operations are typically executed on the device -on which the inputs are placed. For example: - -```python -x = tf.random_normal([10, 10]) - -x_gpu0 = x.gpu() -x_cpu = x.cpu() - -_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU -_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0 - -if tfe.num_gpus() > 1: - x_gpu1 = x.gpu(1) - _ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1 -``` - -### Automatic Differentiation - -[Automatic -differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is -very useful when implementing many machine learning algorithms (e.g., -[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training -neural networks). For this purpose, TensorFlow eager execution provides an -[autograd](https://github.com/HIPS/autograd)-style API for automatic -differentiation. Specifically, the functions: - -- `tfe.gradients_function(f)`: Returns a Python function that computes the - derivatives of the Python function `f` with respect to its arguments. `f` - must return a scalar value. When the returned function is invoked, it - returns a list of `Tensor` objects (one element for each argument of `f`). -- `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`, - except that when the returned function is invoked, it returns the value of - `f` in addition to the list of derivatives of `f` with respect to its - arguments. - -These functions naturally apply to higher order differentiation as well. For -example: - -```python -def f(x): - return tf.multiply(x, x) # Or x * x -assert 9 == f(3.).numpy() - -df = tfe.gradients_function(f) -assert 6 == df(3.)[0].numpy() - -# Second order deriviative. -d2f = tfe.gradients_function(lambda x: df(x)[0]) -assert 2 == d2f(3.)[0].numpy() - -# Third order derivative. -d3f = tfe.gradients_function(lambda x : d2f(x)[0]) -assert 0 == d3f(3.)[0].numpy() -``` - -These functions can be used to train models. For example, consider the following -simple linear regression model: - -```python -def prediction(input, weight, bias): - return input * weight + bias - -# A toy dataset of points around 3 * x + 2 -NUM_EXAMPLES = 1000 -training_inputs = tf.random_normal([NUM_EXAMPLES]) -noise = tf.random_normal([NUM_EXAMPLES]) -training_outputs = training_inputs * 3 + 2 + noise - -# A loss function: Mean-squared error -def loss(weight, bias): - error = prediction(training_inputs, weight, bias) - training_outputs - return tf.reduce_mean(tf.square(error)) - -# Function that returns the derivative of loss with respect to -# weight and bias -grad = tfe.gradients_function(loss) - -# Train for 200 steps (starting from some random choice for W and B, on the same -# batch of data). -W = 5. -B = 10. -learning_rate = 0.01 -print("Initial loss: %f" % loss(W, B).numpy()) -for i in range(200): - (dW, dB) = grad(W, B) - W -= dW * learning_rate - B -= dB * learning_rate - if i % 20 == 0: - print("Loss at step %d: %f" % (i, loss(W, B).numpy())) -print("Final loss: %f" % loss(W, B).numpy()) -print("W, B = %f, %f" % (W.numpy(), B.numpy())) -``` - -Output: (the exact numbers may vary depending on the randomness in noise) - -``` -Initial loss: 66.730003 -Loss at step 0: 64.200096 -Loss at step 20: 29.872814 -Loss at step 40: 14.233772 -Loss at step 60: 7.090570 -Loss at step 80: 3.819887 -Loss at step 100: 2.318821 -Loss at step 120: 1.628385 -Loss at step 140: 1.310142 -Loss at step 160: 1.163167 -Loss at step 180: 1.095162 -Final loss: 1.064711 -W, B = 3.094944, 2.161383 -``` - -To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):` -block. (However, this particular model, with only two floating point parameters, -is unlikely to benefit from GPU acceleration.) - -### Customizing gradients - -One may want to define custom gradients for an operation, or for a function. -This may be useful for multiple reasons, including providing a more efficient -or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability) -gradient for a sequence of operations. - -For example, consider the function `log(1 + e^x)`, which commonly occurs in the -computation of cross entropy and log likelihoods. - -```python -def log1pexp(x): -  return tf.log(1 + tf.exp(x)) -grad_log1pexp = tfe.gradients_function(log1pexp) - -# Works fine at x = 0. -assert 0.5 == float(grad_log1pexp(0.)[0]) - -# Returns a `nan` at x = 100 due to numerical instability. -import math -assert math.isnan(float(grad_log1pexp(100.)[0])) -``` - -We can define a custom gradient for the above function that analytically -simplifies the gradient expression. - -```python -@tfe.custom_gradient -def log1pexp(x): -  e = tf.exp(x) -  def grad(dy): -    return dy * (1 - 1 / (1 + e)) -  return tf.log(1 + e), grad -grad_log1pexp = tfe.gradients_function(log1pexp) - -# Works as before at x = 0. -assert 0.5 == float(grad_log1pexp(0.)[0]) - -# But now works at x = 100 as well. -assert 1.0 == float(grad_log1pexp(100.)[0]) -``` -Also notice how the gradient function implementation reuses an expression -(`tf.exp(x)`) computed during the forward pass, hence making the gradient -computation more efficient by avoiding redundant computation. - -## Building and training models - -In practice, your computation may have many parameters to be optimized (by -computing derivatives). Encapsulating them into re-usable classes/objects -makes the code easier to follow than writing a single top-level function with -many arguments. - -In fact, eager execution encourages use of the [Keras](https://keras.io)-style -"Layer" classes in the -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) -module. - -Furthermore, you may want to apply more sophisticated techniques to compute -parameter updates, such as those in -[`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers) -implementations. - -This next section walks through using the same `Optimizer` and `Layer` APIs used -to build trainable TensorFlow graphs in an environment where eager execution is -enabled. - -### Variables and Optimizers - -`tfe.Variable` objects store mutable `Tensor` values that can be accessed during -training, making automatic differentiation easier. In particular, parameters of -a model can be encapsulated in Python classes as variables. - -`tfe.gradients_function(f)` introduced earlier computes the derivatives of `f` -with respect to its arguments. However, it requires all parameters of interest -to be arguments of `f`, which becomes cumbersome when `f` depends on a large -number of trainable parameters. - -`tfe.implicit_gradients` is an alternative function with some useful properties: - -- It computes the derivatives of `f` with respect to all the `tfe.Variable`s - used by `f`. -- When the returned function is invoked, it returns a list of - (gradient value, Variable object) tuples. - -Representing model parameters as `Variable` objects, along with the use of -`tfe.implicit_gradients`, typically results in better encapsulation. For -example, the linear regression model described above can be written into a -class: - -```python -class Model(object): - def __init__(self): - self.W = tfe.Variable(5., name='weight') - self.B = tfe.Variable(10., name='bias') - - def predict(self, inputs): - return inputs * self.W + self.B - - -# The loss function to be optimized -def loss(model, inputs, targets): - error = model.predict(inputs) - targets - return tf.reduce_mean(tf.square(error)) - -# A toy dataset of points around 3 * x + 2 -NUM_EXAMPLES = 1000 -training_inputs = tf.random_normal([NUM_EXAMPLES]) -noise = tf.random_normal([NUM_EXAMPLES]) -training_outputs = training_inputs * 3 + 2 + noise - -# Define: -# 1. A model -# 2. Derivatives of a loss function with respect to model parameters -# 3. A strategy for updating the variables based on the derivatives -model = Model() -grad = tfe.implicit_gradients(loss) -optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - -# The training loop -print("Initial loss: %f" % - loss(model, training_inputs, training_outputs).numpy()) -for i in range(201): - optimizer.apply_gradients(grad(model, training_inputs, training_outputs)) - if i % 20 == 0: - print("Loss at step %d: %f" % - (i, loss(model, training_inputs, training_outputs).numpy())) -print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy()) -print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy())) -``` - -Output: - -``` -Initial loss: 69.693184 -Loss at step 0: 66.987854 -Loss at step 20: 30.553387 -Loss at step 40: 14.250237 -Loss at step 60: 6.955020 -Loss at step 80: 3.690550 -Loss at step 100: 2.229739 -Loss at step 120: 1.576032 -Loss at step 140: 1.283496 -Loss at step 160: 1.152584 -Loss at step 180: 1.093999 -Final loss: 1.067780 -W, B = 3.0114281, 2.0865183 -``` - -Using `implicit_gradients` avoids the need to provide all the trainable -parameters of the model as arguments to the `loss` function. - -### Using Keras and the Layers API - -[Keras](https://keras.io) is a popular API for defining model structures. The -[`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) -module provides a set of building blocks for models and is implemented using the -`tf.layers.Layer` subclasses in the -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) -module. We encourage the use of these same building blocks when using -TensorFlow's eager execution feature. For example, the very same linear -regression model can be built using `tf.layers.Dense`: - -```python -class Model(object): - def __init__(self): - self.layer = tf.layers.Dense(1) - - def predict(self, inputs): - return self.layer(inputs) -``` - -The `tf.layers` API makes it more convenient to define more sophisticated -models. For example, the following will train an MNIST model: - -```python -class MNISTModel(object): - def __init__(self, data_format): - # 'channels_first' is typically faster on GPUs - # while 'channels_last' is typically faster on CPUs. - # See: https://www.tensorflow.org/performance/performance_guide#data_formats - if data_format == 'channels_first': - self._input_shape = [-1, 1, 28, 28] - else: - self._input_shape = [-1, 28, 28, 1] - self.conv1 = tf.layers.Conv2D(32, 5, - padding='same', - activation=tf.nn.relu, - data_format=data_format) - self.max_pool2d = tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='same', data_format=data_format) - self.conv2 = tf.layers.Conv2D(64, 5, - padding='same', - activation=tf.nn.relu, - data_format=data_format) - self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu) - self.dropout = tf.layers.Dropout(0.5) - self.dense2 = tf.layers.Dense(10) - - def predict(self, inputs): - x = tf.reshape(inputs, self._input_shape) - x = self.max_pool2d(self.conv1(x)) - x = self.max_pool2d(self.conv2(x)) - x = tf.layers.flatten(x) - x = self.dropout(self.dense1(x)) - return self.dense2(x) - -def loss(model, inputs, targets): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=model.predict(inputs), labels=targets)) - - -# Load the training and validation data -from tensorflow.examples.tutorials.mnist import input_data -data = input_data.read_data_sets("./mnist_data", one_hot=True) - -# Train -device = "gpu:0" if tfe.num_gpus() else "cpu:0" -model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last') -optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) -grad = tfe.implicit_gradients(loss) -for i in range(20001): - with tf.device(device): - (inputs, targets) = data.train.next_batch(50) - optimizer.apply_gradients(grad(model, inputs, targets)) - if i % 100 == 0: - print("Step %d: Loss on training set : %f" % - (i, loss(model, inputs, targets).numpy())) -print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy()) -``` - -For a more complete example, see -[`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) - -### Checkpointing trained variables - -TensorFlow Variables (`tfe.Variable`) provides a way to represent shared, -persistent state of your model. The `tfe.Saver` class (which is a thin wrapper -over the -[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver) -class) provides a means to save and restore variables to and from _checkpoints_. - -For example: - -```python -# Create variables. -x = tfe.Variable(10., name='x') -y = tfe.Variable(5., name='y') - -# Create a Saver. -saver = tfe.Saver([x, y]) - -# Assign new values to the variables and save. -x.assign(2.) -saver.save('/tmp/ckpt') - -# Change the variable after saving. -x.assign(11.) -assert 16. == (x + y).numpy() # 11 + 5 - -# Restore the values in the checkpoint. -saver.restore('/tmp/ckpt') - -assert 7. == (x + y).numpy() # 2 + 5 -``` - -### `tfe.Network` - -You may often want to organize your models using classes, like the `MNISTModel` -class described above. We recommend inheriting from the `tfe.Network` class as -it provides conveniences like keeping track of all model variables and methods -to save and restore from checkpoints. - -Sub-classes of `tfe.Network` may register `Layer`s (like classes in -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), -or [Keras -layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) -using a call to `self.track_layer()` and define the computation in an -implementation of `call()`. - -Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables -lazily, when the first input is encountered. - -For example, consider the following two-layer neural network: - -```python -class TwoLayerNet(tfe.Network): - def __init__(self): - super(TwoLayerNet, self).__init__() - self.layer1 = self.track_layer( - tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False)) - self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False)) - - def call(self, x): - return self.layer2(self.layer1(x)) - -net = TwoLayerNet() - -# No variables created yet -assert 0 == len(net.variables) - -# They are created on first input: -inp = tf.constant([[1.]]) - -# Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units, -# the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3 -# matrix. -assert [1, 3] == net(inp).shape.as_list() # Invoke net; get output shape. -assert 1 == len(net.layer1.variables) -assert 1 == len(net.layer2.variables) -assert 2 == len(net.variables) # weights for each layer. -assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1. -assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2. -``` - -The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows -instances of `tfe.Network` to be embedded in other networks. For example: - -```python -class ThreeLayerNet(tfe.Network): - def __init__(self): - super(ThreeLayerNet, self).__init__() - self.a = self.track_layer(TwoLayerNet()) - self.b = self.track_layer(tf.layers.Dense(4, use_bias=False)) - - def call(self, x): - return self.b(self.a(x)) - -net = ThreeLayerNet() - -assert [1, 4] == net(inp).shape.as_list() -assert 3 == len(net.variables) -assert [1, 2] == net.variables[0].shape.as_list() -assert [2, 3] == net.variables[1].shape.as_list() -assert [3, 4] == net.variables[2].shape.as_list() -``` - -See more examples in -[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples). - -`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a -convenient way to save and load checkpoints without changing the program once -the checkpoint has been created. For example, we can set an objective for the -output of our network, choose an optimizer, and a location for the checkpoint: - -```python -objective = tf.constant([[2., 3., 4., 5.]]) -optimizer = tf.train.AdamOptimizer(0.01) -checkpoint_directory = '/tmp/tfe_example' -checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') -net = ThreeLayerNet() -``` - -Note that variables have not been created yet. We want them to be restored from -a checkpoint, if one exists, so we create them inside a -`tfe.restore_variables_on_create` context manager. Then our training loop is the -same whether starting training or resuming from a previous checkpoint: - -```python -with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(checkpoint_directory)): - global_step = tf.train.get_or_create_global_step() - for _ in range(100): - loss_fn = lambda: tf.norm(net(inp) - objective) - optimizer.minimize(loss_fn, global_step=global_step) - if tf.equal(global_step % 20, 0): - print("Step %d, output %s" % (global_step.numpy(), - net(inp).numpy())) - all_variables = ( - net.variables - + optimizer.variables() - + [global_step]) - # Save the checkpoint. - tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) -``` - -The first time it runs, `Network` variables are initialized randomly. Then the -output is trained to match the objective we've set: - -``` -Step 20, output [[ 0.03575622 0.29863232 0.03474367 0.24735749]] -Step 40, output [[ 0.40646029 0.9856872 0.46851286 0.95358551]] -Step 60, output [[ 1.74541104 2.800704 1.79055595 2.74783421]] -Step 80, output [[ 2.14977384 3.44340849 3.96120024 5.16242075]] -Step 100, output [[ 1.99943113 3.02364397 3.93500996 4.9610076 ]] -``` - -In subsequent iterations, variables are initialized with the values read from -the latest checkpoint. Running the same code again, we continue from where we -left off: - -``` -Step 120, output [[ 1.99234128 3.0271616 3.98732996 4.96401167]] -Step 140, output [[ 2.00133467 3.01270437 4.00616646 5.00406504]] -Step 160, output [[ 1.99647415 2.9956708 3.99064088 4.99632359]] -Step 180, output [[ 2.00699997 3.00904822 4.00706148 5.01193142]] -Step 200, output [[ 1.98334622 2.98249531 3.97375059 4.97123432]] -``` - - -### Summaries, metrics and TensorBoard - -[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) -is a popular tool for understanding, debugging and optimizing the model training -process. To benefit from the visualizations offered by TensorBoard, summary -events need to be written during the course of execution of your program. You -might find many Tensorflow programs that include the -[`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations -during graph construction. - -`tf.summary` operations are *not* compatible with eager execution, but an -equivalent alternative exists in -[`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/summary) -that is compatible with both eager execution and graph construction. - -During model construction simply insert summary operations like -`tf.contrib.summary.scalar`. These operations do nothing by default, unless a -summary writer is currently active and a writing policy is set. - -For example, to record summaries once every 100 global steps, use: - -```python -tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_file_writer(logdir) - -for _ in range(iterations): - with writer.as_default(): - with tf.contrib.summary.record_summaries_every_n_global_steps(100): - # your model code goes here - tf.contrib.summary.scalar('loss', loss) - # ... -``` - -See the full mnist example in -[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) -for a full model using `tf.contrib.summary`. - -Similarly to summaries, the metrics in `tf.metrics` are currently not compatible -with eager execution. We instead provide object-oriented metrics in the -`tfe.metrics` package, which are compatible with graph construction as well. - -Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and -`tfe.Metrics.Accuracy`, all implement an intuitive object-oriented -interface. Here's an example of how to use the `tfe.metrics.Mean` metric: - -```python -# Metrics are objects, which can be created and destroyed. -my_mean = tfe.metrics.Mean(name='my_mean') -# While a metric is active, you can call it as a function to accumulate into its -# internal state. -my_mean(0.0) -my_mean(10.0) -# Once you've finished updating the metric, you can get its result. In this case -# a simple average over all the calls to it. If a summary writer is active the -# metric will write the appropriate summaries using the metric name. -assert 5.0 == my_mean.result().numpy() -``` - -For a full example of a model using metrics for evaluation, see the mnist -example in -[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist). - -### Input Pipelines - -The discussion above has been centered around the computation executed by your -model. The -[`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) -module provides APIs to build complex input pipelines from simple, reusable -pieces. - -If you're familiar with constructing `tf.data.Dataset` objects when building -TensorFlow graphs, the same API calls are used when eager execution is enabled. -However, the process of iterating over elements of the dataset differs between -eager execution and graph construction. When eager execution is enabled, the -discussion on iterator creation using `make_one_shot_iterator()` and -`get_next()` in the -[Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is -*not* applicable. Instead, a more Pythonic `Iterator` class is available. - -For example: - -```python -# Create a source Dataset from in-memory numpy arrays. -# For reading from files on disk, you may want to use other Dataset classes -# like the TextLineDataset or the TFRecordDataset. -dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]) - -# Apply transformations, shuffling, batching etc. -dataset = dataset.map(tf.square).shuffle(2).batch(2) - -# Use tfe.Iterator to iterate over the dataset. -for x in tfe.Iterator(dataset): - print(x) -``` - -Output: - -``` -tf.Tensor([4 9], shape=(2,), dtype=int32) -tf.Tensor([16 25], shape=(2,), dtype=int32) -tf.Tensor([36 1], shape=(2,), dtype=int32) -``` - -## Interoperating with Graphs - -Eager execution improves the process of model development in Python; however, -because it is in its earliest stages, it does not yet support some features -available to [TensorFlow -graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph) -that are desirable when deploying models in production. In particular, eager -execution does not yet support distributed training, exporting models (to other -[programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow -serving](https://www.tensorflow.org/serving/), and mobile applications), and -various memory and computation optimizations that are applied to TensorFlow's -dataflow graphs. - -That said, the APIs used to build modes are exactly the same whether executing -eagerly or constructing graphs. This means that you can iteratively develop your -model with eager execution enabled and later, if needed, use the same code to -reap the benefits of representing models as computational graphs. - -For example, -[`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) -defines a model that is eagerly executed. That same code is used to construct -and execute a graph in -[`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py). - -Other models in the [examples -directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/) -demonstrate this as well. - -Some differences worth noting: - -- There is no notion of a `tf.placeholder` or a `tf.Session` when eager - execution is enabled. -- Many properties on the `tf.Tensor` object, like `tf.Tensor.name`, - `tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution - is enabled and their use will raise an `AttributeError`. -- To use `tfe.implicit_gradients` in graph construction, variables must be - created with [`use_resource=True`] provided to - [`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable) - or - [`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope). -- Some API calls (such as the functional-style `tf.layers.dense`, - `tf.layers.conv2d`) are not compatible with eager execution. Use of such - methods should raise an error indicating the alternative (e.g., the - `tf.layers.Dense` and `tf.layers.Conv2D` classes). - -## What next? +immediately: concrete values are returned, instead of creating a computational +graph that is executed later. -Please give eager execution a spin. This feature is in early stages and is -evolving, so we welcome your feedback via issues on GitHub (see [known -issues](https://github.com/tensorflow/tensorflow/labels/comp:eager)). +A user guide is available: https://www.tensorflow.org/programmers_guide/eager +([source file](../../../../docs_src/programmers_guide/eager.md)) -You may want to browse through some sample code, including benchmarks for some: +We welcome feedback through [GitHub issues](https://github.com/tensorflow/tensorflow/labels/comp:eager). -- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) -- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) -- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) -- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) -- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) +Sample code is available, including benchmarks for some: +- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) +- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) +- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) +- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) +- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index ea8dbf2b46ea4bd0e33645ae3c590c4dd13f7a52..c947ed9dcc415670a820f8a5cd9eaaf07334cfc3 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -20,22 +20,23 @@ from __future__ import print_function import re -from tensorflow.contrib.summary import summary_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope - +from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(object): +class Metric(checkpointable.CheckpointableBase): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -93,11 +94,12 @@ class Metric(object): `aggregate()`, it is for use by TensorFlow infrastructure. """ - def __init__(self, name=None): + def __init__(self, name=None, use_global_variables=False): self._built = False self._vars = [] self._initial_values = {} self._updates = [] + self._use_global_variables = use_global_variables name = name or self.__class__.__name__ # Replace things like spaces in name to create a valid scope name. scope_name = _to_replace.sub("_", name) @@ -108,13 +110,25 @@ class Metric(object): pos = scope.name.rfind(scope_name) self._name = name + scope.name[pos + len(scope_name):] self._scope = scope - if context.in_graph_mode(): + + # Ensures that if the user calls build directly we still set self._built to + # True to prevent variables from being recreated. + self._build = self.build + + def actual_build(*args, **kwargs): + self._build(*args, **kwargs) + self._built = True + self.build = actual_build + self.build.__doc__ = self._build.__doc__ + + # Captures construction scope for proper initialization. + if context.executing_eagerly(): + self._construction_scope = context.eager_mode + else: # We make self.call() into a graph callable here, so that we can # return a single op that performs all of the variable updates. self._construction_scope = ops.get_default_graph().as_default self.call = function.defun(self.call) - else: - self._construction_scope = context.eager_mode # ---- API for users ---- def __call__(self, *args, **kwargs): @@ -155,10 +169,11 @@ class Metric(object): initialization. Under eager execution, the variables are reset to their initial values as a side effect and this function returns None. """ - if context.in_graph_mode(): + if context.executing_eagerly(): + for v in self._vars: + v.assign(self._initial_values[v]) + else: return control_flow_ops.group([v.initializer for v in self._vars]) - for v in self._vars: - v.assign(self._initial_values[v]) # ---- To be implemented by descendants --- def build(self, *args, **kwargs): @@ -200,10 +215,10 @@ class Metric(object): def value(self): """In graph mode returns the result Tensor while in eager the callable.""" - if context.in_graph_mode(): - return self.result() - else: + if context.executing_eagerly(): return self.result + else: + return self.result() # We can support two different strategies of for doing data-parallel # distributed metric computations: @@ -245,19 +260,31 @@ class Metric(object): """***Only for use by descendants of Metric***.""" if self._built: raise RuntimeError("Can't call add_variable() except in build().") - collections = None if context.in_eager_mode() else [ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ] - v = variable_scope.get_variable( - name, - shape, - dtype, - initializer, + if context.executing_eagerly(): + collections = None + else: + if self._use_global_variables: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + else: + collections = [ops.GraphKeys.LOCAL_VARIABLES] + collections += [ops.GraphKeys.METRIC_VARIABLES] + # Variables are Checkpointable dependencies of Metrics regardless of the + # global/local distinction. Users can avoid saving variables by not adding a + # dependency on the Metric. + v = self._add_variable_with_custom_getter( + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, trainable=False, collections=collections, - use_resource=True) + use_resource=True, + getter=variable_scope.get_variable, + # Raise duplicate variable exceptions from get_variable rather than + # Checkpointable. + overwrite=True) self._vars.append(v) - if context.in_eager_mode(): + if context.executing_eagerly(): self._initial_values[v] = v.value() return v @@ -267,8 +294,10 @@ class Mean(Metric): # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64? - def __init__(self, name=None, dtype=dtypes.float64): - super(Mean, self).__init__(name=name) + def __init__(self, name=None, dtype=dtypes.float64, + use_global_variables=False): + super(Mean, self).__init__(name=name, + use_global_variables=use_global_variables) self.dtype = dtype def build(self, *args, **kwargs): @@ -339,6 +368,9 @@ class Accuracy(Mean): Returns: The arguments, for easy chaining. """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index a9ecaa3f8bced3043ea0eb0ac3aa8bfa65e9e1ff..02ee05487515b81bfae70d02c1dfdb6d816b77c7 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -18,18 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile from tensorflow.contrib.eager.python import metrics -from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils class MetricsTest(test.TestCase): @@ -50,6 +53,19 @@ class MetricsTest(test.TestCase): self.assertEqual( set(m.variables), set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))) + self.assertEqual(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), []) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testUseGlobalVariablesCollections(self): + with context.graph_mode(), ops.Graph().as_default(): + m = metrics.Mean(use_global_variables=True) + m(1000) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + self.assertEqual(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES), []) self.assertEqual( set(m.variables), set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) @@ -102,6 +118,11 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): + m = metrics.Accuracy() + with self.assertRaises(errors.InvalidArgumentError): + m([[0], [0]], [0, 1]) + def testWeightedAccuracy(self): m = metrics.Accuracy() # 1 correct, total weight of 2 @@ -131,8 +152,6 @@ class MetricsTest(test.TestCase): self.assertAllEqual(2.0, m2.result()) def testNamesWithSpaces(self): - # Verify two metrics with the same class and name don't - # accidentally share state. m1 = metrics.Mean("has space") m1(0) self.assertEqual(m1.name, "has space") @@ -171,8 +190,8 @@ class MetricsTest(test.TestCase): self.assertEqual(self.evaluate(value), 2.5) def testTwoMeansGraph(self): - # Verify two metrics with the same class and name don't - # accidentally share state. + # Verify two metrics with the same name in the same graph raises a + # ValueError. with context.graph_mode(): m1 = metrics.Mean() m1(0) @@ -180,6 +199,15 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testBuildMean(self): + # Verify that calling build() on Mean and then calling it won't recreate + # variables. + m = metrics.Mean() + m.build() + old_numer = m.numer + m(0.0) + self.assertTrue(old_numer is m.numer) + def testMetricsChain(self): with context.graph_mode(), self.test_session(): m1 = metrics.Mean() @@ -193,6 +221,31 @@ class MetricsTest(test.TestCase): self.assertAllEqual(m2.result().eval(), 2.0) self.assertAllEqual(m1.result().eval(), 1.0) + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + mean = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=mean) + mean.build() + mean._built = True + self.evaluate(mean.init_variables()) + self.evaluate(mean(100.)) + self.evaluate(mean(200.)) + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(mean(1000.)) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(mean(300.)) + self.assertAllEqual(200., self.evaluate(mean.value())) + + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertAllEqual(200., self.evaluate(restore_mean.value())) + self.assertEqual(3, self.evaluate(restore_mean.denom)) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d..f801d9a47b2f831a48d9b6335c69612c1356d800 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -23,13 +23,16 @@ import os import weakref from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils # pylint: disable=protected-access # Explanation for protected-access disable: Network has lots of same-class and @@ -51,9 +54,40 @@ def _network_name_scope_naming(current_variable_scope): return current_variable_scope.name + "/" +_NETWORK_DEPRECATION_MESSAGE = ( + "Please inherit from `tf.keras.Model`, and see its documentation for " + "details. `tf.keras.Model` should be a drop-in replacement for " + "`tfe.Network` in most cases, but note that `track_layer` is no longer " + "necessary or supported. Instead, `Layer` instances are tracked on " + "attribute assignment (see the section of `tf.keras.Model`'s documentation " + "on subclassing). Since the output of `track_layer` is often assigned to " + "an attribute anyway, most code can be ported by simply removing the " + "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow " + "`Layer` instances, including those from `tf.layers`, but switching to " + "the `tf.keras.layers` versions along with the migration to " + "`tf.keras.Model` is recommended, since it will preserve variable names. " + "Feel free to import it with an alias to avoid excess typing :)." +) + + class Network(base.Layer): """Represents the composition of a set of Layers. + *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation + for details. `tf.keras.Model` should be a drop-in replacement for + `tfe.Network` in most cases, but note that `track_layer` is no longer + necessary or supported. Instead, `Layer` instances are tracked on attribute + assignment (see the section of `tf.keras.Model`'s documentation on + subclassing). Since the output of `track_layer` is often assigned to an + attribute anyway, most code can be ported by simply removing the `track_layer` + calls. + + `tf.keras.Model` works with all TensorFlow `Layer` instances, including those + from `tf.layers`, but switching to the `tf.keras.layers` versions along with + the migration to `tf.keras.Model` is recommended, since it will preserve + variable names. Feel free to import it with an alias to avoid excess typing + :). + `Network` implements the `Layer` interface and adds convenience methods for managing sub-`Layer`s, such as listing variables. @@ -111,6 +145,7 @@ class Network(base.Layer): # - Detect layers used in __call__ that weren't registered with track_layer. # - Convert inputs to __call__ to tensors. + @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE) def __init__(self, name=None): """Configure the `Network`. @@ -129,6 +164,10 @@ class Network(base.Layer): ValueError: If `name` is not valid. Note that some naming errors will instead be raised when the `Network` is called. """ + if context.executing_eagerly(): + logging.warning( + ("** tfe.Network is deprecated and will be removed in a future " + "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE) if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -149,7 +188,12 @@ class Network(base.Layer): # check we might have name collisions if the parent scope on init gets # closed before build is called. self._variable_scope_counts_on_init = ( - variable_scope._get_default_variable_store().variable_scopes_count) + variable_scope.get_variable_scope_store().variable_scopes_count) + + def _gather_saveables_for_checkpoint(self): + raise NotImplementedError( + "tfe.Network does not support object-based checkpointing.\n\n%s" + % _NETWORK_DEPRECATION_MESSAGE) def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" @@ -176,7 +220,7 @@ class Network(base.Layer): avoid_names = parent_network._owned_layers name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = base._get_default_graph_uid_map() + name_uid_map = keras_base_layer.get_default_graph_uid_map() # Figure out which names we have to avoid based on which variable scope # we're nested in. strip_name = self._default_parent_variable_scope.name @@ -326,6 +370,8 @@ class Network(base.Layer): raise TypeError( "Network.track_layer() passed type %s, not a tf.layers.Layer" % (type(layer),)) + # Always use `ResourceVariable` with legacy layers. + layer._use_resource_variables = True if isinstance(layer, Network): layer._finalize_name(parent_network=self) else: @@ -499,10 +545,10 @@ class Sequential(Network): def add(self, layer_func): if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) + args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): - args = estimator_util.fn_args(layer_func) + args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " @@ -639,7 +685,7 @@ def _make_custom_getter_for_deferred_restorations(): # Mark as already restored from this checkpoint. delayed_restoration.checkpointed_variables_to_restore[ checkpoint_name] = None - if context.in_graph_mode(): + if not context.executing_eagerly(): delayed_restoration.session.run(variable.initializer) if found_value: # Error checking should run even if we've already restored a value. @@ -703,6 +749,9 @@ def _make_prefix_stripping_map_fn(scope_name): return _strip_variable_prefix +@deprecation.deprecated(date=None, instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.save_weights.")) def save_network_checkpoint( network, save_path, global_step=None, map_func=None): """Save variables from the Network to a checkpoint. @@ -772,7 +821,7 @@ def save_network_checkpoint( variable_map[mapped_name]._shared_name, variable._shared_name, network.scope_name)) - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -853,7 +902,7 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func): network_name=network.name, network_scope_name=network.scope_name)) if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -880,7 +929,7 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, # _DeferredRestoration objects once a Network has been built (so that # restoring in a loop does not take increasing amounts of memory). if checkpointed_variables_to_restore: - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -902,6 +951,9 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, _add_deferred_restoration(network, deferred_restoration) +@deprecation.deprecated(date=None, instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.load_weights.")) def restore_network_checkpoint(network, save_path, map_func=None): """Restore the Network from a checkpoint. diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 3329fc6c513265deff41a368f5688dd605209c14..c92bd15b253b67a3301cd562046a4467e1bf877d 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -20,12 +20,10 @@ import gc from tensorflow.contrib.eager.python import network from tensorflow.contrib.layers.python.layers import regularizers -from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl -from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import math_ops @@ -33,6 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: disable=not-callable @@ -64,6 +63,12 @@ class RegularizedNetwork(network.Network): class NetworkTest(test.TestCase): + def test_checkpointing_not_implemented(self): + checkpoint_directory = self.get_temp_dir() + checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + with self.assertRaises(NotImplementedError): + checkpoint.save(checkpoint_directory) + def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() checkpoint_path = network.save_network_checkpoint( @@ -469,36 +474,6 @@ class NetworkTest(test.TestCase): self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) - def testGraphOpNames(self): - """Network operation names should match variable naming.""" - - def _check_op_prefixes(expected_prefix, checked_ops): - for operation in ops.get_default_graph().get_operations(): - if operation.name == "ignore": - continue - if operation.name in checked_ops: - continue - checked_ops.add(operation.name) - self.assertStartsWith(expected_start=expected_prefix, - actual=operation.name) - self.assertNotIn("my_network", operation.name[len(expected_prefix):]) - self.assertNotIn("dense", operation.name[len(expected_prefix):]) - - with context.graph_mode(): - net = MyNetwork() - zero = constant_op.constant([[0.]], name="ignore") - net(zero) - checked_ops = set() - _check_op_prefixes(expected_prefix="my_network/dense/", - checked_ops=checked_ops) - net.net2 = net.track_layer(MyNetwork()) - net.net2(zero) - _check_op_prefixes(expected_prefix="my_network/my_network/dense/", - checked_ops=checked_ops) - MyNetwork()(zero) - _check_op_prefixes(expected_prefix="my_network_1/dense/", - checked_ops=checked_ops) - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableRegularizers(self): net = RegularizedNetwork() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 62421849c766a1124c726812428985c913c653a3..fdaca90fd13576e6ca8a3408aaf528dbc2384b0c 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -73,7 +73,7 @@ def restore_variables_on_create(save_path, map_func=None): NotFoundError: If the variable is not found in checkpoint. ValueError: If not used in eager mode or map_func is not callable. """ - if context.in_graph_mode(): + if not context.executing_eagerly(): raise ValueError( "Currently, restore_variables_on_create can only be used with " "eager execution enabled.") @@ -131,7 +131,7 @@ class Saver(object): Raises: RuntimeError: if invoked when eager execution has not been enabled. """ - if context.in_graph_mode(): + if not context.executing_eagerly(): raise RuntimeError("tfe.Saver can only be used when eager " "execution is enabled. Use tf.train.Saver when " "building graphs.") diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 1a7f7b85e688e80e3cf482f2754462888187d311..90a3711475719a7f991473c6c9067da1e76ab9f2 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -60,15 +60,9 @@ class SaverTest(test.TestCase): def testSameNameNoClobbering(self): with ops.device(self._dev()): - # Note that this test purposefully uses Graphs rather than - # IsolateTest. Users are more likely to accidentally create the same - # variable name this way. - first_graph = ops.Graph() - with first_graph.as_default(): - v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') - saver = _saver.Saver([v1_first_graph, v1_second_graph]) + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1, v2]) ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) @@ -102,7 +96,6 @@ class SaverTest(test.TestCase): # Can still restore it. saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) - self.assertEqual(v1.read_value().numpy(), 1.0) # However, cannot restore it with default name. with self.assertRaisesOpError('not found in checkpoint'): saver = _saver.Saver([v1, v2]).restore(ckpt_prefix) @@ -127,12 +120,11 @@ class SaverTest(test.TestCase): saver = _saver.Saver([v1]) saver.save(ckpt_prefix) - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - with _saver.restore_variables_on_create(ckpt_prefix): - # Value is from checkpoint, but not from argument. - ret, _ = model(2.0) - self.assertEqual(ret.numpy(), 1.0) + saver = _saver.Saver([v1]) + with _saver.restore_variables_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret, _ = model(2.0) + self.assertEqual(ret.numpy(), 1.0) def testRestoreNotFound(self): with ops.device(self._dev()): @@ -185,17 +177,17 @@ class SaverTest(test.TestCase): 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 - with ops.Graph().as_default(): - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + ops.reset_default_graph() + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index d32bebf90c1e768d1efec26b3b78bf1a522a8f00..fee9db46fa4f79d7dd613436726e8ddad51faf1c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -56,14 +56,24 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@save_network_checkpoint @@restore_network_checkpoint +@@Checkpoint +@@Checkpointable +@@CheckpointableSaver + +@@executing_eagerly @@in_eager_mode -@@in_graph_mode +@@set_execution_mode +@@execution_mode +@@async_wait +@@async_clear_error @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT +@@SYNC +@@ASYNC """ from __future__ import absolute_import @@ -87,11 +97,15 @@ from tensorflow.python.eager import function from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT -from tensorflow.python.eager.context import in_eager_mode -from tensorflow.python.eager.context import in_graph_mode +from tensorflow.python.eager.context import executing_eagerly from tensorflow.python.eager.context import list_devices +from tensorflow.python.eager.context import set_execution_mode +from tensorflow.python.eager.context import execution_mode +from tensorflow.python.eager.context import async_wait +from tensorflow.python.eager.context import async_clear_error +from tensorflow.python.eager.context import SYNC +from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback @@ -101,10 +115,15 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes +from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.util import CheckpointableSaver +from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func @@ -115,5 +134,6 @@ implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function value_and_gradients_function = backprop.val_and_grad_function GradientTape = backprop.GradientTape # pylint: disable=invalid-name +in_eager_mode = executing_eagerly remove_undocumented(__name__) diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index b6659c2a1797feab261d756e78b45231dbea5a02..db50b33af2e4f1cc6575d4b0d416d6d2669b5c35 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -47,7 +47,8 @@ class TFETest(test_util.TensorFlowTestCase): def testVariableError(self): with self.assertRaisesRegexp( - RuntimeError, r'Variable not supported in Eager mode'): + RuntimeError, + r'Variable not supported when eager execution is enabled'): variables.Variable(initial_value=1.0) def testGradients(self): @@ -56,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase): return math_ops.multiply(x, x) grad = tfe.gradients_function(square) - self.assertEquals([6], [x.numpy() for x in grad(3)]) + self.assertEquals([6], [x.numpy() for x in grad(3.)]) def testGradOfGrad(self): @@ -65,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase): grad = tfe.gradients_function(square) gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) - self.assertEquals([2], [x.numpy() for x in gradgrad(3)]) + self.assertEquals([2], [x.numpy() for x in gradgrad(3.)]) def testCustomGrad(self): @@ -79,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase): return y, grad_fn grad = tfe.gradients_function(f) - self.assertEquals([12], [x.numpy() for x in grad(3)]) + self.assertEquals([12], [x.numpy() for x in grad(3.)]) def testGPU(self): if tfe.num_gpus() <= 0: diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 6cdbed5b896577f5622b1bd0123c289c798bc0a5..1937ffb583bc727df76470d072b35fb3c9acaa88 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -9,35 +9,101 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "estimator_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":baseline", + ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", + ":hooks", ":linear", ":logit_fns", ":multi_head", ":replicate_model_fn", + ":rnn", "//tensorflow/python:util", ], ) +py_library( + name = "baseline", + srcs = ["python/estimator/baseline.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:baseline", + ], +) + +py_test( + name = "baseline_test", + size = "small", + srcs = ["python/estimator/baseline_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":baseline", + ":head", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:metric_keys", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "boosted_trees", + srcs = ["python/estimator/boosted_trees.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:boosted_trees", + ], +) + +py_test( + name = "boosted_trees_test", + size = "medium", + srcs = ["python/estimator/boosted_trees_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":boosted_trees", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], +) + py_library( name = "dnn", srcs = ["python/estimator/dnn.py"], @@ -57,6 +123,7 @@ py_test( tags = [ "no_pip", "notsan", + "optonly", # times out http://b/79220679 ], deps = [ ":dnn", @@ -70,6 +137,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -110,6 +178,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -138,9 +207,11 @@ py_test( size = "medium", srcs = ["python/estimator/extenders_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 deps = [ ":extenders", "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/predictor", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", @@ -155,6 +226,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ @@ -169,9 +277,11 @@ py_library( "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", + "//tensorflow/python:nn", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", @@ -184,13 +294,14 @@ py_library( py_test( name = "head_test", - size = "small", + size = "medium", srcs = ["python/estimator/head_test.py"], srcs_version = "PY2AND3", deps = [ ":head", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -201,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", @@ -211,6 +323,37 @@ py_test( ], ) +py_library( + name = "hooks", + srcs = [ + "python/estimator/hooks.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "hooks_test", + size = "medium", + srcs = ["python/estimator/hooks_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":hooks", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "linear", srcs = ["python/estimator/linear.py"], @@ -223,7 +366,7 @@ py_library( py_test( name = "linear_test", - size = "small", + size = "medium", srcs = ["python/estimator/linear_test.py"], srcs_version = "PY2AND3", tags = [ @@ -242,6 +385,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -255,9 +399,9 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:util", ], ) @@ -288,6 +432,8 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", @@ -337,6 +483,7 @@ py_library( "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", @@ -351,6 +498,7 @@ cuda_py_test( size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", @@ -382,3 +530,63 @@ cuda_py_test( "notap", ], ) + +py_library( + name = "rnn", + srcs = ["python/estimator/rnn.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + "//tensorflow/contrib/feature_column:feature_column_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + +py_test( + name = "rnn_test", + size = "medium", + srcs = ["python/estimator/rnn_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "noasan", # times out + "notsan", + "optonly", # times out http://b/79220679 + ], + deps = [ + ":head", + ":rnn", + "//tensorflow/contrib/data", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:parsing_utils", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 0f75b77050b0ba4c752a6a74fdc7024170b6f318..788ac5ca7046d6dd30a3d5520b243944532622fa 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,14 +19,19 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.baseline import * +from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.hooks import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * +from tensorflow.contrib.estimator.python.estimator.rnn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -36,18 +41,28 @@ _allowed_symbols = [ 'binary_classification_head', 'clip_gradients_by_norm', 'forward_features', + 'InMemoryEvaluatorHook', + 'logistic_regression_head', 'multi_class_head', 'multi_head', 'multi_label_head', + 'poisson_regression_head', 'regression_head', + 'BaselineEstimator', 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', + 'boosted_trees_classifier_train_in_memory', + 'boosted_trees_regressor_train_in_memory', 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', 'replicate_model_fn', 'TowerOptimizer', + 'RNNClassifier', + 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..beffbee73064b9ef425b115317c43e29477b19af --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Baseline estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import baseline + + +class BaselineEstimator(estimator.Estimator): + """An estimator that can establish a simple baseline. + + The estimator uses a user-specified head. + + This estimator ignores feature values and will learn to predict the average + value of each label. E.g. for single-label classification problems, this will + predict the probability distribution of the classes as seen in the labels. + For multi-label classification problems, it will predict the ratio of examples + that contain each class. + + Example: + + ```python + + # Build baseline multi-label classifier. + estimator = BaselineEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3)) + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + estimator.train(input_fn=input_fn_train) + + # Evaluates cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # For each class, predicts the ratio of training examples that contain the + # class. + predictions = classifier.predict(new_samples) + + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` passed to the `head` constructor is not `None`, a feature + with `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + head, + model_dir=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + """ + def _model_fn(features, labels, mode, config): + return baseline._baseline_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + config=config) + super(BaselineEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e3e670f7332811c1bfdaea65b0308ce59ade59 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -0,0 +1,430 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for baseline.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import baseline +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer +from tensorflow.python.training import saver + +# Names of variables created by model. +BIAS_NAME = 'baseline/bias' + + +def assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def save_variables_to_ckpt(model_dir): + init_all_op = [variables.global_variables_initializer()] + with tf_session.Session() as sess: + sess.run(init_all_op) + saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + + +def _baseline_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a BaselineEstimator that uses regression_head.""" + return baseline.BaselineEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), + *args, **kwargs) + + +class BaselineEstimatorEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate( + input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the sum over batch = 9 + 9 = 18 + # Average loss is the average over batch = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 18., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} + labels = ((10.,), (10.,)) + return features, labels + + baseline_estimator = _baseline_estimator_fn( + weight_column='weights', + model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the weighted sum over batch = 9 + 2*9 = 27 + # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 27., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_for_multi_dimensions(self): + label_dim = 2 + with ops.Graph().as_default(): + variables.Variable([46.0, 58.0], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dim, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={ + 'age': np.array([[2., 4., 5.]]), + }, + y=np.array([[46., 58.]]), + batch_size=1, + num_epochs=None, + shuffle=False) + eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1) + + self.assertItemsEqual( + (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, + ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + + # Logit is bias which is [46, 58] + self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) + + +class BaselineEstimatorPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_1d(self): + """Tests predict when all variables are one-dimensional.""" + with ops.Graph().as_default(): + variables.Variable([.2], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[2.]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # x * weight + bias = 2. * 10. + .2 = 20.2 + self.assertAllClose([[.2]], predicted_scores) + + def testMultiDim(self): + """Tests predict when all variables are multi-dimenstional.""" + batch_size = 2 + label_dimension = 3 + with ops.Graph().as_default(): + variables.Variable( # shape=[label_dimension] + [.2, .4, .6], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + # x shape=[batch_size, x_dim] + x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # score = bias, shape=[batch_size, label_dimension] + self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], + predicted_scores) + + +class BaselineEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, + input_dimension, label_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['predictions'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + +class BaselineEstimatorTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step=None, var_list=None): + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint(self, + label_dimension, + expected_global_step, + expected_bias=None): + shapes = { + name: shape + for (name, shape) in checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(self._model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([label_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertEqual(expected_bias, + checkpoint_utils.load_variable(self._model_dir, + BIAS_NAME)) + + def testFromScratch(self): + # Create BaselineRegressor. + label = 5. + age = 17 + # loss = (logits - label)^2 = (0 - 5.)^2 = 25. + mock_optimizer = self._mock_optimizer(expected_loss=25.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=num_steps, + expected_bias=[0.]) + + def testFromCheckpoint(self): + # Create initial checkpoint. + bias = 7.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = 6. + # loss = (logits - label)^2 = (7 - 5)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=[bias]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..bd641014e9eec6623d66574bccd08ff03ebc28ac --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -0,0 +1,357 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Boosted Trees estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees + + +def _validate_input_fn_and_repeat_dataset(train_input_fn): + """Validates whether the input_fn is valid, and repeat() if tf.Dataset.""" + def _input_fn(): + result_input_fn = train_input_fn() + if isinstance(result_input_fn, dataset_ops.Dataset): + return result_input_fn.repeat() + return result_input_fn + + return _input_fn + + +class _BoostedTreesEstimator(estimator.Estimator): + """An Estimator for Tensorflow Boosted Trees models.""" + + def __init__(self, + feature_columns, + n_batches_per_layer, + head, + model_dir=None, + weight_column=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + min_node_weight=0., + config=None): + """Initializes a `BoostedTreesEstimator` instance. + + Args: + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + n_batches_per_layer: the number of batches to collect statistics per + layer. + head: the `Head` instance defined for Estimator. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). + config: `RunConfig` object to configure the runtime settings. + """ + # pylint:disable=protected-access + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, labels, mode, head, feature_columns, tree_hparams, + n_batches_per_layer, config) + + super(_BoostedTreesEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + # pylint:enable=protected-access + + +def boosted_trees_classifier_train_in_memory( + train_input_fn, + feature_columns, + model_dir=None, + n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT, + weight_column=None, + label_vocabulary=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + min_node_weight=0., + config=None, + train_hooks=None): + """Trains a boosted tree classifier with in memory dataset. + + Example: + + ```python + bucketized_feature_1 = bucketized_column( + numeric_column('feature_1'), BUCKET_BOUNDARIES_1) + bucketized_feature_2 = bucketized_column( + numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + + def train_input_fn(): + dataset = create-dataset-from-training-data + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. + return dataset + + classifier = boosted_trees_classifier_train_in_memory( + train_input_fn, + feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_trees=100, + ... + ) + + def input_fn_eval(): + ... + return dataset + + metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10) + ``` + + Args: + train_input_fn: the input function returns a dataset containing a single + epoch of *unbatched* features and labels. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + n_classes: number of label classes. Default is binary classification. + Multiclass support is not yet implemented. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). + config: `RunConfig` object to configure the runtime settings. + train_hooks: a list of Hook instances to be passed to estimator.train(). + + Returns: + a `BoostedTreesClassifier` instance created with the given arguments and + trained with the data loaded up on memory from the input_fn. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + # pylint: disable=protected-access + # TODO(nponomareva): Support multi-class cases. + if n_classes == canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT: + n_classes = 2 + head, closed_form = ( + canned_boosted_trees._create_classification_head_and_closed_form( + n_classes, weight_column, label_vocabulary=label_vocabulary)) + + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer=1, + config=config, + closed_form_grad_and_hess_fn=closed_form, + train_in_memory=True) + + in_memory_classifier = estimator.Estimator( + model_fn=_model_fn, model_dir=model_dir, config=config) + + in_memory_classifier.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) + + return in_memory_classifier + # pylint: enable=protected-access + + +def boosted_trees_regressor_train_in_memory( + train_input_fn, + feature_columns, + model_dir=None, + label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT, + weight_column=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + min_node_weight=0., + config=None, + train_hooks=None): + """Trains a boosted tree regressor with in memory dataset. + + Example: + + ```python + bucketized_feature_1 = bucketized_column( + numeric_column('feature_1'), BUCKET_BOUNDARIES_1) + bucketized_feature_2 = bucketized_column( + numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + + def train_input_fn(): + dataset = create-dataset-from-training-data + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. + return dataset + + regressor = boosted_trees_regressor_train_in_memory( + train_input_fn, + feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_trees=100, + ... + ) + + def input_fn_eval(): + ... + return dataset + + metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10) + ``` + + Args: + train_input_fn: the input function returns a dataset containing a single + epoch of *unbatched* features and labels. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + label_dimension: Number of regression targets per example. + Multi-dimensional support is not yet implemented. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). + config: `RunConfig` object to configure the runtime settings. + train_hooks: a list of Hook instances to be passed to estimator.train(). + + Returns: + a `BoostedTreesClassifier` instance created with the given arguments and + trained with the data loaded up on memory from the input_fn. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + # pylint: disable=protected-access + # TODO(nponomareva): Extend it to multi-dimension cases. + if label_dimension == canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT: + label_dimension = 1 + head = canned_boosted_trees._create_regression_head(label_dimension, + weight_column) + + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer=1, + config=config, + train_in_memory=True) + + in_memory_regressor = estimator.Estimator( + model_fn=_model_fn, model_dir=model_dir, config=config) + + in_memory_regressor.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) + + return in_memory_regressor + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py new file mode 100644 index 0000000000000000000000000000000000000000..76cbefe5e94502188388df6fc2816d130ac896d5 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -0,0 +1,222 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests boosted_trees estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.estimator.python.estimator import boosted_trees +from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils + +NUM_FEATURES = 3 + +BUCKET_BOUNDARIES = [-2., .5, 12.] # Boundaries for all the features. +INPUT_FEATURES = np.array( + [ + [12.5, 1.0, -2.001, -2.0001, -1.999], # feature_0 quantized:[3,2,0,0,1] + [2.0, -3.0, 0.5, 0.0, 0.4995], # feature_1 quantized:[2,0,2,1,1] + [3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3] + ], + dtype=np.float32) +CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]] +REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]] +FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)} + + +def _make_train_input_fn(is_classification): + """Makes train input_fn for classification/regression.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + return features_dict, labels + + return _input_fn + + +def _make_train_input_fn_dataset(is_classification): + """Makes input_fn using Dataset.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(features_dict), + dataset_ops.Dataset.from_tensors(labels) + )) + return ds + + return _input_fn + + +class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._head = canned_boosted_trees._create_regression_head(label_dimension=1) + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) + for i in range(NUM_FEATURES) + } + + def _assert_checkpoint(self, model_dir, global_step, finalized_trees, + attempted_layers): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + serialized = reader.get_tensor('boosted_trees:0_serialized') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertEqual( + finalized_trees, + sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized])) + self.assertEqual(attempted_layers, + ensemble_proto.growing_metadata.num_layers_attempted) + + def testTrainAndEvaluateEstimator(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5) + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 1.008551) + + def testInferEstimator(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + head=self._head) + + # It will stop after 5 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + + def testBinaryClassifierTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + + def testRegressorTrainInMemoryAndEvalAndInfer(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testRegressorTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index cf6e3329d2e27735d8759cc2ab3726e8c624c6ae..7ff25b95c079c7e06d29e874bcaa0d2c13e7167e 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -93,7 +93,7 @@ class DNNEstimator(estimator.Estimator): dropout=None, input_layer_partitioner=None, config=None): - """Initializes a `DNNClassifier` instance. + """Initializes a `DNNEstimator` instance. Args: head: A `_Head` instance constructed with a method such as diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index b5e4d34dc70ccaa4806ae8b8ed5001bd971ee7b4..dd009a6753f3231638f93e50fc8f19eae8820139 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -34,6 +34,7 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -52,7 +53,9 @@ def _dnn_only_estimator_fn( config=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, dnn_feature_columns=feature_columns, dnn_optimizer=optimizer, @@ -100,7 +103,9 @@ def _linear_only_estimator_fn( partitioner=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, linear_feature_columns=feature_columns, linear_optimizer=optimizer, diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 71f810acec856d42d389260e7b9fea32123348b4..75e3107670d658e55ce23d983e47311f1c180104 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -41,7 +42,9 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 0000000000000000000000000000000000000000..03cf6f107c1c5589522d7be4946562a466740b0e --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name(''linear/linear_model/age/weights') + ... + ``` + + This method is a wrapper for _export_all_saved_models, and wraps a raw + input_receiver_fn in a dictionary to pass in to that function. + See _export_all_saved_models for full docs. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + For each mode passed in via the input_receiver_fn_map, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Only one of the modes is used for saving variables to the SavedModel + (order of preference: TRAIN, EVAL, then PREDICT), such that up to three + MetaGraphDefs are saved with a single set of variables in a single + SavedModel directory. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name('linear/linear_model/age/weights') + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 0000000000000000000000000000000000000000..050821ee672f30a6926c4a0a0e48915515d9afd7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,373 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for tag_set in model_fn_lib.EXPORT_TAG_MAP.values(): + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + self._validate_exported_files(export_dir) + + return export_dir, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index c99bf8badb35e6fffb7cae8761db9d402b8b3a8f..bf08be09e7baf63e507a6a4db6a91e7b6bb20b74 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -22,18 +22,19 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import function_utils _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) def add_metrics(estimator, metric_fn): - """Creates a new ${tf.estimator.Estimator} which has given metrics. + """Creates a new @{tf.estimator.Estimator} which has given metrics. Example: @@ -60,7 +61,7 @@ def add_metrics(estimator, metric_fn): ``` Args: - estimator: A ${tf.estimator.Estimator} object. + estimator: A @{tf.estimator.Estimator} object. metric_fn: A function which should obey the following signature: - Args: can only have following four arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given @@ -78,7 +79,7 @@ def add_metrics(estimator, metric_fn): function, namely a `(metric_tensor, update_op)` tuple. Returns: - A new ${tf.estimator.Estimator} which has a union of original metrics with + A new @{tf.estimator.Estimator} which has a union of original metrics with given ones. """ _verify_metric_fn_args(metric_fn) @@ -96,7 +97,10 @@ def add_metrics(estimator, metric_fn): return estimator_lib.Estimator( model_fn=new_model_fn, model_dir=estimator.model_dir, - config=estimator.config) + config=estimator.config, + # pylint: disable=protected-access + warm_start_from=estimator._warm_start_settings) + # pylint: enable=protected-access def clip_gradients_by_norm(optimizer, clip_norm): @@ -161,14 +165,14 @@ def forward_features(estimator, keys=None): ``` Args: - estimator: A ${tf.estimator.Estimator} object. + estimator: A @{tf.estimator.Estimator} object. keys: a `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict` is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list` of strings, all the given `keys` are forwarded. Returns: - A new ${tf.estimator.Estimator} which forwards features to predictions. + A new @{tf.estimator.Estimator} which forwards features to predictions. Raises: ValueError: @@ -233,7 +237,17 @@ def forward_features(estimator, keys=None): 'argument of forward_features to filter unwanted features. Type of ' 'features[{}] is {}.'.format(key, key, type(feature))) predictions[key] = feature - return spec._replace(predictions=predictions) + spec = spec._replace(predictions=predictions) + if spec.export_outputs: + for ekey in ['predict', 'serving_default']: + if (ekey in spec.export_outputs and + isinstance(spec.export_outputs[ekey], + PredictOutput)): + export_outputs = spec.export_outputs[ekey].outputs + for key in get_keys(features): + export_outputs[key] = predictions[key] + + return spec return estimator_lib.Estimator( model_fn=new_model_fn, @@ -316,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): - args = set(estimator_util.fn_args(metric_fn)) + args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % @@ -325,7 +339,7 @@ def _verify_metric_fn_args(metric_fn): def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" - metric_fn_args = estimator_util.fn_args(metric_fn) + metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index ad1a8ef152b07ecbab33d9eb3184a2ae89def27d..407af2deaf0928361a4f0b0e44e842b7750118cb 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -18,20 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile import numpy as np from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.predictor import from_saved_model from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import training +from tensorflow.python.util import compat def get_input_fn(x, y): @@ -177,6 +184,44 @@ class ForwardFeaturesTest(test.TestCase): self.assertIn('id', predictions) self.assertEqual(101, predictions['id']) + def test_forward_in_exported(self): + + def serving_input_fn(): + features_ph = { + 'x': array_ops.placeholder(dtypes.float32, [None]), + 'id': array_ops.placeholder(dtypes.int32, [None]) + } + features = { + key: array_ops.expand_dims(tensor, -1) + for key, tensor in features_ph.items() + } + return estimator_lib.export.ServingInputReceiver(features, features_ph) + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + # create estimator + feature_columns = [fc.numeric_column('x')] + estimator = linear.LinearRegressor(feature_columns) + estimator.train(input_fn=input_fn, steps=1) + estimator = extenders.forward_features(estimator, 'id') + + # export saved model + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn) + self.assertTrue(gfile.Exists(export_dir)) + + # restore model + predict_fn = from_saved_model(export_dir, signature_def_key='predict') + predictions = predict_fn({'x': [3], 'id': [101]}) + + # verify that 'id' exists in predictions + self.assertIn('id', predictions) + self.assertEqual(101, predictions['id']) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + def test_forward_list(self): def input_fn(): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 238cf287b768eee28b20202084eb244c085c8b75..9594e5132fd20dadea118fd1dd6768feb7fd7fff 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys @@ -31,10 +33,12 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -42,7 +46,7 @@ _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for multi class classification. @@ -70,6 +74,33 @@ def multi_class_head(n_classes, shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -83,7 +114,8 @@ def multi_class_head(n_classes, have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -108,7 +140,7 @@ def binary_classification_head( weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for single label binary classification. @@ -136,6 +168,33 @@ def binary_classification_head( shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.binary_classification_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.binary_classification_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -152,7 +211,8 @@ def binary_classification_head( `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -175,8 +235,9 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + inverse_link_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -195,10 +256,44 @@ def regression_head(weight_column=None, `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. - Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, label_dimension]`. + Also supports custom `inverse_link_fn`, also known as 'mean function'. + `inverse_link_fn` is only used in `PREDICT` mode. It takes `logits` as + argument and returns predicted values. This function is the inverse of the + link function defined in + https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function + Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -208,8 +303,12 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. - loss_fn: Optional loss function. + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + loss_fn: Optional loss function. Defaults to `mean_squared_error`. + inverse_link_fn: Optional inverse link function, also known as 'mean + function'. Defaults to identity. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -219,11 +318,191 @@ def regression_head(weight_column=None, Raises: ValueError: If `label_dimension` or `loss_reduction` is invalid. """ - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, loss_fn=loss_fn, + inverse_link_fn=inverse_link_fn, + name=name) + + +def poisson_regression_head( + weight_column=None, + label_dimension=1, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, + compute_full_loss=True, + name=None): + """Creates a `_Head` for poisson regression using `tf.nn.log_poisson_loss`. + + The loss is the weighted sum over all input dimensions. Namely, if the input + labels have shape `[batch_size, label_dimension]`, the loss is the weighted + sum over both `batch_size` and `label_dimension`. + + The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. + In many applications, the shape is `[batch_size, label_dimension]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape + `[D0, D1, ... DN]` is also supported. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or + `[D0, D1, ... DN, label_dimension]`. + + This is implemented as a generalized linear model, see + https://en.wikipedia.org/wiki/Generalized_linear_model. + + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.poisson_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.poisson_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + compute_full_loss: Whether to include the constant `log(z!)` term in + computing the poisson loss. See `tf.nn.log_poisson_loss` for the full + documentation. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. + + Returns: + An instance of `_Head` for poisson regression. + + Raises: + ValueError: If `label_dimension` or `loss_reduction` is invalid. + """ + def _poisson_loss(labels, logits): + return nn.log_poisson_loss( + targets=labels, log_input=logits, compute_full_loss=compute_full_loss) + return head_lib._regression_head( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=label_dimension, + loss_reduction=loss_reduction, + loss_fn=_poisson_loss, + inverse_link_fn=math_ops.exp, + name=name) + + +def logistic_regression_head( + weight_column=None, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, + name=None): + """Creates a `_Head` for logistic regression. + + Uses `sigmoid_cross_entropy_with_logits` loss, which is the same as + `binary_classification_head`. The differences compared to + `binary_classification_head` are: + + * Does not support `label_vocabulary`. Instead, labels must be float in the + range [0, 1]. + * Does not calculate some metrics that do not make sense, such as AUC. + * In `PREDICT` mode, only returns logits and predictions + (`=tf.sigmoid(logits)`), whereas `binary_classification_head` also returns + probabilities, classes, and class_ids. + * Export output defaults to `RegressionOutput`, whereas + `binary_classification_head` defaults to `PredictOutput`. + + The head expects `logits` with shape `[D0, D1, ... DN, 1]`. + In many applications, the shape is `[batch_size, 1]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`. + + This is implemented as a generalized linear model, see + https://en.wikipedia.org/wiki/Generalized_linear_model. + + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.logistic_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.logistic_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. + + Returns: + An instance of `_Head` for logistic regression. + + Raises: + ValueError: If `loss_reduction` is invalid. + """ + def _logistic_loss(labels, logits): + labels = head_lib._assert_range( # pylint:disable=protected-access + labels, n_classes=2, message='Labels must be in range [0, 1]') + return nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + return head_lib._regression_head( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=1, + loss_reduction=loss_reduction, + loss_fn=_logistic_loss, + inverse_link_fn=math_ops.sigmoid, name=name) @@ -231,8 +510,9 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. @@ -249,6 +529,7 @@ def multi_label_head(n_classes, applications, the shape is `[batch_size, n_classes]`. Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. @@ -264,6 +545,33 @@ def multi_label_head(n_classes, shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -282,8 +590,13 @@ def multi_label_head(n_classes, string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. + classes_for_class_based_metrics: List of integer class IDs or string class + names for which per-class metrics are evaluated. If integers, all must be + in the range `[0, n_classes - 1]`. If strings, all must be in + `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -291,8 +604,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is - invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or + `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -317,10 +630,31 @@ def multi_label_head(n_classes, if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + classes_for_class_based_metrics = tuple( + [] if classes_for_class_based_metrics is None + else classes_for_class_based_metrics) + if classes_for_class_based_metrics: + if isinstance(classes_for_class_based_metrics[0], six.string_types): + if not label_vocabulary: + raise ValueError( + 'label_vocabulary must be provided when ' + 'classes_for_class_based_metrics are sting.') + class_ids = [] + for class_string in classes_for_class_based_metrics: + class_ids.append(label_vocabulary.index(class_string)) + classes_for_class_based_metrics = tuple(class_ids) + else: + for class_id in classes_for_class_based_metrics: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError( + 'All classes_for_class_based_metrics must be in range [0, {}]. ' + 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, - loss_fn=loss_fn, name=name) + loss_fn=loss_fn, + classes_for_class_based_metrics=classes_for_class_based_metrics, + name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -331,8 +665,9 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): self._n_classes = n_classes self._weight_column = weight_column @@ -340,6 +675,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn + self._classes_for_class_based_metrics = classes_for_class_based_metrics self._name = name @property @@ -406,7 +742,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access reduction=losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( - unweighted_loss, axis=-1, keep_dims=True) + unweighted_loss, axis=-1, keepdims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( @@ -417,10 +753,10 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, processed_labels=processed_labels) - def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): - """Returns an `EstimatorSpec`. + def _create_tpu_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): + """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. @@ -431,8 +767,11 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -440,9 +779,10 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + `model_fn._TPUEstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -459,7 +799,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access classifier_output = head_lib._classification_output( # pylint:disable=protected-access scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -482,20 +822,31 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=processed_labels, - probabilities=probabilities, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access + self._eval_metric_ops, { + 'labels': processed_labels, + 'probabilities': probabilities, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss, + })) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -517,11 +868,11 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access summary.scalar( head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _eval_metric_ops( self, labels, probabilities, weights, unreduced_loss, @@ -580,4 +931,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, threshold=threshold, name=recall_key)) + for class_id in self._classes_for_class_based_metrics: + batch_rank = array_ops.rank(probabilities) - 1 + begin = array_ops.concat( + [array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]], + axis=0) + size = array_ops.concat( + [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]], + axis=0) + class_probabilities = array_ops.slice( + probabilities, begin=begin, size=size) + class_labels = array_ops.slice(labels, begin=begin, size=size) + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access + head_lib._predictions_mean( # pylint:disable=protected-access + predictions=class_probabilities, + weights=weights, + name=prob_key)) + auc_key = keys.AUC_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + name=auc_key)) + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + curve='PR', + name=auc_pr_key)) return metric_ops diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 1411635228457218578c0297d4d901e9c86ca91a..b2b57fa06ba818d4455871fe57dde5ce287b39a2 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -32,9 +32,11 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -174,6 +176,21 @@ class MultiLabelHead(test.TestCase): r'loss_fn has unexpected args: \[\'name\'\]'): head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_classes_for_class_based_metrics_invalid(self): + with self.assertRaisesRegexp( + ValueError, + r'All classes_for_class_based_metrics must be in range \[0, 2\]\. ' + r'Given: -1'): + head_lib.multi_label_head( + n_classes=3, classes_for_class_based_metrics=[2, -1]) + + def test_classes_for_class_based_metrics_string_invalid(self): + with self.assertRaisesRegexp( + ValueError, r'\'z\' is not in list'): + head_lib.multi_label_head( + n_classes=3, label_vocabulary=['a', 'b', 'c'], + classes_for_class_based_metrics=['c', 'z']) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -271,9 +288,9 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - expected_training_loss = np.sum( + # loss = (labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / 2 + expected_training_loss = 0.5 * np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -297,7 +314,7 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_training_loss = np.sum( + expected_training_loss = 0.5 * np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -360,7 +377,7 @@ class MultiLabelHead(test.TestCase): labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -437,16 +454,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -467,18 +485,17 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -509,7 +526,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -532,18 +549,17 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -561,19 +577,18 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3., keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3., @@ -592,6 +607,81 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_classes_for_class_based_metrics(self): + head = head_lib.multi_label_head( + n_classes=2, classes_for_class_based_metrics=[0, 1]) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_with_classes_for_class_based_metrics_string(self): + head = head_lib.multi_label_head( + n_classes=2, label_vocabulary=['a', 'b'], + classes_for_class_based_metrics=['a', 'b']) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = sparse_tensor.SparseTensor( + values=['a', 'a', 'b'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_onehot, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') @@ -602,8 +692,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2 + expected_loss = 12.5 spec = head.create_estimator_spec( features={ @@ -616,12 +707,12 @@ class MultiLabelHead(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { - # Average loss over weighted examples. - keys.LOSS_MEAN: expected_loss / 3, + # Average loss over weighted examples (denominator is sum(weights)). + keys.LOSS_MEAN: expected_loss * (2. / 3.), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, - keys.AUC_PR: 0.5833, + keys.AUC_PR: 0.7833, } # Assert spec contains expected tensors. @@ -662,7 +753,7 @@ class MultiLabelHead(test.TestCase): # (1 - labels) * (logits > 0) * logits expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] expected_weights = [[1.], [2.]] - expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), @@ -808,11 +899,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol) def test_train(self): head = head_lib.multi_label_head(n_classes=2) @@ -822,8 +910,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -839,8 +928,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -857,11 +947,77 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_optimizer(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) @@ -915,8 +1071,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2 + expected_loss = 12.5 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -950,11 +1107,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over weighted examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol) def test_multi_dim_weighted_train_create_loss(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" @@ -971,8 +1125,8 @@ class MultiLabelHead(test.TestCase): expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] # weights are reshaped to [2, 2, 1] to match logits. expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_training_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_training_loss = 9.9167 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, @@ -998,8 +1152,8 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -1087,15 +1241,15 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 keys = metric_keys.MetricKeys expected_metrics = { - keys.LOSS_MEAN: expected_loss / np.sum(weights), + keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, - keys.AUC_PR: 0.4037, + keys.AUC_PR: 0.6645, } self._test_eval( head=head, @@ -1106,5 +1260,194 @@ class MultiLabelHead(test.TestCase): expected_metrics=expected_metrics) +class PoissonRegressionHead(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def test_train(self): + head = head_lib.poisson_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[1], [2], [3]], dtype=np.int32) + # With x = exp(logits), z = labels. + # loss = -ln(exp(-x) * (x^z) / z!) + # = x - z * ln(x) + ln(z!) + # = exp(logits) - labels * logits - ln(labels!) + # But for ln(z!) and z > 1, the Stirling approximation is used + # ln(z!) = z*ln(z) - z + 0.5*ln(2*pi*z) + # loss = [exp(0) - 1 * 0 + ln(1!), + # exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2), + # exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)] + # = [1.0, 3.020, 1.482] + # training_loss = (1.0 + 3.020 + 1.482) / 3 + expected_loss = 1.834 + atol = 0.001 + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_near( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + atol=atol, name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run([spec.loss, spec.train_op]) + self.assertAlmostEqual(expected_loss, loss, delta=atol) + self.assertEqual(expected_train_result, train_result) + + def test_predict(self): + head = head_lib.poisson_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + expected_predictions = np.exp(logits) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + + +class LogisticRegressionHead(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def test_train(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [.6], [.8]], dtype=np.float32) + # Following the documentation in + # tf.nn.sigmoid_cross_entropy_with_logits: + # With x = logits, z = labels. + # loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + # loss = [0 - 0 * 0.4 + ln(1 + exp(-0)), + # 0 + 1 * 0.6 + ln(1 + exp(-1)), + # 1 - 1 * 0.8 + ln(1 + exp(-1))] + # = [0.6931, 0.9133, 0.5133] + # training_loss = (0.6931 + 0.9133 + 0.5133) / 3 + expected_loss = 0.7066 + atol = 0.001 + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_near( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + atol=atol, name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run([spec.loss, spec.train_op]) + self.assertAlmostEqual(expected_loss, loss, delta=atol) + self.assertEqual(expected_train_result, train_result) + + def test_train_labels_too_large(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [1.2], [.8]], dtype=np.float32) + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[1.2\]\[0.8\]\]'): + _ = sess.run(spec.loss) + + def test_train_labels_negative(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [-0.2], [.8]], dtype=np.float32) + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[-0.2\]\[0.8\]\]' + ): + _ = sess.run(spec.loss) + + def test_predict(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + expected_predictions = 1. / (1. + np.exp(-logits)) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd6aa442f82bad2d4714dbcdc85b20b34773068 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Some useful session run hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training + + +# pylint: disable=protected-access +class InMemoryEvaluatorHook(training.SessionRunHook): + """Hook to run evaluation in training without a checkpoint. + + Example: + + ```python + def train_input_fn(): + ... + return train_dataset + + def eval_input_fn(): + ... + return eval_dataset + + estimator = tf.estimator.DNNClassifier(...) + + evaluator = tf.contrib.estimator.InMemoryEvaluatorHook( + estimator, eval_input_fn) + estimator.train(train_input_fn, hooks=[evaluator]) + ``` + + Current limitations of this approach are: + * It doesn't support multi-node distributed mode. + * It doesn't support saveable objects other than variables (such as boosted + tree support) + * It doesn't support custom saver logic (such as ExponentialMovingAverage + support) + + """ + + def __init__(self, + estimator, + input_fn, + steps=None, + hooks=None, + name=None, + every_n_iter=100): + """Initializes a `InMemoryEvaluatorHook`. + + Args: + estimator: A `tf.estimator.Estimator` instance to call evaluate. + input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A + function that constructs the input data for evaluation. + See @{$premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + steps: Equivalent to the `steps` arg to `estimator.evaluate`. Number of + steps for which to evaluate model. If `None`, evaluates until `input_fn` + raises an end-of-input exception. + hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of + `SessionRunHook` subclass instances. Used for callbacks inside the + evaluation call. + name: Equivalent to the `name` arg to `estimator.evaluate`. Name of the + evaluation if user needs to run multiple evaluations on different data + sets, such as on training data vs test data. Metrics for different + evaluations are saved in separate folders, and appear separately in + tensorboard. + every_n_iter: `int`, runs the evaluator once every N training iteration. + + Raises: + ValueError: if `every_n_iter` is non-positive or it's not a single machine + training + """ + if every_n_iter is None or every_n_iter <= 0: + raise ValueError('invalid every_n_iter=%s.' % every_n_iter) + if (estimator.config.num_ps_replicas > 0 or + estimator.config.num_worker_replicas > 1): + raise ValueError( + 'InMemoryEvaluator supports only single machine (aka Local) setting.') + self._estimator = estimator + self._input_fn = input_fn + self._steps = steps + self._name = name + self._every_n_iter = every_n_iter + self._eval_dir = os.path.join(self._estimator.model_dir, 'eval' + if not name else 'eval_' + name) + + self._graph = None + self._hooks = estimator_lib._check_hooks_type(hooks) + self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps)) + self._timer = training.SecondOrStepTimer(every_steps=every_n_iter) + + def begin(self): + """Build eval graph and restoring op.""" + self._timer.reset() + self._iter_count = 0 + self._graph = ops.Graph() + with self._graph.as_default(): + (self._scaffold, self._update_op, self._eval_dict, + self._all_hooks) = self._estimator._evaluate_build_graph( + self._input_fn, self._hooks, checkpoint_path=None) + + if self._scaffold.saver is not None: + raise ValueError('InMemoryEvaluator does not support custom saver') + if self._scaffold.init_fn is not None: + raise ValueError('InMemoryEvaluator does not support custom init_fn') + + self._var_name_to_eval_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + self._var_name_to_placeholder = { + v.name: array_ops.placeholder(v.dtype) + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + + def after_create_session(self, session, coord): # pylint: disable=unused-argument + """Does first run which shows the eval metrics before training.""" + if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS): + raise ValueError( + 'InMemoryEvaluator does not support saveables other than global ' + 'variables.') + self._var_name_to_train_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set( + self._var_name_to_train_var.keys()) + # Filter training var names that are not exist in evaluation + self._var_name_to_train_var = { + v_name: self._var_name_to_train_var[v_name] + for v_name in var_names_to_transfer + } + # Filter eval var names that are not exist in training + self._var_name_to_eval_var = { + v_name: self._var_name_to_eval_var[v_name] + for v_name in var_names_to_transfer + } + + with self._graph.as_default(): + self._var_feed_op = control_flow_ops.group([ + state_ops.assign(self._var_name_to_eval_var[v_name], + self._var_name_to_placeholder[v_name]) + for v_name in var_names_to_transfer + ]) + + self._evaluate(session) + + def _evaluate(self, train_session): + var_name_to_value = train_session.run(self._var_name_to_train_var) + placeholder_to_value = { + self._var_name_to_placeholder[v_name]: var_name_to_value[v_name] + for v_name in var_name_to_value + } + + def feed_variables(scaffold, session): + del scaffold + session.run(self._var_feed_op, feed_dict=placeholder_to_value) + + scaffold = training.Scaffold( + init_fn=feed_variables, copy_from_scaffold=self._scaffold) + + with self._graph.as_default(): + return self._estimator._evaluate_run( + checkpoint_path=None, + scaffold=scaffold, + update_op=self._update_op, + eval_dict=self._eval_dict, + all_hooks=self._all_hooks, + output_dir=self._eval_dir) + + self._timer.update_last_triggered_step(self._iter_count) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Runs evaluator.""" + self._iter_count += 1 + if self._timer.should_trigger_for_step(self._iter_count): + self._evaluate(run_context.session) + + def end(self, session): # pylint: disable=unused-argument + """Runs evaluator for final model.""" + self._evaluate(session) + + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae971852ee6dffb6174fc243686721c30ef685 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os + +from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import training + + +def summary_step_keyword_to_value_mapping(dir_): + writer_cache.FileWriterCache.clear() + + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + step_keyword_to_value = {} + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step not in step_keyword_to_value: + step_keyword_to_value[last_event.step] = {} + if last_event.summary is not None: + for value in last_event.summary.value: + step_keyword_to_value[last_event.step][value.tag] = value.simple_value + + return step_keyword_to_value + + +def get_summary_value(dir_, step, keyword): + """Get summary value for given step and keyword.""" + + writer_cache.FileWriterCache.clear() + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + print('XXX', event_paths) + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step == step and last_event.summary is not None: + for value in last_event.summary.value: + if keyword in value.tag: + return value.simple_value + return None + + +class InMemoryEvaluatorHookTest(test.TestCase): + + def test_runs_eval_metrics(self): + + def model_fn(features, labels, mode): + _ = labels + if estimator_lib.ModeKeys.TRAIN == mode: + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + return estimator_lib.EstimatorSpec( + mode, loss=constant_op.constant(3.), train_op=train_op) + if estimator_lib.ModeKeys.EVAL == mode: + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(5.), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # 4.5 = sum(range(10))/10 + # before training + self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) + # intervals (every_n_iter=4) + self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features']) + self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) + # end + self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + + def test_uses_latest_variable_value(self): + + def model_fn(features, labels, mode): + _ = labels + step = training.get_global_step() + w = variable_scope.get_variable( + 'w', + shape=[], + initializer=init_ops.zeros_initializer(), + dtype=dtypes.int64) + if estimator_lib.ModeKeys.TRAIN == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + step_inc = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([step_inc]): + assign_w_to_step_plus_2 = w.assign(step + 2) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + train_op=assign_w_to_step_plus_2) + if estimator_lib.ModeKeys.EVAL == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + loss = constant_op.constant(5.) + return estimator_lib.EstimatorSpec( + mode, + loss=loss, + # w is constant in each step, so the mean. + # w = 0 if step==0 else step+2 + eval_metric_ops={'mean_of_const': metrics_lib.mean(w)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # w = 0 if step==0 else step+2 + self.assertEqual(0, step_keyword_to_value[0]['mean_of_const']) + self.assertEqual(6, step_keyword_to_value[4]['mean_of_const']) + self.assertEqual(12, step_keyword_to_value[10]['mean_of_const']) + + def test_dnn_classifier(self): + embedding = feature_column_lib.embedding_column( + feature_column_lib.categorical_column_with_vocabulary_list( + 'wire_cast', ['kima', 'omar', 'stringer']), 8) + dnn = estimator_lib.DNNClassifier( + feature_columns=[embedding], hidden_units=[3, 1]) + + def train_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['omar'], ['kima']] + }, [[0], [1]])).repeat(3) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['stringer'], ['kima']] + }, [[0], [1]])).repeat(2) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + dnn, eval_input_fn, name='in-memory') + dnn.train(train_input_fn, hooks=[evaluator]) + self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory'))) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + dnn.eval_dir('in-memory')) + + final_metrics = dnn.evaluate(eval_input_fn) + step = final_metrics[ops.GraphKeys.GLOBAL_STEP] + for summary_tag in final_metrics: + if summary_tag == ops.GraphKeys.GLOBAL_STEP: + continue + self.assertEqual(final_metrics[summary_tag], + step_keyword_to_value[step][summary_tag]) + + def test_raise_error_with_multi_worker(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_ps(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1'], + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_custom_saver_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(saver=training.Saver()), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom saver'): + evaluator.begin() + + def test_raise_error_with_custom_init_fn_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + + def init_fn(scaffold, session): + _, _ = scaffold, session + + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_fn=init_fn), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'): + evaluator.begin() + + def test_raise_error_with_saveables_other_than_global_variables(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + w = variables.Variable( + initial_value=[0.], + trainable=False, + collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) + init_op = control_flow_ops.group( + [w.initializer, training.get_global_step().initializer]) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_op=init_op), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support saveables'): + estimator.train(input_fn, hooks=[evaluator]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py index c63514eb688af48577f0a3b7ce9e7478309f2c30..c41996b9c6871d294f157411662f2eb9d4c09e5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear_test.py +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -42,7 +43,9 @@ def _linear_estimator_fn( """Returns a LinearEstimator that uses regression_head.""" return linear.LinearEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 09c2862ccd3f90de4153a2095afc9c3d3f9476c1..c8b0dd62970e341a3c6b176278fe1c2adfcd8d20 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -41,10 +41,10 @@ from __future__ import print_function import six -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core from tensorflow.python.framework import ops +from tensorflow.python.util import function_utils # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder @@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ - logit_fn_args = util.fn_args(logit_fn) + logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 0346ddc24bffd61068177f4622bd03be4acd53d9..ce758992140d43529037b14cbbf958d5aa763fb4 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -30,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -226,8 +228,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access weights=example_weights_by_head, processed_labels=labels_by_head) + # TODO(b/65403806): Support regularization_losses arg. def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): """See `_Head`.""" if isinstance(logits, dict): logits_dict = logits @@ -248,9 +252,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access train_op_fn=_no_op_train_fn)) if mode == model_fn.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None in TRAIN mode.') - spec = self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train( + all_estimator_spec=all_estimator_spec, + optimizer=optimizer, + train_op_fn=train_op_fn) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec @@ -279,16 +284,21 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, train_op_fn): + def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): """Merges list of `EstimatorSpec` for training. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. - train_op_fn: Function to create train op. See `create_estimator_spec` - documentation for more details. + optimizer: `Optimizer` instance to create train op. See + `create_estimator_spec` documentation for more details. + train_op_fn: Function to create train op. Used if `optimizer` is `None`. Returns: `EstimatorSpec` that merges all heads for TRAIN. + + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode. """ losses = [] metrics = {} @@ -297,11 +307,20 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access # Metric keys already contain head.name. metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + loss, global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op_fn(loss), + train_op=train_op, eval_metric_ops=metrics) def _merge_predict(self, all_estimator_spec): @@ -319,16 +338,24 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access all_estimator_spec[0].export_outputs, self._heads[0].name), } + merged_predict_outputs = {} for head, spec in zip(self._heads, all_estimator_spec): head_name = head.name for k, v in six.iteritems(spec.export_outputs): if k == _DEFAULT_SERVING_KEY: key = head_name else: - key = '%s/%s' % (k, head_name) + key = '%s/%s' % (head_name, k) export_outputs[key] = v + if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access + isinstance(v, export_output_lib.PredictOutput)): + for kp, vp in six.iteritems(v.outputs): + key = '%s/%s' % (head_name, kp) + merged_predict_outputs[key] = vp for k, v in six.iteritems(spec.predictions): predictions[(head_name, k)] = v + export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access + export_output_lib.PredictOutput(merged_predict_outputs)) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index e47a6788f3b5440c4906b9f0430c802cf73237e3..3d6fccb1180c435f64552667306be004437f62ba 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', + 'head1/predict', 'head2', 'head2/classification', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -158,6 +158,22 @@ class MultiHeadTest(test.TestCase): self.assertAllClose( expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['predict'].outputs['head1/probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['predict'].outputs['head2/probabilities'])) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['head1/predict'].outputs['probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['head2/predict'].outputs['probabilities'])) def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" @@ -181,8 +197,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', + 'head1/predict', 'head2', 'head2/classification', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -238,8 +254,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', - 'head2', 'regression/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression', + 'head1/predict', 'head2', 'head2/regression', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -283,10 +299,11 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15. expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 spec = multi_head.create_estimator_spec( @@ -300,14 +317,14 @@ class MultiHeadTest(test.TestCase): keys.LOSS + '/head1': expected_loss_head1, keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. - keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, + keys.LOSS_MEAN + '/head1': expected_loss_head1, + keys.LOSS_MEAN + '/head2': expected_loss_head2, # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, keys.AUC + '/head2': 0.3333, - keys.AUC_PR + '/head1': 0.49999964, - keys.AUC_PR + '/head2': 0.33333313, + keys.AUC_PR + '/head1': 0.6667, + keys.AUC_PR + '/head2': 0.5000, } # Assert spec contains expected tensors. @@ -347,8 +364,8 @@ class MultiHeadTest(test.TestCase): tol = 1e-3 with self.test_session(): # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] - # (averaged over classes, sum-reduced over examples). - self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) + # (averaged over classes, averaged over examples). + self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): # Use different example weighting for each head weighting. @@ -383,18 +400,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -431,18 +448,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -466,14 +483,14 @@ class MultiHeadTest(test.TestCase): [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), } # Loss for the first head: - # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + - # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 - # = 28 + # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8 + # = 3.5 # Loss for the second head: - # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + - # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 - # = 74 - expected_training_loss = 28. + 74. + # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12 + # = 6.167 + expected_training_loss = 3.5 + 6.167 training_loss = multi_head.create_loss( features={}, @@ -495,8 +512,8 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -530,10 +547,46 @@ class MultiHeadTest(test.TestCase): _assert_simple_summaries(self, { metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, }, summary_str, tol) + def test_train_one_head_with_optimizer(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + multi_head = multi_head_lib.multi_head([head1]) + + logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} + labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -553,10 +606,12 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -592,9 +647,6 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, }, summary_str, tol) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index e0fae2c99292385c6dd32cc6002cee2076a2bb20..cda23aa437f954700b74dcb9294550eb9a8a8c5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -32,7 +32,6 @@ import six from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import device_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib @@ -47,8 +46,13 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils +@deprecation.deprecated( + '2018-05-31', + 'Please use `tf.contrib.distribute.MirroredStrategy` instead.') def replicate_model_fn(model_fn, loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, devices=None): @@ -136,7 +140,7 @@ def replicate_model_fn(model_fn, the train_op argument of `EstimatorSpec`. loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This - argument can be used to replice only on the subset of available GPUs. + argument can be used to replicate only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. @@ -255,6 +259,9 @@ class TowerOptimizer(optimizer_lib.Optimizer): COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states' + @deprecation.deprecated( + '2018-05-31', + 'Please use `tf.contrib.distribute.MirroredStrategy` instead.') def __init__(self, optimizer_or_optimizer_fn): """Wrap an existing optimizer for gathering gradients across towers. @@ -456,7 +463,7 @@ def _get_local_devices(device_type): def _split_batch(features, labels, number_of_shards, device): - """Split input features and labes into batches.""" + """Split input features and labels into batches.""" def ensure_divisible_by_shards(sequence): batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0] @@ -514,7 +521,7 @@ def _get_loss_towers(model_fn, """Replicate the loss computation across devices.""" tower_specs = [] - model_fn_args = util.fn_args(model_fn) + model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) @@ -602,7 +609,7 @@ def _local_device_setter(worker_device, ps_devices, ps_strategy): def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers): - """Produce an EstimatorSpec with approproriately scaled loss.""" + """Produce an EstimatorSpec with appropriately scaled loss.""" if tower_spec.loss is None: return tower_spec diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index d46a18aacfcd911c56a9f22dc9581060c7b458a6..dd8a3a95f1b83bfd29e8a38ec1512f90e22968d9 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import re import shutil import tempfile +from absl.testing import parameterized import numpy as np import six @@ -57,26 +58,19 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import training -# TODO(isaprykin): Parametrize all the tests on -# replicate_model_fn._VariableDistributionMode when it's supported. -class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow_with_public_version(self): - return self._complete_flow_with_mode(mode=None) - - def test_complete_flow_with_mode_local_ps_server(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode. - SHARED_LOCAL_PARAMETER_SERVER) - - def test_complete_flow_with_mode_round_robin(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) - - def _complete_flow_with_mode(self, mode): + @parameterized.named_parameters( + ('PublicInterface', None), + ('ParameterServerMode', replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER), + ('RoundRobinMode', + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)) + def test_complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 @@ -546,59 +540,6 @@ class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( self.assertEqual(7.0, session.run(c)) -class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): - - def model_fn(self, mode, features, labels, params): - c = variable_scope.get_variable( - 'c', - initializer=constant_op.constant(10, dtype=dtypes.float64), - dtype=dtypes.float64) - - features = features['features'] - predictions = math_ops.multiply(features, c) - - loss = losses.absolute_difference( - labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) - loss = math_ops.reduce_sum(loss) - - metrics = { - 'accuracy': metrics_lib.accuracy(labels, predictions), - 'auc': metrics_lib.auc(labels, predictions) - } - - optimizer = replicate_model_fn.TowerOptimizer( - gradient_descent.GradientDescentOptimizer(params['learning_rate'])) - - return model_fn_lib.EstimatorSpec( - mode=mode, - loss=loss, - eval_metric_ops=metrics, - predictions={'probabilities': predictions}, - train_op=optimizer.minimize(loss)) - - @property - def params(self): - params = {} - params['learning_rate'] = 1.0 - return params - - def test_train_single_tower(self): - features = np.array([[1.0], [2.0]]) - labels = np.array([[1.0], [2.0]]) - - train_input_fn = numpy_io.numpy_input_fn( - x={'features': features}, y=labels, batch_size=2, shuffle=False) - - with self.test_session(): - estimator = estimator_lib.Estimator( - model_fn=self.model_fn, - model_dir=tempfile.mkdtemp(), - params=self.params) - estimator.train(train_input_fn, steps=1) - - self.assertEqual(7.0, estimator.get_variable_value('c')) - - class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase): def model_fn(self, mode, features, labels, params): diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7c49cd00d16777872ad1211dfa1d1a3ac9ac1cee --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -0,0 +1,622 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Recurrent Neural Network estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import training_util + + +# The defaults are historical artifacts of the initial implementation, but seem +# reasonable choices. +_DEFAULT_LEARNING_RATE = 0.05 +_DEFAULT_CLIP_NORM = 5.0 + +_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell, + 'lstm': rnn_cell.BasicLSTMCell, + 'gru': rnn_cell.GRUCell} + +# Indicates no value was provided by the user to a kwarg. +USE_DEFAULT = object() + + +def _single_rnn_cell(num_units, cell_type): + cell_type = _CELL_TYPES.get(cell_type, cell_type) + if not cell_type or not issubclass(cell_type, rnn_cell.RNNCell): + raise ValueError('Supported cell types are {}; got {}'.format( + list(_CELL_TYPES.keys()), cell_type)) + return cell_type(num_units=num_units) + + +def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): + """Convenience function to create `rnn_cell_fn` for canned RNN Estimators. + + Args: + num_units: Iterable of integer number of hidden units per RNN layer. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. + + Returns: + A function that takes a single argument, an instance of + `tf.estimator.ModeKeys`, and returns an instance derived from + `tf.nn.rnn_cell.RNNCell`. + + Raises: + ValueError: If cell_type is not supported. + """ + def rnn_cell_fn(mode): + # Unused. Part of the rnn_cell_fn interface since user specified functions + # may need different behavior across modes (e.g. dropout). + del mode + cells = [_single_rnn_cell(n, cell_type) for n in num_units] + if len(cells) == 1: + return cells[0] + return rnn_cell.MultiRNNCell(cells) + return rnn_cell_fn + + +def _concatenate_context_input(sequence_input, context_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + +def _select_last_activations(activations, sequence_lengths): + """Selects the nth set of activations for each n in `sequence_length`. + + Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not + `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If + `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. + + Args: + activations: A `Tensor` with shape `[batch_size, padded_length, k]`. + sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`. + Returns: + A `Tensor` of shape `[batch_size, k]`. + """ + with ops.name_scope( + 'select_last_activations', values=[activations, sequence_lengths]): + activations_shape = array_ops.shape(activations) + batch_size = activations_shape[0] + padded_length = activations_shape[1] + output_units = activations_shape[2] + if sequence_lengths is None: + sequence_lengths = padded_length + start_indices = math_ops.to_int64( + math_ops.range(batch_size) * padded_length) + last_indices = start_indices + sequence_lengths - 1 + reshaped_activations = array_ops.reshape( + activations, [batch_size * padded_length, output_units]) + + last_activations = array_ops.gather(reshaped_activations, last_indices) + last_activations.set_shape([activations.shape[0], activations.shape[2]]) + return last_activations + + +def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, + context_feature_columns, input_layer_partitioner): + """Function builder for a rnn logit_fn. + + Args: + output_units: An int indicating the dimension of the logit layer. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell`. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. + context_feature_columns: An iterable containing the `FeatureColumn`s + that represent contextual input. + input_layer_partitioner: Partitioner for input layer. + + Returns: + A logit_fn (see below). + + Raises: + ValueError: If output_units is not an int. + """ + if not isinstance(output_units, int): + raise ValueError('output_units must be an int. Given type: {}'.format( + type(output_units))) + + def rnn_logit_fn(features, mode): + """Recurrent Neural Network logit_fn. + + Args: + features: This is the first item returned from the `input_fn` + passed to `train`, `evaluate`, and `predict`. This should be a + single `Tensor` or `dict` of same. + mode: Optional. Specifies if this training, evaluation or prediction. See + `ModeKeys`. + + Returns: + A `Tensor` representing the logits. + """ + with variable_scope.variable_scope( + 'sequence_input_layer', + values=tuple(six.itervalues(features)), + partitioner=input_layer_partitioner): + sequence_input, sequence_length = seq_fc.sequence_input_layer( + features=features, feature_columns=sequence_feature_columns) + summary.histogram('sequence_length', sequence_length) + + if context_feature_columns: + context_input = feature_column_lib.input_layer( + features=features, + feature_columns=context_feature_columns) + sequence_input = _concatenate_context_input(sequence_input, + context_input) + + cell = rnn_cell_fn(mode) + # Ignore output state. + rnn_outputs, _ = rnn.dynamic_rnn( + cell=cell, + inputs=sequence_input, + sequence_length=sequence_length, + dtype=dtypes.float32, + time_major=False) + last_activations = _select_last_activations(rnn_outputs, sequence_length) + + with variable_scope.variable_scope('logits', values=(rnn_outputs,)): + logits = core_layers.dense( + last_activations, + units=output_units, + activation=None, + kernel_initializer=init_ops.glorot_uniform_initializer()) + return logits + + return rnn_logit_fn + + +def _rnn_model_fn(features, + labels, + mode, + head, + rnn_cell_fn, + sequence_feature_columns, + context_feature_columns, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Recurrent Neural Net model_fn. + + Args: + features: dict of `Tensor` and `SparseTensor` objects returned from + `input_fn`. + labels: `Tensor` of shape [batch_size, 1] or [batch_size] with labels. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `head_lib._Head` instance. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell`. + sequence_feature_columns: Iterable containing `FeatureColumn`s that + represent sequential model inputs. + context_feature_columns: Iterable containing `FeatureColumn`s that + represent model inputs not associated with a specific timestep. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use the Adagrad + optimizer with a default learning rate of 0.05 and gradient clip norm of + 5.0. + input_layer_partitioner: Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Returns: + An `EstimatorSpec` instance. + + Raises: + ValueError: If mode or optimizer is invalid, or features has the wrong type. + """ + if not isinstance(features, dict): + raise ValueError('features should be a dictionary of `Tensor`s. ' + 'Given type: {}'.format(type(features))) + + # If user does not provide an optimizer instance, use the optimizer specified + # by the string with default learning rate and gradient clipping. + if not isinstance(optimizer, optimizer_lib.Optimizer): + optimizer = optimizers.get_optimizer_instance( + optimizer, learning_rate=_DEFAULT_LEARNING_RATE) + optimizer = extenders.clip_gradients_by_norm(optimizer, _DEFAULT_CLIP_NORM) + + num_ps_replicas = config.num_ps_replicas if config else 0 + partitioner = partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas) + with variable_scope.variable_scope( + 'rnn', + values=tuple(six.itervalues(features)), + partitioner=partitioner): + input_layer_partitioner = input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas, + min_slice_size=64 << 20)) + + logit_fn = _rnn_logit_fn_builder( + output_units=head.logits_dimension, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + input_layer_partitioner=input_layer_partitioner) + logits = logit_fn(features=features, mode=mode) + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizer.minimize( + loss, + global_step=training_util.get_global_step()) + + return head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + +def _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type): + """Assert arguments are valid and return rnn_cell_fn.""" + if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): + raise ValueError( + 'num_units and cell_type must not be specified when using rnn_cell_fn' + ) + if not rnn_cell_fn: + if cell_type == USE_DEFAULT: + cell_type = 'basic_rnn' + rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + return rnn_cell_fn + + +class RNNClassifier(estimator.Estimator): + """A classifier for TensorFlow RNN models. + + Trains a recurrent neural network model to classify instances into one of + multiple classes. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNClassifier( + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss is calculated by using softmax cross entropy. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + n_classes: Number of label classes. Defaults to 2, namely binary + classification. Must be > 1. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) + + if n_classes == 2: + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access + weight_column=weight_column, + label_vocabulary=label_vocabulary) + else: + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access + n_classes, weight_column=weight_column, + label_vocabulary=label_vocabulary) + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNClassifier, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +class RNNEstimator(estimator.Estimator): + """An Estimator for TensorFlow RNN models with user-specified head. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') + + # Or with custom RNN cell: + def rnn_cell_fn(mode): + cells = [ tf.contrib.rnn.LSTMCell(size) for size in [32, 16] ] + if mode == tf.estimator.ModeKeys.TRAIN: + cells = [ tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=0.5) + for cell in cells ] + return tf.contrib.rnn.MultiRNNCell(cells) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + rnn_cell_fn=rnn_cell_fn) + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if the head's `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. This specifies the model's + output and loss function to be optimized. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) + + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..959b40371aa5fa83a40af999cffade18e5b502e5 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -0,0 +1,1172 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rnn.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.contrib.estimator.python.estimator import rnn +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.canned import parsing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import monitored_session +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_util + + +# Names of variables created by BasicRNNCell model. +TOKEN_EMBEDDING_NAME = 'rnn/sequence_input_layer/input_layer/tokens_sequential_embedding/embedding_weights' +CELL_WEIGHTS_NAME = 'rnn/rnn/basic_rnn_cell/kernel' +CELL_BIAS_NAME = 'rnn/rnn/basic_rnn_cell/bias' +MULTI_CELL_WEIGHTS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/kernel' +MULTI_CELL_BIAS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/bias' +LOGITS_WEIGHTS_NAME = 'rnn/logits/dense/kernel' +LOGITS_BIAS_NAME = 'rnn/logits/dense/bias' + + +def _assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def create_checkpoint(rnn_weights, rnn_biases, logits_weights, logits_biases, + global_step, model_dir): + """Create checkpoint file with provided model weights. + + Args: + rnn_weights: Iterable of values of weights for the RNN cell. + rnn_biases: Iterable of values of biases for the RNN cell. + logits_weights: Iterable of values for matrix connecting RNN output to + logits. + logits_biases: Iterable of values for logits bias term. + global_step: Initial global step to save in checkpoint. + model_dir: Directory into which checkpoint is saved. + """ + model_weights = {} + model_weights[CELL_WEIGHTS_NAME] = rnn_weights + model_weights[CELL_BIAS_NAME] = rnn_biases + model_weights[LOGITS_WEIGHTS_NAME] = logits_weights + model_weights[LOGITS_BIAS_NAME] = logits_biases + + with ops.Graph().as_default(): + # Create model variables. + for k, v in six.iteritems(model_weights): + variables_lib.Variable(v, name=k, dtype=dtypes.float32) + + # Create non-model variables. + global_step_var = training_util.create_global_step() + assign_op = global_step_var.assign(global_step) + + # Initialize vars and save checkpoint. + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=model_dir) as sess: + sess.run(assign_op) + + +class RNNLogitFnTest(test.TestCase): + """Tests correctness of logits calculated from _rnn_logit_fn_builder.""" + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_logits(self, mode, rnn_units, logits_dimension, features_fn, + sequence_feature_columns, context_feature_columns, + expected_logits): + """Tests that the expected logits are calculated.""" + with ops.Graph().as_default(): + # Global step needed for MonitoredSession, which is in turn used to + # explicitly set variable weights through a checkpoint. + training_util.create_global_step() + # Use a variable scope here with 'rnn', emulating the rnn model_fn, so + # the checkpoint naming is shared. + with variable_scope.variable_scope('rnn'): + input_layer_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=0, min_slice_size=64 << 20)) + logit_fn = rnn._rnn_logit_fn_builder( + output_units=logits_dimension, + rnn_cell_fn=rnn._make_rnn_cell_fn(rnn_units), + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + input_layer_partitioner=input_layer_partitioner) + # Features are constructed within this function, otherwise the Tensors + # containing the features would be defined outside this graph. + logits = logit_fn(features=features_fn(), mode=mode) + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=self._model_dir) as sess: + self.assertAllClose(expected_logits, sess.run(logits), atol=1e-4) + + def testOneDimLogits(self): + """Tests one-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10]], [[5]]] + initial_state = [0, 0] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)]] + = [[0.53, -0.37]] + logits = [[-1*0.53 - 1*0.37 + 0.3]] = [[-0.6033]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033]]) + + def testMultiDimLogits(self): + """Tests multi-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10]], [[5]]] + initial_state = [0, 0] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)]] + = [[0.53, -0.37]] + logits = [[-1*0.53 - 1*0.37 + 0.3], + [0.5*0.53 + 0.3*0.37 + 0.4], + [0.2*0.53 - 0.1*0.37 + 0.5] + = [[-0.6033, 0.7777, 0.5698]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=3, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033, 0.7777, 0.5698]]) + + def testMultiExampleMultiDim(self): + """Tests multiple examples and multi-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10], [5]], [[2], [7]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + .2*0 + .3*0 +.2), + tanh(-.2*2 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91], [0.38, 0.10]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)], + [tanh(.1*7 + .2*.38 + .3*.10 +.2), + tanh(-.2*7 - .3*.38 - .4*.10 +.5)]] + = [[0.53, -0.37], [0.76, -0.78] + logits = [[-1*0.53 - 1*0.37 + 0.3, + 0.5*0.53 + 0.3*0.37 + 0.4, + 0.2*0.53 - 0.1*0.37 + 0.5], + [-1*0.76 - 1*0.78 + 0.3, + 0.5*0.76 +0.3*0.78 + 0.4, + 0.2*0.76 -0.1*0.78 + 0.5]] + = [[-0.6033, 0.7777, 0.5698], [-1.2473, 1.0170, 0.5745]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,)) + ] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=3, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033, 0.7777, 0.5698], + [-1.2473, 1.0170, 0.5745]]) + + def testMultiExamplesDifferentLength(self): + """Tests multiple examples with different lengths. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10], [5]], [[2], [0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + .2*0 + .3*0 +.2), + tanh(-.2*2 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91], [0.38, 0.10]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)], + []] + = [[0.53, -0.37], []] + logits = [[-1*0.53 - 1*0.37 + 0.3], + [-1*0.38 + 1*0.10 + 0.3]] + = [[-0.6033], [0.0197]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033], [0.0197]]) + + def testMultiExamplesWithContext(self): + """Tests multiple examples with context features. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10, -0.5], [5, -0.5]], [[2, 0.8], [0, 0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 - 1*.5 + .2*0 + .3*0 +.2), + tanh(-.2*10 - 0.9*.5 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + 1*.8 + .2*0 + .3*0 +.2), + tanh(-.2*2 + .9*.8 - .3*0 - .4*0 +.5)]] + = [[0.60, -0.96], [0.83, 0.68]] + rnn_output_timestep_2 = [[tanh(.1*5 - 1*.5 + .2*.60 - .3*.96 +.2), + tanh(-.2*5 - .9*.5 - .3*.60 + .4*.96 +.5)], + []] + = [[0.03, -0.63], []] + logits = [[-1*0.03 - 1*0.63 + 0.3], + [-1*0.83 + 1*0.68 + 0.3]] + = [[-0.3662], [0.1414]] + """ + base_global_step = 100 + create_checkpoint( + # Context features weights are inserted between input and state weights. + rnn_weights=[[.1, -.2], [1., 0.9], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + 'context': [[-0.5], [0.8]], + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [fc.numeric_column('context', shape=(1,))] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.3662], [0.1414]]) + + def testMultiExamplesMultiFeatures(self): + """Tests examples with multiple sequential feature columns. + + Intermediate values are rounded for ease in reading. + input_layer = [[[1, 0, 10], [0, 1, 5]], [[1, 0, 2], [0, 0, 0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.5*1 + 1*0 + .1*10 + .2*0 + .3*0 +.2), + tanh(-.5*1 - 1*0 - .2*10 - .3*0 - .4*0 +.5)], + [tanh(.5*1 + 1*0 + .1*2 + .2*0 + .3*0 +.2), + tanh(-.5*1 - 1*0 - .2*2 - .3*0 - .4*0 +.5)]] + = [[0.94, -0.96], [0.72, -0.38]] + rnn_output_timestep_2 = [[tanh(.5*0 + 1*1 + .1*5 + .2*.94 - .3*.96 +.2), + tanh(-.5*0 - 1*1 - .2*5 - .3*.94 + .4*.96 +.5)], + []] + = [[0.92, -0.88], []] + logits = [[-1*0.92 - 1*0.88 + 0.3], + [-1*0.72 - 1*0.38 + 0.3]] + = [[-1.5056], [-0.7962]] + """ + base_global_step = 100 + create_checkpoint( + # FeatureColumns are sorted alphabetically, so on_sale weights are + # inserted before price. + rnn_weights=[[.5, -.5], [1., -1.], [.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + 'on_sale': + sparse_tensor.SparseTensor( + values=[0, 1, 0], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + } + + price_column = seq_fc.sequence_numeric_column('price', shape=(1,)) + on_sale_column = fc.indicator_column( + seq_fc.sequence_categorical_column_with_identity( + 'on_sale', num_buckets=2)) + sequence_feature_columns = [price_column, on_sale_column] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-1.5056], [-0.7962]]) + + +class RNNClassifierTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _assert_checkpoint( + self, n_classes, input_units, cell_units, expected_global_step): + + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + # RNN Cell variables. + if len(cell_units) > 1: + for i, cell_unit in enumerate(cell_units): + self.assertEqual([input_units + cell_unit, cell_unit], + shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i]) + self.assertEqual([cell_unit], + shapes[MULTI_CELL_BIAS_NAME_PATTERN % i]) + input_units = cell_unit + elif len(cell_units) == 1: + self.assertEqual([input_units + cell_unit, cell_unit], + shapes[CELL_WEIGHTS_NAME]) + self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME]) + + # Logits variables. + logits_dimension = n_classes if n_classes > 2 else 1 + self.assertEqual([cell_units[-1], logits_dimension], + shapes[LOGITS_WEIGHTS_NAME]) + self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME]) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s/part_0:0' % CELL_BIAS_NAME, + '%s/part_0:0' % CELL_WEIGHTS_NAME, + '%s/part_0:0' % LOGITS_BIAS_NAME, + '%s/part_0:0' % LOGITS_WEIGHTS_NAME, + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = _assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def testConflictingRNNCellFn(self): + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + cell_units = [4, 2] + + with self.assertRaisesRegexp( + ValueError, + 'num_units and cell_type must not be specified when using rnn_cell_fn'): + rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=lambda x: x, + num_units=cell_units) + + with self.assertRaisesRegexp( + ValueError, + 'num_units and cell_type must not be specified when using rnn_cell_fn'): + rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=lambda x: x, + cell_type='lstm') + + def _testFromScratchWithDefaultOptimizer(self, n_classes): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat'], + indices=[[0, 0], [0, 1], [0, 2]], + dense_shape=[1, 3]), + }, [[1]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + + cell_units = [4, 2] + est = rnn.RNNClassifier( + sequence_feature_columns=[embed], + num_units=cell_units, + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def testBinaryClassFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=2) + + def testMultiClassFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=4) + + def testFromScratchWithCustomRNNCellFn(self): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat'], + indices=[[0, 0], [0, 1], [0, 2]], + dense_shape=[1, 3]), + }, [[1]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + cell_units = [4, 2] + n_classes = 2 + + def rnn_cell_fn(mode): + del mode # unused + cells = [rnn_cell.BasicRNNCell(num_units=n) for n in cell_units] + return rnn_cell.MultiRNNCell(cells) + + est = rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=rnn_cell_fn, + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def _testExampleWeight(self, n_classes): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat', 'dog', 'barked'], + indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + dense_shape=[2, 3]), + 'w': [[1], [2]], + }, [[1], [0]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + + cell_units = [4, 2] + est = rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=[embed], + n_classes=n_classes, + weight_column='w', + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def testBinaryClassWithExampleWeight(self): + self._testExampleWeight(n_classes=2) + + def testMultiClassWithExampleWeight(self): + self._testExampleWeight(n_classes=4) + + def testBinaryClassFromCheckpoint(self): + initial_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=initial_global_step, + model_dir=self._model_dir) + + def train_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + }, [[0], [1]] + + # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics. + # See that test for loss calculation. + mock_optimizer = self._mock_optimizer(expected_loss=1.119661) + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + est.train(input_fn=train_input_fn, steps=10) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + def testMultiClassFromCheckpoint(self): + initial_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=initial_global_step, + model_dir=self._model_dir) + + def train_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }, [[0], [1]] + + # Uses same checkpoint and examples as testMultiClassEvaluationMetrics. + # See that test for loss calculation. + mock_optimizer = self._mock_optimizer(expected_loss=2.662932) + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + est.train(input_fn=train_input_fn, steps=10) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + +def sorted_key_dict(unsorted_dict): + return {k: unsorted_dict[k] for k in sorted(unsorted_dict)} + + +class RNNClassifierEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def testBinaryClassEvaluationMetrics(self): + global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=global_step, + model_dir=self._model_dir) + + def eval_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + }, [[0], [1]] + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + model_dir=self._model_dir) + eval_metrics = est.evaluate(eval_input_fn, steps=1) + + # Uses identical numbers to testMultiExamplesWithDifferentLength. + # See that test for logits calculation. + # logits = [[-0.603282], [0.019719]] + # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]] + # loss = -label * ln(p) - (1 - label) * ln(1 - p) + # = [[0.436326], [0.683335]] + expected_metrics = { + ops.GraphKeys.GLOBAL_STEP: global_step, + metric_keys.MetricKeys.LOSS: 1.119661, + metric_keys.MetricKeys.LOSS_MEAN: 0.559831, + metric_keys.MetricKeys.ACCURACY: 1.0, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + # With default threshold of 0.5, the model is a perfect classifier. + metric_keys.MetricKeys.RECALL: 1.0, + metric_keys.MetricKeys.PRECISION: 1.0, + # Positive example is scored above negative, so AUC = 1.0. + metric_keys.MetricKeys.AUC: 1.0, + metric_keys.MetricKeys.AUC_PR: 1.0, + } + self.assertAllClose( + sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) + + def testMultiClassEvaluationMetrics(self): + global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=global_step, + model_dir=self._model_dir) + + def eval_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }, [[0], [1]] + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + model_dir=self._model_dir) + eval_metrics = est.evaluate(eval_input_fn, steps=1) + + # Uses identical numbers to testMultiExampleMultiDim. + # See that test for logits calculation. + # logits = [[-0.603282, 0.777708, 0.569756], + # [-1.247356, 1.017018, 0.574481]] + # logits_exp = exp(logits) / (1 + exp(logits)) + # = [[0.547013, 2.176468, 1.767836], + # [0.287263, 2.764937, 1.776208]] + # softmax_probabilities = logits_exp / logits_exp.sum() + # = [[0.121793, 0.484596, 0.393611], + # [0.059494, 0.572639, 0.367866]] + # loss = -1. * log(softmax[label]) + # = [[2.105432], [0.557500]] + expected_metrics = { + ops.GraphKeys.GLOBAL_STEP: global_step, + metric_keys.MetricKeys.LOSS: 2.662932, + metric_keys.MetricKeys.LOSS_MEAN: 1.331466, + metric_keys.MetricKeys.ACCURACY: 0.5, + } + + self.assertAllClose( + sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) + + +class RNNClassifierPredictionTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def testBinaryClassPredictions(self): + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=0, + model_dir=self._model_dir) + + def predict_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + label_vocabulary = ['class_0', 'class_1'] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + label_vocabulary=label_vocabulary, + model_dir=self._model_dir) + # Uses identical numbers to testOneDimLogits. + # See that test for logits calculation. + # logits = [-0.603282] + # logistic = exp(-0.6033) / (1 + exp(-0.6033)) = [0.353593] + # probabilities = [0.646407, 0.353593] + # class_ids = argmax(probabilities) = [0] + predictions = next(est.predict(predict_input_fn)) + self.assertAllClose([-0.603282], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose([0.353593], + predictions[prediction_keys.PredictionKeys.LOGISTIC]) + self.assertAllClose( + [0.646407, 0.353593], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([0], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertEqual([b'class_0'], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + def testMultiClassPredictions(self): + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=0, + model_dir=self._model_dir) + + def predict_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + label_vocabulary = ['class_0', 'class_1', 'class_2'] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + label_vocabulary=label_vocabulary, + model_dir=self._model_dir) + # Uses identical numbers to testMultiDimLogits. + # See that test for logits calculation. + # logits = [-0.603282, 0.777708, 0.569756] + # logits_exp = exp(logits) = [0.547013, 2.176468, 1.767836] + # softmax_probabilities = logits_exp / logits_exp.sum() + # = [0.121793, 0.484596, 0.393611] + # class_ids = argmax(probabilities) = [1] + predictions = next(est.predict(predict_input_fn)) + self.assertAllClose([-0.603282, 0.777708, 0.569756], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + [0.121793, 0.484596, 0.393611], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([1], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertEqual([b'class_1'], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + +class BaseRNNClassificationIntegrationTest(object): + + def __init__(self, _create_estimator_fn): + self._create_estimator_fn = _create_estimator_fn + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, feature_columns, train_input_fn, eval_input_fn, + predict_input_fn, n_classes, batch_size): + cell_units = [4, 2] + est = self._create_estimator_fn(feature_columns, n_classes, cell_units, + self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUATE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + # EXPORT + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def testNumpyInputFn(self): + """Tests complete flow with numpy_input_fn.""" + n_classes = 3 + batch_size = 10 + words = ['dog', 'cat', 'bird', 'the', 'a', 'sat', 'flew', 'slept'] + # Numpy only supports dense input, so all examples will have same length. + # TODO(b/73160931): Update test when support for prepadded data exists. + sequence_length = 3 + + features = [] + for _ in range(batch_size): + sentence = random.sample(words, sequence_length) + features.append(sentence) + + x_data = np.array(features) + y_data = np.random.randint(n_classes, size=batch_size) + + train_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + batch_size=batch_size, + shuffle=False) + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + + self._test_complete_flow( + feature_columns=feature_columns, + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + n_classes=n_classes, + batch_size=batch_size) + + def testParseExampleInputFn(self): + """Tests complete flow with input_fn constructed from parse_example.""" + n_classes = 3 + batch_size = 10 + words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept'] + + _, examples_file = tempfile.mkstemp() + writer = python_io.TFRecordWriter(examples_file) + for _ in range(batch_size): + sequence_length = random.randint(1, len(words)) + sentence = random.sample(words, sequence_length) + label = random.randint(0, n_classes - 1) + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'tokens': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=sentence)), + 'label': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=[label])), + })) + writer.write(example.SerializeToString()) + writer.close() + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) + + def _train_input_fn(): + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec) + return dataset.map(lambda features: (features, features.pop('label'))) + def _eval_input_fn(): + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + return dataset.map(lambda features: (features, features.pop('label'))) + def _predict_input_fn(): + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + def features_fn(features): + features.pop('label') + return features + return dataset.map(features_fn) + + self._test_complete_flow( + feature_columns=feature_columns, + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + n_classes=n_classes, + batch_size=batch_size) + + +def _rnn_classifier_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=feature_columns, + n_classes=n_classes, + model_dir=model_dir) + + +class RNNClassifierIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_classifier_fn) + + +def _rnn_estimator_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNEstimator( + head=head_lib.multi_class_head(n_classes=n_classes), + num_units=cell_units, + sequence_feature_columns=feature_columns, + model_dir=model_dir) + + +class RNNEstimatorIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_estimator_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 180f1b68f3b56113dfbbfc100bd04efc3bb8b31f..effec42f028fe472593a8d06e15a0831346d6f50 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -66,6 +66,7 @@ tf_custom_op_py_library( "//tensorflow/python:variables", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -214,6 +215,7 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:sparse_tensor", ], + shard_count = 4, ) # Estimators tests @@ -223,7 +225,10 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = ["notsan"], # b/67512932 + tags = [ + "nomac", # b/73741358 + "notsan", # b/67512932 + ], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", @@ -238,6 +243,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:training", "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -342,16 +348,3 @@ cuda_py_test( ], main = "python/kernel_tests/masked_matmul_benchmark.py", ) - -# All files -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD index bbe842bd5ccc7357805adda1df42ba8799fcd8f2..363baa121ab3854a802ca3606e35597d31b35a57 100644 --- a/tensorflow/contrib/factorization/examples/BUILD +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -21,14 +21,3 @@ tf_py_test( ], tags = ["notsan"], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index 44eab56011dad2f6fbe843b3569b4acc5c5e542a..ea8b9a17a27093cb57564861815edd6ecb18a014 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -67,14 +67,3 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index dd61f59585aee2e0245cfd6797b313b972c19bc5..025534d540bb82cdb87bb2977d08dfa4f02f1bc8 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -32,6 +32,7 @@ #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" @@ -353,7 +354,7 @@ class NearestNeighborsOp : public OpKernel { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); const int64 num_threads = worker_threads.num_threads; // This kernel might be configured to use fewer than the total number of - // available CPUs on the host machine. To avoid descructive interference + // available CPUs on the host machine. To avoid destructive interference // with other jobs running on the host machine, we must only use a fraction // of total available L3 cache. Unfortunately, we cannot query the host // machine to get the number of physical CPUs. So, we use a fixed per-CPU diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 23137e0a973c0bdd2cdbd97159f7fd310178bf54..84e80791f4991ad2b67d0a00ee1e00cf0d0daadc 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -41,11 +41,12 @@ from tensorflow.python.platform import resource_loader _clustering_ops = loader.load_op_library( resource_loader.get_path_to_datafile('_clustering_ops.so')) -# Euclidean distance between vectors U and V is defined as ||U - V||_F which is -# the square root of the sum of the absolute squares of the elements difference. +# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) +# which is the square root of the sum of the absolute squares of the elements +# difference. SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' # Cosine distance between vectors U and V is defined as -# 1 - (U \dot V) / (||U||_F ||V||_F) +# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\) COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' @@ -472,8 +473,8 @@ class KMeans(object): # Locally compute the sum of inputs mapped to each id. # For a cluster with old cluster value x, old count n, and with data # d_1,...d_k newly assigned to it, we recompute the new value as - # x += (sum_i(d_i) - k * x) / (n + k). - # Compute sum_i(d_i), see comment above. + # \\(x += (sum_i(d_i) - k * x) / (n + k)\\). + # Compute \\(sum_i(d_i)\\), see comment above. cluster_center_updates = math_ops.unsorted_segment_sum( inp, unique_idx, num_unique_cluster_idx) # Shape to enable broadcasting count_updates and learning_rate to inp. diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 054888e734086c153f7af59f4548d4d20abab813..8f73274c2a0ebbdc41ce6a647a8a5650694c9a23 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -51,9 +51,9 @@ class WALSModel(object): r"""A model for Weighted Alternating Least Squares matrix factorization. It minimizes the following loss function over U, V: - \\( - \|\sqrt W \odot (A - U V^T) \|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) - )\\ + $$ + \|\sqrt W \odot (A - U V^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) + $$ where, A: input matrix, W: weight matrix. Note that the (element-wise) square root of the weights @@ -61,12 +61,12 @@ class WALSModel(object): U, V: row_factors and column_factors matrices, \\(\lambda)\\: regularization. Also we assume that W is of the following special form: - \\( W_{ij} = W_0 + R_i * C_j )\\ if \\(A_{ij} \ne 0)\\, - \\(W_{ij} = W_0)\\ otherwise. + \\( W_{ij} = W_0 + R_i * C_j \\) if \\(A_{ij} \ne 0\\), + \\(W_{ij} = W_0\\) otherwise. where, - \\(W_0)\\: unobserved_weight, - \\(R_i)\\: row_weights, - \\(C_j)\\: col_weights. + \\(W_0\\): unobserved_weight, + \\(R_i\\): row_weights, + \\(C_j\\): col_weights. Note that the current implementation supports two operation modes: The default mode is for the condition where row_factors and col_factors can individually @@ -82,14 +82,15 @@ class WALSModel(object): normalized as follows: _, _, unregularized_loss, regularization, sum_weights = update_row_factors(sp_input) - if sp_input contains the rows {A_i, i \in I}, and the input matrix A has n - total rows, then the minibatch loss = unregularized_loss + regularization is - \\( + if sp_input contains the rows \\({A_i, i \in I}\\), and the input matrix A + has n total rows, then the minibatch loss = unregularized_loss + + regularization is + $$ (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2 - )\\ + $$ The sum_weights tensor contains the normalized sum of weights - sum(W_I) * n / |I|. + \\(sum(W_I) * n / |I|\\). A typical usage example (pseudocode): @@ -106,7 +107,7 @@ class WALSModel(object): # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init - # To be run once per interation sweep before the row(column) update + # To be run once per iteration sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. @@ -118,9 +119,9 @@ class WALSModel(object): init_row_update_op = model.initialize_row_update_op init_col_update_op = model.initialize_col_update_op - # Ops to upate row(column). This can either take the entire sparse tensor - # or slices of sparse tensor. For distributed trainer, each trainer - # handles just part of the matrix. + # Ops to update row(column). This can either take the entire sparse + # tensor or slices of sparse tensor. For distributed trainer, each + # trainer handles just part of the matrix. _, row_update_op, unreg_row_loss, row_reg, _ = model.update_row_factors( sp_input=matrix_slices_from_queue_for_worker_shard) row_loss = unreg_row_loss + row_reg @@ -196,7 +197,8 @@ class WALSModel(object): row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -220,10 +222,10 @@ class WALSModel(object): in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of inner lists matching the number of row factor shards and the elements in each inner list are the weights for the rows of the corresponding row - factor shard. In this case, w_ij = unonbserved_weight + + factor shard. In this case, w_ij = unobserved_weight + row_weights[i] * col_weights[j]. - If this is a single non-negative real number, this value is used for - all row weights and w_ij = unobserved_weight + row_weights * + all row weights and \\(w_ij\\) = unobserved_weight + row_weights * col_weights[j]. Note that it is allowed to have row_weights as a list while col_weights a single number or vice versa. @@ -238,6 +240,8 @@ class WALSModel(object): weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -250,25 +254,46 @@ class WALSModel(object): regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") - self._row_update_prep_gramian = self._prepare_gramian( - self._col_factors, self._col_gramian) - self._col_update_prep_gramian = self._prepare_gramian( - self._row_factors, self._row_gramian) - self._create_transient_vars() + with ops.name_scope("row_prepare_gramian"): + self._row_update_prep_gramian = self._prepare_gramian( + self._col_factors, self._col_gramian) + with ops.name_scope("col_prepare_gramian"): + self._col_update_prep_gramian = self._prepare_gramian( + self._row_factors, self._row_gramian) + with ops.name_scope("transient_vars"): + self._create_transient_vars() @property def row_factors(self): @@ -435,7 +460,7 @@ class WALSModel(object): gramian: Variable storing the gramian calculated from the factors. Returns: - A op that updates the gramian with the calcuated value from the factors. + An op that updates the gramian with the calculated value from the factors. """ partial_gramians = [] for f in factors: @@ -564,7 +589,7 @@ class WALSModel(object): Note that specifically this initializes the cache of the row and column weights on workers when `use_factors_weights_cache` is True. In this case, - if these weights are being calcualted and reset after the object is created, + if these weights are being calculated and reset after the object is created, it is important to ensure this ops is run afterwards so the cache reflects the correct values. """ @@ -665,18 +690,18 @@ class WALSModel(object): factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization - term. If sp_input contains the rows {A_{i, :}, i \in I}, and the input - matrix A has n total rows, then the unregularized loss is: - (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I| + term. If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the + input matrix A has n total rows, then the unregularized loss is: + \\(\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I|\\) The total loss is unregularized_loss + regularization. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. - If sp_input contains the rows {A_{i, :}, i \in I}, and the input matrix - A has n total rows, then the regularization term is: - \lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2. + If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the input + matrix A has n total rows, then the regularization term is: + \\(\lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2\\). sum_weights: The sum of the weights W_I corresponding to sp_input, - normalized by a factor of n / |I|. The root weighted squared error is: - \sqrt(unregularized_loss / sum_weights). + normalized by a factor of \\(n / |I|\\). The root weighted squared + error is: \sqrt(unregularized_loss / sum_weights). """ return self._process_input_helper( True, sp_input=sp_input, transpose_input=transpose_input) @@ -698,18 +723,18 @@ class WALSModel(object): factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization - term. If sp_input contains the columns {A_{:, j}, j \in J}, and the - input matrix A has m total columns, then the unregularized loss is: - (\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I| + term. If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and + the input matrix A has m total columns, then the unregularized loss is: + \\(\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I|\\) The total loss is unregularized_loss + regularization. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. - If sp_input contains the columns {A_{:, j}, j \in J}, and the input - matrix A has m total columns, then the regularization term is: - \lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2. + If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and the + input matrix A has m total columns, then the regularization term is: + \\(\lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2\\). sum_weights: The sum of the weights W_J corresponding to sp_input, - normalized by a factor of m / |J|. The root weighted squared error is: - \sqrt(unregularized_loss / sum_weights). + normalized by a factor of \\(m / |J|\\). The root weighted squared + error is: \sqrt(unregularized_loss / sum_weights). """ return self._process_input_helper( False, sp_input=sp_input, transpose_input=transpose_input) @@ -720,8 +745,8 @@ class WALSModel(object): projection_weights=None): """Projects the row factors. - This computes the row embedding u_i for an observed row a_i by solving - one iteration of the update equations. + This computes the row embedding \\(u_i\\) for an observed row \\(a_i\\) by + solving one iteration of the update equations. Args: sp_input: A SparseTensor representing a set of rows. Please note that the @@ -753,8 +778,8 @@ class WALSModel(object): projection_weights=None): """Projects the column factors. - This computes the column embedding v_j for an observed column a_j by solving - one iteration of the update equations. + This computes the column embedding \\(v_j\\) for an observed column + \\(a_j\\) by solving one iteration of the update equations. Args: sp_input: A SparseTensor representing a set of columns. Please note that @@ -938,7 +963,7 @@ class WALSModel(object): loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by - # computing the product for (i, j) in loss_sp_input.indices. + # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index c8137339155ef1da8ee53967eea84a550f12ecbc..bb5140aeb3bf0238ca7cb52067ea6328dd1736d5 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -210,7 +210,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -283,8 +283,8 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 3 column feature vectors. - # This is expected to reprodue the same column factors in the model as the - # weights and feature vectors are identical to that used in model + # This is expected to reproduce the same column factors in the model as + # the weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, @@ -385,7 +385,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -462,8 +462,8 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 2 column feature vectors. - # This is expected to reprodue the same column factors in the model as the - # weights and feature vectors are identical to that used in model + # This is expected to reproduce the same column factors in the model as + # the weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index 98d6434f4752b224201e38bed05ccd14428a758b..e076631bc16fd379a2ad31af9055a7388d98c7ca 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -54,10 +54,10 @@ def _covariance(x, diag): diagonal matrix just the diagonal is returned. """ num_points = math_ops.to_float(array_ops.shape(x)[0]) - x -= math_ops.reduce_mean(x, 0, keep_dims=True) + x -= math_ops.reduce_mean(x, 0, keepdims=True) if diag: cov = math_ops.reduce_sum( - math_ops.square(x), 0, keep_dims=True) / (num_points - 1) + math_ops.square(x), 0, keepdims=True) / (num_points - 1) else: cov = math_ops.matmul(x, x, transpose_a=True) / (num_points - 1) return cov @@ -280,7 +280,7 @@ class GmmAlgorithm(object): self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): - """Defines the full covariance probabilties per example in a class. + """Defines the full covariance probabilities per example in a class. Updates a matrix with dimension num_examples X num_classes. @@ -313,7 +313,7 @@ class GmmAlgorithm(object): # TODO(xavigonzalvo): look into alternatives to log for # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( - math_ops.log(self._covs + 1e-3), 1, keep_dims=True) + math_ops.log(self._covs + 1e-3), 1, keepdims=True) diff = shard - self._means x2 = math_ops.square(diff) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) @@ -344,21 +344,21 @@ class GmmAlgorithm(object): def _define_prior_log_prob_operation(self, shard_id): """Computes the prior probability of all samples. - Updates a vector where each item is the prior probabibility of an + Updates a vector where each item is the prior probability of an input example. Args: shard_id: id of current shard_id. """ self._prior_probs[shard_id] = math_ops.reduce_logsumexp( - self._probs[shard_id], axis=1, keep_dims=True) + self._probs[shard_id], axis=1, keepdims=True) def _define_expectation_operation(self, shard_id): # Shape broadcasting. probs = array_ops.expand_dims(self._probs[shard_id], 0) # Membership weights are computed as: - # w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)} - # {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)} + # $$w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}$$ + # $$ {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}$$ # where "i" is the i-th example, "k" is the k-th mixture, theta are # the model parameters and y_i the observations. # These are defined for each shard. @@ -375,7 +375,7 @@ class GmmAlgorithm(object): """ # Soft assignment of each data point to each of the two clusters. self._points_in_k[shard_id] = math_ops.reduce_sum( - self._w[shard_id], 0, keep_dims=True) + self._w[shard_id], 0, keepdims=True) # Partial means. w_mul_x = array_ops.expand_dims( math_ops.matmul( @@ -397,7 +397,7 @@ class GmmAlgorithm(object): # Compute the effective number of data points assigned to component k. with ops.control_dependencies(self._w): points_in_k = array_ops.squeeze( - math_ops.add_n(self._points_in_k), squeeze_dims=[0]) + math_ops.add_n(self._points_in_k), axis=[0]) # Update alpha. if 'w' in self._params: final_points_in_k = points_in_k / num_batches @@ -454,7 +454,7 @@ class GmmAlgorithm(object): for shard_id, prior_probs in enumerate(self._prior_probs): op.append(prior_probs + math_ops.log(self._w[shard_id])) self._scores = array_ops.squeeze( - math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0) + math_ops.reduce_logsumexp(op, axis=2, keepdims=True), axis=0) def gmm(inp, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 00a4734eb6d89cd02484f1c5161366377cc71208..4fc9c96e9d0a317ef757d5e1bb6563ed7c8832af 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -210,7 +210,7 @@ class GMMTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index c861cfff544a78617aa1ace730b50c094cf16330..9ffdd3ba5e8ac496533d0207f2b6846dbc92bc89 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,6 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -61,8 +62,8 @@ class _LossRelativeChangeHook(session_run_hook.SessionRunHook): loss = run_values.results assert loss is not None if self._prev_loss: - relative_change = (abs(loss - self._prev_loss) / - (1 + abs(self._prev_loss))) + relative_change = ( + abs(loss - self._prev_loss) / (1 + abs(self._prev_loss))) if relative_change < self._tolerance: run_context.request_stop() self._prev_loss = loss @@ -105,24 +106,32 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): logging.info(e) -def _parse_tensor_or_dict(features): +def _parse_features_if_necessary(features, feature_columns): """Helper function to convert the input points into a usable format. Args: - features: The input points. + features: The input features. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column instances + that can be passed to `tf.feature_column.input_layer`. If this is None, + all features will be used. Returns: - If `features` is a dict of `k` features, each of which is a vector of `n` - scalars, the return value is a Tensor of shape `(n, k)` representing `n` - input points, where the items in the `k` dimension are sorted - lexicographically by `features` key. If `features` is not a dict, it is - returned unmodified. + If `features` is a dict of `k` features (optionally filtered by + `feature_columns`), each of which is a vector of `n` scalars, the return + value is a Tensor of shape `(n, k)` representing `n` input points, where the + items in the `k` dimension are sorted lexicographically by `features` key. + If `features` is not a dict, it is returned unmodified. """ - if isinstance(features, dict): - keys = sorted(features.keys()) - with ops.colocate_with(features[keys[0]]): - features = array_ops.concat([features[k] for k in keys], axis=1) - return features + if not isinstance(features, dict): + return features + + if feature_columns: + return fc.input_layer(features, feature_columns) + + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + return array_ops.concat([features[k] for k in keys], axis=1) class _ModelFn(object): @@ -130,7 +139,8 @@ class _ModelFn(object): def __init__(self, num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance): + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns): self._num_clusters = num_clusters self._initial_clusters = initial_clusters self._distance_metric = distance_metric @@ -139,6 +149,7 @@ class _ModelFn(object): self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries self._relative_tolerance = relative_tolerance + self._feature_columns = feature_columns def model_fn(self, features, mode, config): """Model function for the estimator. @@ -166,7 +177,7 @@ class _ModelFn(object): # input_points is a single Tensor. Therefore, the sharding functionality # in clustering_ops is unused, and some of the values below are lists of a # single item. - input_points = _parse_tensor_or_dict(features) + input_points = _parse_features_if_necessary(features, self._feature_columns) # Let N = the number of input_points. # all_distances: A list of one matrix of shape (N, num_clusters). Each value @@ -233,7 +244,57 @@ class _ModelFn(object): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + Example: + ``` + import numpy as np + import tensorflow as tf + + num_points = 100 + dimensions = 2 + points = np.random.uniform(0, 1000, [num_points, dimensions]) + + def input_fn(): + return tf.train.limit_epochs( + tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1) + + num_clusters = 5 + kmeans = tf.contrib.factorization.KMeansClustering( + num_clusters=num_clusters, use_mini_batch=False) + + # train + num_iterations = 10 + previous_centers = None + for _ in xrange(num_iterations): + kmeans.train(input_fn) + cluster_centers = kmeans.cluster_centers() + if previous_centers is not None: + print 'delta:', cluster_centers - previous_centers + previous_centers = cluster_centers + print 'score:', kmeans.score(input_fn) + print 'cluster centers:', cluster_centers + + # map the input points to their clusters + cluster_indices = list(kmeans.predict_cluster_index(input_fn)) + for i, point in enumerate(points): + cluster_index = cluster_indices[i] + center = cluster_centers[cluster_index] + print 'point:', point, 'is in cluster', cluster_index, 'centered at', center + ``` + + The `SavedModel` saved by the `export_savedmodel` method does not include the + cluster centers. However, the cluster centers may be retrieved by the + latest checkpoint saved during training. Specifically, + ``` + kmeans.cluster_centers() + ``` + is equivalent to + ``` + tf.train.load_variable( + kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME) + ``` + """ # Valid values for the distance_metric constructor argument. SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE @@ -253,6 +314,9 @@ class KMeansClustering(estimator.Estimator): CLUSTER_INDEX = 'cluster_index' ALL_DISTANCES = 'all_distances' + # Variable name used by cluster_centers(). + CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME + def __init__(self, num_clusters, model_dir=None, @@ -263,7 +327,8 @@ class KMeansClustering(estimator.Estimator): mini_batch_steps_per_iteration=1, kmeans_plus_plus_num_retries=2, relative_tolerance=None, - config=None): + config=None, + feature_columns=None): """Creates an Estimator for running KMeans training and inference. This Estimator implements the following variants of the K-means algorithm: @@ -309,11 +374,11 @@ class KMeansClustering(estimator.Estimator): than `num_clusters`, a TensorFlow runtime error occurs. distance_metric: The distance metric used for clustering. One of: * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance - between vectors `u` and `v` is defined as `||u - v||_2` which is - the square root of the sum of the absolute squares of the elements' - difference. + between vectors `u` and `v` is defined as \\(||u - v||_2\\) + which is the square root of the sum of the absolute squares of + the elements' difference. * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors - `u` and `v` is defined as `1 - (u . v) / (||u||_2 ||v||_2)`. + `u` and `v` is defined as \\(1 - (u . v) / (||u||_2 ||v||_2)\\). random_seed: Python integer. Seed for PRNG used to initialize centers. use_mini_batch: A boolean specifying whether to use the mini-batch k-means algorithm. See explanation above. @@ -330,6 +395,10 @@ class KMeansClustering(estimator.Estimator): iterations. Stops learning if the loss changes less than this amount. This may not work correctly if `use_mini_batch=True`. config: See @{tf.estimator.Estimator}. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column + instances that can be passed to `tf.feature_column.input_layer`. If this + is None, all features will be used. Raises: ValueError: An invalid argument was passed to `initial_clusters` or @@ -349,7 +418,8 @@ class KMeansClustering(estimator.Estimator): model_fn=_ModelFn( num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns).model_fn, model_dir=model_dir, config=config) @@ -406,4 +476,4 @@ class KMeansClustering(estimator.Estimator): def cluster_centers(self): """Returns the cluster centers.""" - return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) + return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index f9598bfc08c05ea3bba88b3135da0cf2e6bb0c95..88eb9cf692992fe2e1fc4f060ac98dd721c22307 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,6 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -226,6 +227,44 @@ class KMeansTest(KMeansTestBase): self._infer_helper(kmeans, clusters, 10) self._infer_helper(kmeans, clusters, 1) + def _parse_feature_dict_helper(self, features, parsed_feature_dict): + # Perform a sanity check. + self.assertEqual(features.shape, parsed_feature_dict.shape) + self.assertEqual(features.dtype, parsed_feature_dict.dtype) + # Then check that running the tensor yields the original list of points. + with self.test_session() as sess: + parsed_points = sess.run(parsed_feature_dict) + self.assertAllEqual(self.points, parsed_points) + + def test_parse_features(self): + """Tests the various behaviours of kmeans._parse_features_if_necessary.""" + + # No-op if a tensor is passed in. + features = constant_op.constant(self.points) + parsed_features = kmeans_lib._parse_features_if_necessary(features, None) + self.assertAllEqual(features, parsed_features) + + # All values from a feature dict are transformed into a tensor. + feature_dict = { + 'x': [[point[0]] for point in self.points], + 'y': [[point[1]] for point in self.points] + } + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict, None) + self._parse_feature_dict_helper(features, parsed_feature_dict) + + # Only the feature_columns of a feature dict are transformed into a tensor. + feature_dict_with_extras = { + 'foo': 'bar', + 'x': [[point[0]] for point in self.points], + 'baz': {'fizz': 'buzz'}, + 'y': [[point[1]] for point in self.points] + } + feature_columns = [fc.numeric_column(key='x'), fc.numeric_column(key='y')] + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict_with_extras, feature_columns) + self._parse_feature_dict_helper(features, parsed_feature_dict) + class KMeansTestMultiStageInit(KMeansTestBase): @@ -374,7 +413,7 @@ class KMeansCosineDistanceTest(KMeansTestBase): self.assertAllClose(score, self.true_score, atol=1e-2) def test_predict_kmeans_plus_plus(self): - # Most points are concetrated near one center. KMeans++ is likely to find + # Most points are concentrated near one center. KMeans++ is likely to find # the less populated centers. points = np.array( [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], @@ -394,7 +433,6 @@ class KMeansCosineDistanceTest(KMeansTestBase): true_assignments = [0] * 2 + [1] * 2 + [2] * 8 true_score = len(points) - np.tensordot( normalize(points), true_centers[true_assignments]) - kmeans = kmeans_lib.KMeansClustering( 3, initial_clusters=self.initial_clusters, @@ -566,7 +604,7 @@ class KMeansTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9..ca46c39baa16a7fddb96121e0402fc35d24ce1c2 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -216,7 +216,7 @@ def _wals_factorization_model_function(features, labels, mode, params): name=WALSMatrixFactorization.LOSS, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + # \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\) rwse_var = variable_scope.variable( 0., trainable=False, @@ -235,7 +235,7 @@ def _wals_factorization_model_function(features, labels, mode, params): num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of - * new_factors: A flot Tensor of the factor values after update. + * new_factors: A float Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. @@ -490,11 +490,11 @@ class WALSMatrixFactorization(estimator.Estimator): and the problem simplifies to ALS. Note that, in this case, col_weights must also be set to "None". - List of lists of non-negative scalars, of the form - [[w_0, w_1, ...], [w_k, ... ], [...]], + \\([[w_0, w_1, ...], [w_k, ... ], [...]]\\), where the number of inner lists equal to the number of row factor shards and the elements in each inner list are the weights for the rows of that shard. In this case, - w_ij = unonbserved_weight + row_weights[i] * col_weights[j]. + \\(w_ij = unonbserved_weight + row_weights[i] * col_weights[j]\\). - A non-negative scalar: This value is used for all row weights. Note that it is allowed to have row_weights as a list and col_weights as a scalar, or vice-versa. diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 6fc053759c58d30c24657dd22e7d12be46fc7a7e..aab7d0c9e8874269bfa5f33193b0dc0ba4bbc9cd 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -8,30 +8,47 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "feature_column_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":sequential_feature_column", + ":sequence_feature_column", + "//tensorflow/python:util", ], ) py_library( - name = "sequential_feature_column", - srcs = ["python/feature_column/sequential_feature_column.py"], + name = "sequence_feature_column", + srcs = ["python/feature_column/sequence_feature_column.py"], srcs_version = "PY2AND3", - deps = [], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + ], +) + +py_test( + name = "sequence_feature_column_test", + srcs = ["python/feature_column/sequence_feature_column_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/contrib/feature_column/__init__.py b/tensorflow/contrib/feature_column/__init__.py index 6da7b126931effae9cc97091a27070d7013450d4..baa8c1567a5aeb39976ab04c54ae2728ba050a7c 100644 --- a/tensorflow/contrib/feature_column/__init__.py +++ b/tensorflow/contrib/feature_column/__init__.py @@ -19,12 +19,18 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.feature_column.python.feature_column.sequential_feature_column import * +from tensorflow.contrib.feature_column.python.feature_column.sequence_feature_column import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + 'sequence_categorical_column_with_hash_bucket', + 'sequence_categorical_column_with_identity', + 'sequence_categorical_column_with_vocabulary_list', + 'sequence_categorical_column_with_vocabulary_file', + 'sequence_input_layer', + 'sequence_numeric_column', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..b588f75efe9d0bbf8213a89978a627c0a0ccf554 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -0,0 +1,461 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental methods for tf.feature_column sequence input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import collections + + +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# pylint: disable=protected-access +# TODO(b/73827486): Support SequenceExample. + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + features: A dict mapping keys to tensors. + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + weight_collections: A list of collection names to which the Variable will be + added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc._clean_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, fc._SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'You can wrap a sequence_categorical_column with an embedding_column ' + 'or indicator_column. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + None, default_name='sequence_input_layer', values=features.values()): + builder = fc._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + output_tensors.append( + array_ops.reshape( + dense_tensor, + shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) + fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length + + +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + """Returns a feature column that represents sequences of integers. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [watches_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + num_buckets: Range of inputs. Namely, inputs are expected to be in the + range `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace out-of-range inputs. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +def sequence_categorical_column_with_hash_bucket( + key, hash_bucket_size, dtype=dtypes.string): + """A sequence of categorical terms where ids are set by hashing. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + tokens = sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=1000) + tokens_embedding = embedding_column(tokens, dimension=10) + columns = [tokens_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + hash_bucket_size: An int > 1. The number of buckets. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """A sequence of categorical terms where ids use a vocabulary file. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + states = sequence_categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + states_embedding = embedding_column(states, dimension=10) + columns = [states_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. If None, it is set to the length of `vocabulary_file`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=num_oov_buckets, + default_value=default_value, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): + """A sequence of categorical terms where ids use an in-memory list. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + colors = sequence_categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), + num_oov_buckets=2) + colors_embedding = embedding_column(colors, dimension=3) + columns = [colors_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a + hash of the input value. A positive `num_oov_buckets` can not be specified + with `default_value`. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( + key=key, + vocabulary_list=vocabulary_list, + dtype=dtype, + default_value=default_value, + num_oov_buckets=num_oov_buckets)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32, + normalizer_fn=None): + """Returns a feature column that represents sequences of numeric data. + + Example: + + ```python + temperature = sequence_numeric_column('temperature') + columns = [temperature] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input features. + shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`, + each example must contain `2 * sequence_length` values. + default_value: A single value compatible with `dtype` that is used for + padding the sparse data into a dense `Tensor`. + dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. + + Returns: + A `_SequenceNumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int. + ValueError: if any dimension in shape is not a positive integer. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + shape = fc._check_shape(shape=shape, key=key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) + + return _SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype, + normalizer_fn=normalizer_fn) + + +def _assert_all_equal_and_return(tensors, name=None): + """Asserts that all tensors are equal and returns the first one.""" + with ops.name_scope(name, 'assert_all_equal', values=tensors): + if len(tensors) == 1: + return tensors[0] + assert_equal_ops = [] + for t in tensors[1:]: + assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) + with ops.control_dependencies(assert_equal_ops): + return array_ops.identity(tensors[0]) + + +class _SequenceNumericColumn( + fc._SequenceDenseColumn, + collections.namedtuple( + '_SequenceNumericColumn', + ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])): + """Represents sequences of numeric data.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_spec(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + input_tensor = inputs.get(self.key) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor + + @property + def _variable_shape(self): + return tensor_shape.TensorShape(self.shape) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + sp_tensor = inputs.get(self) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + sequence_length = fc._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=self._variable_shape.num_elements()) + return fc._SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py new file mode 100644 index 0000000000000000000000000000000000000000..89b5f4c4137f6c42417f539a578fd8b11f8b235d --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -0,0 +1,857 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase): + + def test_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = fc.embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + vocabulary_size_a = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + vocabulary_size_b = 2 + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [1, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 1, 0), + dense_shape=(2, 2)) + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size_a) + indicator_column_a = fc.indicator_column(categorical_column_a) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size_b) + indicator_column_b = fc.indicator_column(categorical_column_b) + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[indicator_column_b, indicator_column_a]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_indicator_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + def test_numeric_column(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_input_layer = [ + [[0.], [1.]], + [[10.], [0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column_multi_dim(self): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + # The output of numeric_column._get_dense_tensor should be flattened. + expected_input_layer = [ + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_not_equal(self): + """Tests that an error is raised when sequence lengths are not equal.""" + # Input a with sequence_length = [2, 1] + sparse_input_a = sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + # Input b with sequence_length = [1, 1] + sparse_input_b = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0)), + values=(1., 10.), + dense_shape=(2, 2)) + numeric_column_a = sfc.sequence_numeric_column('aaa') + numeric_column_b = sfc.sequence_numeric_column('bbb') + + _, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=[numeric_column_a, numeric_column_b]) + + with monitored_session.MonitoredSession() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Condition x == y did not hold element-wise:\] ' + r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + sess.run(sequence_length) + + +class InputLayerTest(test.TestCase): + """Tests input_layer with sequence feature columns.""" + + def test_embedding_column(self): + """Tests that error is raised for sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc.input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + """Tests that error is raised for sequence indicator column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc.input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + _assert_sparse_tensor_indices_shape(test_case, expected, actual) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + +def _assert_sparse_tensor_indices_shape(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_inputs3d(self): + """Tests _get_sparse_tensors when the input is already 3D Tensor.""" + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 2, 0), + dense_shape=(2, 2, 1)) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'Column aaa expected ID tensor of rank 2\.\s*' + r'id_tensor shape:\s*\[2 2 1\]'): + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': inputs})) + with monitored_session.MonitoredSession() as sess: + id_weight_pair.id_tensor.eval(session=sess) + + +class SequenceCategoricalColumnWithHashBucketTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_hash_bucket( + 'aaa', hash_bucket_size=10) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + values=np.array((0, 0, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_indices_shape( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): + + def _write_vocab(self, vocab_strings, file_name): + vocab_file = os.path.join(self.get_temp_dir(), file_name) + with open(vocab_file, 'w') as f: + f.write('\n'.join(vocab_strings)) + return vocab_file + + def setUp(self): + super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp() + + vocab_strings = ['omar', 'stringer', 'marlo'] + self._wire_vocabulary_file_name = self._write_vocab(vocab_strings, + 'wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + expected_lookups = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceIndicatorColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + expected_lookups = [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceNumericColumnTest(test.TestCase): + + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + + def test_get_sequence_dense_tensor(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_dense_tensor = [ + [[0.], [1.]], + [[10.], [0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_shape(self): + """Tests get_sequence_dense_tensor with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_dense_tensor = [ + [[0., 1., 2.], [3., 4., 5.]], + [[10., 11., 12.], [0., 0., 0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_dense_tensor_multi_dim(self): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + expected_dense_tensor = [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_sequence_length(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_shape(self): + """Tests _sequence_length with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index eccce99071dc1477cf4f3bb152f3304b3b0fc35a..f7b3273a4d35eadb9fad49399b7bf18d4bd33503 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -180,15 +180,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index daba965a98893b992abdc598ec713f13020d6e91..484ffee3e7afe55c63cab2a463454353b2663e18 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -28,7 +28,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio -from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 6b455567d766dbe6d380a498bd7f521db27e077b..59bad8982dd163f89f37e1a0a9d5017d0c495de3 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -74,15 +74,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index e61221a6b0d34373279a379f356c99c379488182..cca1a054193815793846a8753678f75bdfd72a6c 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -28,7 +28,7 @@ #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" using tensorflow::strings::StrCat; @@ -256,6 +256,9 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height, if (p != std::string::npos) { string rgb24 = line.substr(p + 9, line.find(" ", p + 9)); rgb24 = rgb24.substr(0, rgb24.find(",")); + // Strip anything after " ", in case the format is + // `640x360 [SAR 1:1 DAR 16:9]` + rgb24 = rgb24.substr(0, rgb24.find(" ")); string rgb24_width = rgb24.substr(0, rgb24.find("x")); string rgb24_height = rgb24.substr(rgb24_width.length() + 1); if (strings::safe_strtou32(rgb24_width, &width_value) && @@ -270,8 +273,10 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height, // We only look for the first stream mapping to have the number of the // frames. // Once processed we will not further process stream mapping section. - if (line.find("frame= ") == 0) { - string number = line.substr(8, line.find(" ", 8)); + if (line.find("frame=") == 0) { + // The format might be `frame= 166 ` or `frame=12488 ` + string number = line.substr(6); + number = number.substr(number.find_first_not_of(" ")); number = number.substr(0, number.find(" ")); if (strings::safe_strtou32(number, &frames_value)) { in_mapping = false; diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index a8d5a0dd83fb504b5e6671c3e82dc7d2dd3e6a9b..bf2aa75545813f7da88ed503798572474c7c2eb8 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -53,7 +53,7 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, int32 samples_per_second, int32 channel_count, const std::vector& samples, string* output_data); -// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8 +// Reads an video file using ffmpeg and converts it into a RGB24 in uint8 // [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg. Status ReadVideoFile(const string& filename, std::vector* output_data, uint32* width, uint32* height, uint32* frames); diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 020b5c99c61019254bef0b1dff6bc5901c92758a..b1b5126d9e9e5196a1733b80e0778e53cef7f774 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py -from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 9e5f54f0973eae899ca65e4098358107053cb7d4..249debbdf6dff412a5be6cb1032fc4a3567c7d0b 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", - "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -63,7 +62,9 @@ tf_custom_op_py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:smart_cond", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -92,6 +93,7 @@ tf_kernel_library( ], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", ], alwayslink = 1, @@ -161,23 +163,6 @@ py_test( ], ) -py_test( - name = "accumulate_n_v2_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - cuda_py_test( name = "critical_section_test", size = "medium", @@ -185,31 +170,16 @@ cuda_py_test( additional_deps = [ "//tensorflow/python:client_testlib", ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:platform_test", "//tensorflow/python:resource_variable_ops", - ], -) - -py_test( - name = "accumulate_n_v2_eager_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_eager_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/eager:backprop", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", - "//third_party/py/numpy", ], ) @@ -354,15 +324,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 4746cfe0720cb20f530dc919fe062db17a1dfe84..dc49383c5c300e82839c478e097074b3e8776b3b 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -71,6 +71,10 @@ See the @{$python/contrib.framework} guide. @@model_variable @@variable @@VariableDeviceChooser +@@convolutional_delta_orthogonal +@@convolutional_orthogonal_1d +@@convolutional_orthogonal_2d +@@convolutional_orthogonal_3d @@zero_initializer @@load_checkpoint @@ -82,11 +86,16 @@ See the @{$python/contrib.framework} guide. @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer +@@argsort @@py_func @@sort @@get_placeholders +@@smart_cond +@@smart_constant_value +@@smart_case + @@CriticalSection @@BoundedTensorSpec @@ -99,19 +108,38 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.framework.python.framework import * +from tensorflow.contrib.framework.python.framework import nest from tensorflow.contrib.framework.python.ops import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope -from tensorflow.python.ops.control_flow_ops import smart_cond -from tensorflow.python.ops.control_flow_ops import smart_constant_value - +from tensorflow.python.framework.smart_cond import smart_case +from tensorflow.python.framework.smart_cond import smart_cond +from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec - +from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal +from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d +from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d +from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] - +_nest_allowed_symbols = [ + 'assert_same_structure', + 'is_sequence', + 'flatten', + 'flatten_dict_items', + 'pack_sequence_as', + 'map_structure', + 'assert_shallow_structure', + 'flatten_up_to', + 'map_structure_up_to', + 'get_traverse_shallow_structure', + 'yield_flat_paths', + 'flatten_with_joined_string_paths', +] + +remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc index 5bf6b67529579e71a615c27e035111a58d5c02e0..6ab3f460b36d5dd632daee1af68d62529df9cb09 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_var.h" namespace tensorflow { @@ -85,4 +86,74 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_KERNELS +template +class ZeroVarInitializer : public OpKernel { + public: + explicit ZeroVarInitializer(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_)); + } + + void Compute(OpKernelContext* ctx) override { + Var* variable = nullptr; + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, 0), &variable, + [this, ctx](Var** var_ptr) { + *var_ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* var_tensor = nullptr; + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + dtype_, shape_, &unused, &var_tensor, attr)); + + functor::TensorSetZero()( + ctx->eigen_device(), + var_tensor->flat()); + + *(*var_ptr)->tensor() = *var_tensor; + + return Status::OK(); + })); + + core::ScopedUnref scoped(variable); + mutex_lock ml(*variable->mu()); + + OP_REQUIRES(ctx, !variable->is_initialized, + errors::InvalidArgument("input is already initialized")); + + variable->is_initialized = true; + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = HandleFromInput(ctx, 0); + } + + private: + DataType dtype_; + TensorShape shape_; +}; + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ZeroVarInitializer); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype") \ + .HostMemory("var"), \ + ZeroVarInitializer); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS +#endif // GOOGLE_CUDA + } // namespace tensorflow diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc index 706134ba9a51de6253ba7463b17ff662ea740ed0..f6ee6cdb5713c113aff2228db58244ac73536d9a 100644 --- a/tensorflow/contrib/framework/ops/variable_ops.cc +++ b/tensorflow/contrib/framework/ops/variable_ops.cc @@ -39,4 +39,33 @@ ref: Should be from a `Variable` node. output_ref:= Same as "ref". )doc"); +REGISTER_OP("ZeroVarInitializer") + .Input("var: resource") + .Output("output_var: resource") + .Attr("dtype: type") + .Attr("shape: shape") + .SetAllowsUninitializedInput() + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Scalar()); + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); + PartialTensorShape p; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &p)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); + c->set_output_handle_shapes_and_types( + 0, std::vector{{s, t}}); + + return Status::OK(); + }) + .Doc(R"doc( +Initialize 'var' with all zeros. This op requires that the resource var is not +initialized. The var will first be allocated memory, then be filled with all +zeros. This op is intended to save memory during initialization, +if you use this op, you should not run initializer of the var. + +var: Should be a ResourceVariable. +output_var:= Same as "var". +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index 8e54e09e04ee3c0ddbd4fa84cc0912cb70c93e62..cfdc7df7d8fd4c1406bf447a79038ac33b11e047 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -49,7 +49,6 @@ class ExperimentalTest(test.TestCase): "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " "be removed at any time, and without warning." "\n" - "\n" "\nArgs:" "\n arg0: Arg 0." "\n arg1: Arg 1." diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 49eec3a3f1a0f357ea3adfade51e71cb0f89942d..2703224b1bf62831b6088558d4f93950fe938c10 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -85,14 +85,19 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] + visited = set() while next_to_visit: cur_node = next_to_visit[0] + visited.add(cur_node) del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError("Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: - next_to_visit += name_to_input_name[cur_node] + next_to_visit += [ + input_node for input_node in name_to_input_name[cur_node] + if input_node not in visited + ] elif n not in reachable_by_input: nodes_post_output.append(n) diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index b8a6d109e19211d271c2b15bac66ddacd38fe395..812c5fbd8cb759aef6eb1aad532c03794b2ceaf4 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -42,7 +42,8 @@ class GraphUtilTest(test.TestCase): graph_def = graph_pb2.GraphDef() node_a = GetNewNode('A', 'Placeholder', []) node_b = GetNewNode('B', 'Op1', ['A']) - node_c = GetNewNode('C', 'Op1', ['B']) + # A loop in the part that will be fused. + node_c = GetNewNode('C', 'Op1', ['B', 'C']) node_d = GetNewNode('D', 'Op1', ['C']) node_e = GetNewNode('E', 'Op1', ['D']) graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 8cdb340f2ddd9b3a7f55c1937ef045f4627e99be..af1b404cb51bf5d8f8350481f2301d9653895e85 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -48,7 +48,7 @@ class LocalVariabletest(test.TestCase): variables = variables_lib.local_variables() self.assertEquals(2, len(variables)) self.assertRaises(errors_impl.OpError, sess.run, variables) - variables_lib.initialize_variables(variables).run() + variables_lib.variables_initializer(variables).run() self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) @@ -78,7 +78,6 @@ class AssertScalarIntTest(test.TestCase): [3, 4], dtype=dtypes.int32)) -@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -209,31 +208,25 @@ class WithShapeTest(test.TestCase): self.assertRaisesRegexp(errors_impl.OpError, "Wrong shape", tensor_2x2.eval, {tensor_no_shape: [42.0]}) + @test_util.enable_c_shapes def test_with_shape_partial(self): with self.test_session(): tensor_partial_shape = array_ops.placeholder(dtypes.float32) tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: - if ops._USE_C_API: - error_message = "Shapes must be equal rank, but are 2 and 1" - else: - error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, "Shapes must be equal rank, but are 2 and 1", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: - if ops._USE_C_API: - error_message = (r"Dimension 1 in both shapes must be equal, but are " - r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") - else: - error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, + r"Dimension 1 in both shapes must be equal, but are 2 and 1. " + r"Shapes are \[\?,2\] and \[2,1\].", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py deleted file mode 100644 index 476528b0dd3df05239d5dc402b466e06dd789985..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops - - - -def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): - """Returns the element-wise sum of a list of tensors. - - Optionally, pass `shape` and `tensor_dtype` for shape and type checking, - otherwise, these are inferred. - - `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not - wait for all of its inputs to be ready before beginning to sum. This can - save memory if inputs are ready at different times, since minimum temporary - storage is proportional to the output size rather than the inputs size. - - Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. - - For example: - - ```python - a = tf.constant([[1, 2], [3, 4]]) - b = tf.constant([[5, 0], [0, 6]]) - tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] - - # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) - # [[7, 4], - # [6, 14]] - ``` - - Args: - inputs: A list of `Tensor` objects, each with same shape and type. - shape: Shape of elements of `inputs`. - tensor_dtype: The type of `inputs`. - name: A name for the operation (optional). - - Returns: - A `Tensor` of same shape and type as the elements of `inputs`. - - Raises: - ValueError: If `inputs` don't all have same shape and dtype or the shape - cannot be inferred. - """ - _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" - "with the same dtype and shape") - if not inputs or not isinstance(inputs, (list, tuple)): - raise _INPUTS_ERR_MSG - inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise _INPUTS_ERR_MSG - if not all(x.dtype == inputs[0].dtype for x in inputs): - raise _INPUTS_ERR_MSG - if shape is not None: - shape = tensor_shape.as_shape(shape) - else: - shape = tensor_shape.unknown_shape() - for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): - shape = shape.merge_with(input_tensor.get_shape()) - - # tensor_dtype is for safety only; operator's output type computed in C++ - if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}" - .format(tensor_dtype, inputs[0].dtype)) - - if len(inputs) == 1 and name is None: - return inputs[0] - elif len(inputs) == 1 and name is not None: - return array_ops.identity(inputs[0], name=name) - elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back - # onto AddN for now. - # TODO(frreiss) remove this once the lifetime of eager variables gets - # addressed - return math_ops.add_n(inputs, name=name) - else: - return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) - -# The following code should eventually be merged into -# tensorflow/python/ops/math_grad.py -@ops.RegisterGradient("AccumulateNV2") -def _AddNGrad(op, grad): - """Same as gradient for AddN. Copies the gradient to all inputs.""" - # Not broadcasting. - return [grad] * len(op.inputs) diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 409657fe1da0e5540cd2ad6070d86737c039e91f..5b150339953f961c756c0909dd1795341159b9cd 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -68,7 +68,7 @@ from tensorflow.python.util import tf_decorator __all__ = [ 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', - 'arg_scoped_arguments' + 'arg_scoped_arguments', 'arg_scope_func_key' ] _ARGSTACK = [{}] @@ -89,7 +89,7 @@ def current_arg_scope(): return stack[-1] -def _key_op(op): +def arg_scope_func_key(op): return getattr(op, '_key_op', str(op)) @@ -103,9 +103,9 @@ def _kwarg_names(func): def _add_op(op): - key_op = _key_op(op) - if key_op not in _DECORATED_OPS: - _DECORATED_OPS[key_op] = _kwarg_names(op) + key = arg_scope_func_key(op) + if key not in _DECORATED_OPS: + _DECORATED_OPS[key] = _kwarg_names(op) @tf_contextlib.contextmanager @@ -142,21 +142,21 @@ def arg_scope(list_ops_or_scope, **kwargs): else: # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs. if not isinstance(list_ops_or_scope, (list, tuple)): - raise TypeError('list_ops_or_scope must either be a list/tuple or reused' + raise TypeError('list_ops_or_scope must either be a list/tuple or reused ' 'scope (i.e. dict)') try: current_scope = current_arg_scope().copy() for op in list_ops_or_scope: - key_op = _key_op(op) + key = arg_scope_func_key(op) if not has_arg_scope(op): raise ValueError('%s is not decorated with @add_arg_scope', _name_op(op)) - if key_op in current_scope: - current_kwargs = current_scope[key_op].copy() + if key in current_scope: + current_kwargs = current_scope[key].copy() current_kwargs.update(kwargs) - current_scope[key_op] = current_kwargs + current_scope[key] = current_kwargs else: - current_scope[key_op] = kwargs.copy() + current_scope[key] = kwargs.copy() _get_arg_stack().append(current_scope) yield current_scope finally: @@ -176,14 +176,14 @@ def add_arg_scope(func): def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs - key_func = _key_op(func) + key_func = arg_scope_func_key(func) if key_func in current_scope: current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) _add_op(func) - setattr(func_with_args, '_key_op', _key_op(func)) + setattr(func_with_args, '_key_op', arg_scope_func_key(func)) return tf_decorator.make_decorator(func, func_with_args) @@ -196,7 +196,7 @@ def has_arg_scope(func): Returns: a boolean. """ - return _key_op(func) in _DECORATED_OPS + return arg_scope_func_key(func) in _DECORATED_OPS def arg_scoped_arguments(func): @@ -209,4 +209,4 @@ def arg_scoped_arguments(func): a list of kwargs names. """ assert has_arg_scope(func) - return _DECORATED_OPS[_key_op(func)] + return _DECORATED_OPS[arg_scope_func_key(func)] diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index 7ba9d4ffa90f6860629b15a2ea91e0c573bf6368..4c3879d4fc08b53ea8be5f1256a830a64fb39af6 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -170,6 +170,30 @@ class ArgScopeTest(test.TestCase): self.assertTupleEqual(args, func1_args) self.assertDictEqual(kwargs, func1_kwargs) + def testNestedArgScopeObjectCreatedOutsideScopeOverridesArgScope(self): + + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + scope_object = get_scope_object() + with arg_scope([func1], b=2, d=10): + with arg_scope(scope_object): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1]}) + + def testArgScopeObjectCreatedWithinScopeInheritsArgScope(self): + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + with arg_scope([func1], b=2, d=10): + with arg_scope(get_scope_object()): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1], 'd': 10}) + def testSharedArgScope(self): func1_args = (0,) func1_kwargs = {'a': 1, 'b': None, 'c': [1]} diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 182fec924febb74a23b82b1664d137f033f3b1b4..72835c3ad86e6321eb30324c7dd0751034759ce4 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -24,10 +24,12 @@ import collections # from tensorflow.core.protobuf import critical_section_pb2 from tensorflow.python.eager import context -from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -38,11 +40,32 @@ CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" class _ExecutionSignature( collections.namedtuple("_ExecutionSignature", - ("op", "exclusive_resource_access"))): + ("op", "handle", + "resources", "exclusive_resource_access"))): """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" pass +def _identity(x): + """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" + if isinstance(x, tensor_array_ops.TensorArray): + return x.identity() + elif isinstance(x, ops.Operation): + return control_flow_ops.group(x) + elif context.executing_eagerly() and x is None: + return None + else: + return array_ops.identity(x) + + +def _get_colocation(op): + """Get colocation symbol from op, if any.""" + try: + return op.get_attr("_class") + except ValueError: + return None + + class CriticalSection(object): """Critical section. @@ -112,16 +135,18 @@ class CriticalSection(object): ``` """ - def __init__(self, name=None, critical_section_def=None, import_scope=None): + def __init__(self, name=None, shared_name=None, + critical_section_def=None, import_scope=None): """Creates a critical section.""" if critical_section_def and name is not None: - raise ValueError("critical_section_def and name are mutually exclusive.") + raise ValueError("critical_section_def and shared_name are " + "mutually exclusive.") if critical_section_def: self._init_from_proto(critical_section_def, import_scope=import_scope) else: - self._init_from_args(name) + self._init_from_args(name, shared_name) - def _init_from_proto(self, critical_section_def, import_scope): + def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name raise NotImplementedError("Not yet implemented") # TODO(ebrevdo): Re-enable once CriticalSection is in core. # assert isinstance( @@ -133,19 +158,21 @@ class CriticalSection(object): # critical_section_def.critical_section_name, # import_scope=import_scope)) - def _init_from_args(self, name): + def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name """Initialize the CriticalSection from constructor arguments.""" with ops.name_scope(name, "CriticalSection", []) as name: - with ops.control_dependencies(None): + with ops.init_scope(): # pylint: disable=protected-access - handle_name = ops._name_from_scope_name(name) container = ops.get_default_graph()._container # pylint: enable=protected-access + if shared_name is None: + shared_name = name if container is None: container = "" - self._handle = gen_resource_variable_ops.critical_section_op( - shared_name=handle_name, name=name) - if context.in_graph_mode(): + self._handle = gen_resource_variable_ops.mutex_v2( + shared_name=shared_name, container=container, name=name) + + if not context.executing_eagerly(): ops.add_to_collections(CRITICAL_SECTIONS, self) @property @@ -171,11 +198,11 @@ class CriticalSection(object): The tensors returned from `fn(*args, **kwargs)`. Raises: - ValueError: If `fn` attempts to use this `CriticalSection` in any nested - way. + ValueError: If `fn` attempts to lock this `CriticalSection` in any nested + or lazy way that may cause a deadlock. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same - resources as in `*args`, `**kwargs`, and any additionaly captured + resources as in `*args`, `**kwargs`, and any additionally captured inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, if another execution in another `CriticalSection` was created without `exclusive_resource_access=True`, a `ValueError` will be raised. @@ -183,68 +210,163 @@ class CriticalSection(object): name = kwargs.pop("name", None) exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) - args = nest.map_structure(ops.convert_to_tensor, args) with ops.name_scope(name, "critical_section_execute", []): - fn_op = function.make_defun_op(fn, *args, **kwargs) - flat_dtypes = nest.flatten(fn_op.output_dtypes) - flat_shapes = nest.flatten(fn_op.output_shapes) - all_inputs = nest.flatten(args) + fn_op.captured_inputs - if self._handle in all_inputs: - raise ValueError("The function fn attempts to access the " - "CriticalSection in which it would be running. This " - "is illegal and would cause deadlocks. " - "CriticalSection: %s." % self._handle) - - if context.in_graph_mode(): - # Collections and op introspection does not work in eager - # mode. This is generally ok; since eager mode (as of - # writing) executes sequentially anyway. - all_input_resources = [ - x for x in all_inputs if x.dtype == dtypes.resource] - for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): - if sg.op.inputs[0].name == self._handle.name: - # Other executions in the same critical section are allowed. - continue - if not (exclusive_resource_access or sg.exclusive_resource_access): - # Neither execution requested exclusive access. - continue - sg_input_names = [y.name for y in sg.op.inputs[1:]] - for res in all_input_resources: - if res.name in sg_input_names: - raise ValueError( - "This execution would access resource %s; but either this " - "execution (CriticalSection: %s) or Execution '%s' " - "(CriticalSection: %s) requested exclusive resource access " - "of this resource for their critical section. Did you mean " - "to call execute with keyword argument " - "exclusive_resource_access=False?" - % (res.name, - self.name, - sg.op.name, - sg.op.inputs[0].op.name)) - - flat_outputs = gen_resource_variable_ops.execute_in_critical_section( - critical_section=self._handle, - arguments=all_inputs, - f=fn_op, - output_types=flat_dtypes, - output_shapes=flat_shapes) - - if context.in_graph_mode(): - if isinstance(flat_outputs, ops.Operation): - flat_outputs = [flat_outputs] - op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) - else flat_outputs[0]) + + # Ensure that mutex locking only happens *after* all args and + # kwargs have been executed. This avoids certain types of deadlocks. + lock = gen_resource_variable_ops.mutex_lock(self._handle) + + if not context.executing_eagerly(): + # NOTE(ebrevdo): This is to ensure we don't pick up spurious + # Operations created by other threads. + with ops.get_default_graph()._lock: # pylint: disable=protected-access + existing_ops = ops.get_default_graph().get_operations() + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + # TODO(ebrevdo): If creating critical sections in a python loop, this + # makes graph creation time quadratic. Revisit if this + # becomes a problem. + created_ops = (set(ops.get_default_graph().get_operations()) + .difference(existing_ops)) + else: + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + + if not context.executing_eagerly(): + self._add_control_dependencies_to_lock(created_ops, lock.op) + + # captured_resources is a list of resources that are directly + # accessed only by ops created during fn(), not by any + # ancestors of those ops in the graph. + captured_resources = set([ + input_ for op in created_ops + for input_ in op.inputs + if input_.dtype == dtypes.resource + ]) + + # NOTE(ebrevdo): The only time self._is_self_handle() is True + # in this call is if one of the recently created ops, within + # the execute(), themselves attempt to access the + # CriticalSection. This will cause a deadlock. + if any(self._is_self_handle(x) for x in captured_resources): + raise ValueError("The function fn attempts to directly access the " + "CriticalSection in which it would be running. " + "This is illegal and would cause deadlocks.") + + self._check_multiple_access_to_resources( + captured_resources, exclusive_resource_access) + + r_flat = [_identity(x) for x in nest.flatten(r)] + + with ops.control_dependencies(r_flat): + # The identity must run on the same machine as self._handle + with ops.colocate_with(self._handle): + # Do not use array_ops.identity as there are special + # optimizations within TensorFlow which seem to elide it + # even when optimizations are disabled(!). + ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( + lock) + + # Make sure that if any element of r is accessed, all of + # them are executed together. + r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) + + with ops.control_dependencies([ensure_lock_exists]): + outputs = nest.map_structure(_identity, r) + + if not context.executing_eagerly(): signature = _ExecutionSignature( - op=op, + op=lock.op, + handle=self._handle, + resources=list(captured_resources), exclusive_resource_access=exclusive_resource_access) ops.add_to_collections( CRITICAL_SECTION_EXECUTIONS, signature) - return (flat_outputs[0] - if (len(flat_outputs) == 1 - and isinstance(flat_outputs[0], ops.Operation)) - else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs)) + return outputs + + def _add_control_dependencies_to_lock(self, created_ops, lock_op): + """To avoid deadlocks, all args must be executed before lock_op.""" + # Get all arguments (explicit and captured) of all ops created by fn(). + all_args = set([input_.op for op in created_ops for input_ in op.inputs]) + all_args.update( + input_op for op in created_ops for input_op in op.control_inputs) + # Unfortunately, we can't use sets throughout because TF seems to + # create new Operation objects for the same op sometimes; and we + # can't rely on id(op). + + # pylint: disable=protected-access + all_args_dict = dict((op._id, op) for op in all_args) + + # Remove ops created within fn, or that lock_op already has a + # control dependency on. Also remove a possible self-loop. + for op in created_ops: + all_args_dict.pop(op._id, None) + for op in lock_op.control_inputs: + all_args_dict.pop(op._id, None) + for input_ in lock_op.inputs: + all_args_dict.pop(input_.op._id, None) + all_args_dict.pop(lock_op._id, None) + + all_args = all_args_dict.values() + + if not all_args: + # No control dependencies to add; return early. + return + + # This group is important: it ensures that any ops in all_args + # outside the control context of the lock_op (and this fn, which + # runs in the same context) are added to this context before + # being added to the control dependencies of lock_op. + all_args = control_flow_ops.group(*all_args) + + lock_op._add_control_input(all_args) + # pylint: enable=protected-access + + def _is_self_handle(self, x): + """Check if the tensor `x` is the same Mutex as `self._handle`.""" + return (x.op.type == "MutexV2" + # blank shared_name means the op will create a unique one. + and x.op.get_attr("shared_name") + and (x.op.get_attr("shared_name") == + self._handle.op.get_attr("shared_name")) + and (x.op.device == self._handle.op.device + or _get_colocation(x.op) == _get_colocation(self._handle.op))) + + def _check_multiple_access_to_resources( + self, captured_resources, exclusive_resource_access): + """Raise if captured_resources are accessed by another CriticalSection. + + Args: + captured_resources: Set of tensors of type resource. + exclusive_resource_access: Whether this execution requires exclusive + resource access. + + Raises: + ValueError: If any tensors in `captured_resources` are also accessed + by another `CriticalSection`, and at least one of them requires + exclusive resource access. + """ + # Collections and op introspection does not work in eager + # mode. This is generally ok; since eager mode (as of + # writing) executes sequentially anyway. + for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): + if self._is_self_handle(sg.handle): + # Other executions in the same critical section are allowed. + continue + if not (exclusive_resource_access or sg.exclusive_resource_access): + # Neither execution requested exclusive access. + continue + resource_intersection = captured_resources.intersection(sg.resources) + if resource_intersection: + raise ValueError( + "This execution would access resources: %s. Either this " + "lock (CriticalSection: %s) or lock '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource. Did you mean to call execute with keyword " + "argument exclusive_resource_access=False?" % + (list(resource_intersection), self._handle.name, + sg.op.name, sg.handle.name)) # TODO(ebrevdo): Re-enable once CriticalSection is in core. @@ -276,6 +398,7 @@ class CriticalSection(object): # def _execution_to_proto_fn(execution_signature, export_scope=None): # """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. +# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. # Args: # execution_signature: Instance of `_ExecutionSignature`. @@ -298,6 +421,7 @@ class CriticalSection(object): # def _execution_from_proto_fn(op_def, import_scope=None): # """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" +# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. # assert isinstance( # op_def, critical_section_pb2.CriticalSectionExecutionDef) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index a416724d3ba1719471d70667e140f9cd2daf86c7..df7d7e9dae80722569efccbc9cc0d1b75e90cf03 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -19,14 +19,15 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import critical_section_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging # TODO(ebrevdo): Re-enable once CriticalSection is in core. # from tensorflow.python.training import saver as saver_lib @@ -35,26 +36,82 @@ class CriticalSectionTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testCreateCriticalSection(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): - c = v.read_value() + c = v.value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) - num_concurrent = 1000 + num_concurrent = 100 r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value)) + @test_util.run_in_graph_and_eager_modes() + def testCriticalSectionWithControlFlow(self): + for outer_cond in [False, True]: + for inner_cond in [False, True]: + cs = critical_section_ops.CriticalSection(shared_name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + num_concurrent = 100 + + # pylint: disable=cell-var-from-loop + def fn(a, b): + c = v.read_value() + def true_fn(): + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + return control_flow_ops.cond( + array_ops.identity(inner_cond), true_fn, lambda: c) + + def execute(): + return cs.execute(fn, 1.0, 2.0) + + r = [ + control_flow_ops.cond(array_ops.identity(outer_cond), + execute, + v.read_value) + for _ in range(num_concurrent) + ] + # pylint: enable=cell-var-from-loop + + self.evaluate(v.initializer) + r_value = self.evaluate(r) + if inner_cond and outer_cond: + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + else: + self.assertAllClose([0] * num_concurrent, r_value) + + def testCriticalSectionInParallelDoesntDeadlockOnError(self): + # No eager mode execution of this test because eager does not + # run fn() in parallel, which is where the deadlock could + # potentially occur (in graph mode). + cs = critical_section_ops.CriticalSection(shared_name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn(i): + error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) + with ops.control_dependencies([error]): + return v.read_value() + num_concurrent = 2 + r = [cs.execute(fn, i) for i in range(num_concurrent)] + self.evaluate(v.initializer) + for _ in range(100): + with self.assertRaisesOpError("Error"): + self.evaluate(r) + @test_util.run_in_graph_and_eager_modes() def testCreateCriticalSectionFnReturnsOp(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn_return_op(a, b): @@ -62,7 +119,7 @@ class CriticalSectionTest(test.TestCase): with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): - return () + return control_flow_ops.no_op() num_concurrent = 100 r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] @@ -71,52 +128,166 @@ class CriticalSectionTest(test.TestCase): final_v = self.evaluate(v) self.assertAllClose(2.0 * num_concurrent, final_v) - def testCreateCriticalSectionRaw(self): - cs = critical_section_ops.CriticalSection(name="cs") - v = resource_variable_ops.ResourceVariable(0.0, name="v") - - @function.Defun(dtypes.float32, dtypes.float32) - def fn(a, b): - c = v.read_value() - with ops.control_dependencies([c]): - nv = v.assign_add(a * b) - with ops.control_dependencies([nv]): - return array_ops.identity(c) - - def execute(fn, *args): - output_args = fn.definition.signature.output_arg - return resource_variable_ops.execute_in_critical_section( - critical_section=cs._handle, - arguments=list(args) + fn.captured_inputs, - f=fn, - output_types=[out.type for out in output_args], - output_shapes=[tensor_shape.TensorShape(None) for _ in output_args]) - - num_concurrent = 1000 - r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] - self.evaluate(v.initializer) - r_value = self.evaluate(r) - self.assertAllClose([2.0 * i for i in range(num_concurrent)], - sorted(r_value)) - def testCollection(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) - execute_op = cs.execute(lambda x: x + 1, 1.0).op + execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute") + execute_op = [ + x for x in execute.graph.get_operations() + if "my_execute" in x.name and "MutexLock" in x.type + ][0] self.assertIn( execute_op, [signature.op for signature in ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) - @test_util.run_in_graph_and_eager_modes() def testRecursiveCriticalSectionAccessIsIllegal(self): - cs = critical_section_ops.CriticalSection(name="cs") + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection() def fn(x): - return cs.execute(lambda x: x+1, x) + return cs.execute(lambda y: y + 1, x) with self.assertRaisesRegexp( ValueError, - r"attempts to access the CriticalSection in which it would be running"): + r"attempts to directly access the CriticalSection in which it " + r"would be running"): + cs.execute(fn, 1.0) + + def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): + # This one is subtle; and we're being overly cautious here. The + # deadlock we are ensuring we catch is: + # + # to_capture = CS[lambda x: x + 1](1.0) + # deadlocked = CS[lambda x: x + to_capture](1.0) + # + # This would have caused a deadlock because executing `deadlocked` will + # lock the mutex on CS; but then due to dependencies, will attempt + # to compute `to_capture`. This computation requires locking CS, + # but that is not possible now because CS is already locked by + # `deadlocked`. + # + # We check that CriticalSection.execute properly inserts new + # control dependencies to its lock to ensure all captured + # operations are finished before anything runs within the critical section. + cs = critical_section_ops.CriticalSection(shared_name="cs") + fn = array_ops.identity + to_capture = cs.execute(fn, 1.0) + fn_captures = lambda x: x + to_capture + to_capture_too = array_ops.identity(to_capture) + + ex_0 = cs.execute(fn_captures, 1.0) + + with ops.control_dependencies([to_capture]): + # This is OK because to_capture will execute before this next call + ex_1 = cs.execute(fn_captures, 1.0) + + dependency = array_ops.identity(to_capture) + + fn_captures_dependency = lambda x: x + dependency + + ex_2 = cs.execute(fn_captures_dependency, 1.0) + + with ops.control_dependencies([to_capture_too]): + ex_3 = cs.execute(fn_captures_dependency, 1.0) + + # Ensure there's no actual deadlock on to_execute. + self.assertEquals(2.0, self.evaluate(ex_0)) + self.assertEquals(2.0, self.evaluate(ex_1)) + self.assertEquals(2.0, self.evaluate(ex_2)) + self.assertEquals(2.0, self.evaluate(ex_3)) + + def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): + cs = critical_section_ops.CriticalSection(shared_name="cs") + + def body_implicit_capture(i, j): + # This would have caused a deadlock if not for logic in execute + # that inserts additional control dependencies onto the lock op: + # * Loop body argument j is captured by fn() + # * i is running in parallel to move forward the execution + # * j is not being checked by the predicate function + # * output of cs.execute() is returned as next j. + fn = lambda: j + 1 + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + + def body_implicit_capture_protected(i, j): + # This version is ok because we manually add a control + # dependency on j, which is an argument to the while_loop body + # and captured by fn. + fn = lambda: j + 1 + with ops.control_dependencies([j]): + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture_protected, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + + def body_args_capture(i, j): + # This version is ok because j is an argument to fn and we can + # ensure there's a control dependency on j. + fn = lambda x: x + 1 + return (i + 1, cs.execute(fn, j)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_args_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + + def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection(shared_name="cs") + cs_same = critical_section_ops.CriticalSection(shared_name="cs") + def fn(x): + return cs_same.execute(lambda x: x+1, x) + with self.assertRaisesRegexp( + ValueError, + r"attempts to directly access the CriticalSection in which it " + r"would be running"): cs.execute(fn, 1.0) def testMultipleCSExecutionsRequestSameResource(self): @@ -147,6 +318,39 @@ class CriticalSectionTest(test.TestCase): ValueError, "requested exclusive resource access"): cs1.execute(lambda: v2 + 1) + def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0, name="v") + # Make sure that the control dependencies on v do not cause issues + # in the lock_op's automatic control dependency adder. + # + # Note, here v must be a resource variable (or something similar), + # otherwise it gets hoisted into the while_loop by the time we add + # control dependencies to the lock_op. + out = control_flow_ops.while_loop( + lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) + self.evaluate(v.initializer) + self.assertEqual(10, self.evaluate(out)) + + @test_util.run_in_graph_and_eager_modes() + def testInsideFunction(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(1) + def fn(): + return v.read_value() + + # map() creates a TensorFlow function. + ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn)) + + def get_first(): + if context.executing_eagerly(): + return self.evaluate(ds.make_one_shot_iterator().get_next()) + itr = ds.make_initializable_iterator() + self.evaluate([v.initializer, itr.initializer]) + return self.evaluate(itr.get_next()) + + self.assertEqual(1, get_first()) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): @@ -167,7 +371,7 @@ class CriticalSectionTest(test.TestCase): # self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name) # def testToProto(self): - # cs = critical_section_ops.CriticalSection(name="cs") + # cs = critical_section_ops.CriticalSection(shared_name="cs") # proto = cs.to_proto() # self.assertEqual(proto.critical_section_name, cs._handle.name) # cs_copy = critical_section_ops.CriticalSection.from_proto(proto) diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 8f62f0ea7b9b561f235b9496ffda97a9f378d530..1921a77c1e96ee3531d1ed0f98e41c27c9d427ac 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -14,6 +14,7 @@ # ============================================================================== """Support for sorting tensors. +@@argsort @@sort """ @@ -21,6 +22,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -47,64 +51,141 @@ def sort(values, axis=-1, direction='ASCENDING', name=None): ValueError: If axis is not a constant scalar, or the direction is invalid. """ with framework_ops.name_scope(name, 'sort'): - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error + return _sort_or_argsort(values, axis, direction, return_argsort=False) + + +def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): + """Returns the indices of a tensor that give its sorted order along an axis. + + For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to + `tf.sort(values)`. For higher dimensions, the output has the same shape as + `values`, but along the given axis, values represent the index of the sorted + element in that slice of the tensor at the given position. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + stable: If True, equal elements in the original tensor will not be + re-ordered in the returned order. Unstable sort is not yet implemented, + but will eventually be the default for performance reasons. If you + require a stable order, pass `stable=True` for forwards compatibility. + name: Optional name for the operation. + + Returns: + An int32 `Tensor` with the same shape as `values`. The indices that would + sort each slice of the given `values` along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + del stable # Unused. + with framework_ops.name_scope(name, 'argsort'): + return _sort_or_argsort(values, axis, direction, return_argsort=True) + + +def _sort_or_argsort(values, axis, direction, return_argsort): + """Internal sort/argsort implementation. + + Args: + values: The input values. + axis: The axis along which to sort. + direction: 'ASCENDING' or 'DESCENDING'. + return_argsort: Whether to return the argsort result. + + Returns: + Either the sorted values, or the indices of the sorted values in the + original tensor. See the `sort` and `argsort` docstrings. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error - values = framework_ops.convert_to_tensor(values, name='values') + values = framework_ops.convert_to_tensor(values, name='values') - return _SORT_IMPL[direction](values, axis_static) + return _SORT_IMPL[direction](values, axis_static, return_argsort) -def _descending_sort(values, axis): +def _descending_sort(values, axis, return_argsort=False): """Sorts values in reverse using `top_k`. Args: values: Tensor of numeric values. axis: Index of the axis which values should be sorted along. + return_argsort: If False, return the sorted values. If True, return the + indices that would sort the values. Returns: The sorted values. """ k = array_ops.shape(values)[axis] rank = array_ops.rank(values) + static_rank = values.shape.ndims # Fast path: sorting the last axis. if axis == -1 or axis + 1 == values.get_shape().ndims: - return nn_ops.top_k(values, k)[0] - - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Make axis a Tensor with the real axis index if needed. - axis += rank - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - values, unused_indices = nn_ops.top_k(top_k_input, k) - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return array_ops.transpose(values, transposition) - - -def _ascending_sort(values, axis): + top_k_input = values + transposition = None + else: + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Calculate the actual axis index if counting from the end. Use the static + # rank if available, or else make the axis back into a tensor. + axis += static_rank or rank + if static_rank is not None: + # Prefer to calculate the transposition array in NumPy and make it a + # constant. + transposition = constant_op.constant( + np.r_[ + # Axes up to axis are unchanged. + np.arange(axis), + # Swap axis and rank - 1. + [static_rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + np.arange(axis + 1, static_rank - 1), + # Swap axis and rank - 1. + [axis]], + name='transposition') + else: + # Generate the transposition array from the tensors. + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + + values, indices = nn_ops.top_k(top_k_input, k) + return_value = indices if return_argsort else values + if transposition is not None: + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return_value = array_ops.transpose(return_value, transposition) + return return_value + + +def _ascending_sort(values, axis, return_argsort=False): # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis) - return -values_or_indices + values_or_indices = _descending_sort(-values, axis, return_argsort) + # If not argsort, negate the values again. + return values_or_indices if return_argsort else -values_or_indices _SORT_IMPL = { diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py index d08ae502f10d98ee14d8bea2f76b18bedb935cea..a8fb94b245dccc8c7cf0e94cef9b436f881fe408 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -24,6 +24,8 @@ from tensorflow.contrib.framework.python.ops import sort_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -90,6 +92,38 @@ class SortTest(test.TestCase): axis=0, direction='DESCENDING').eval()) + def testSort_staticallyKnownRank_constantTransposition(self): + # The transposition array should be a constant if the rank of "values" is + # statically known. + tensor = random_ops.random_uniform( + # Rank is statically known to be 5, but the dimension lengths are not + # known. + random_ops.random_uniform( + shape=(5,), minval=0, maxval=10, dtype=dtypes.int32)) + sort_ops.sort(tensor, axis=1) + transposition = ( + ops.get_default_graph().get_tensor_by_name('sort/transposition:0')) + self.assertFalse(tensor_util.constant_value(transposition) is None) + self.assertAllEqual( + # Swaps "1" and "4" to put "1" at the end. + tensor_util.constant_value(transposition), + [0, 4, 2, 3, 1]) + + def testArgsort_1d(self): + arr = np.random.random(42) + with self.test_session(): + self.assertAllEqual( + np.sort(arr), + array_ops.gather(arr, sort_ops.argsort(arr)).eval()) + + def testArgsort(self): + arr = np.random.random((5, 6, 7, 8)) + for axis in range(4): + with self.test_session(): + self.assertAllEqual( + np.argsort(arr, axis=axis), + sort_ops.argsort(arr, axis=axis).eval()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 0754c3e0e30a340910a43a3ce86f6ca10afe848e..40ae01bfcce1dde580e6a5f6d9c8ec1aa1abb83f 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import resource_loader from tensorflow.python.platform import tf_logging as logging @@ -82,7 +83,12 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"): """ loader.load_op_library( resource_loader.get_path_to_datafile("_variable_ops.so")) - return gen_variable_ops.zero_initializer(ref, name=name) + if resource_variable_ops.is_resource_variable(ref): + return gen_variable_ops.zero_var_initializer( + ref.handle, shape=ref.shape, dtype=ref.dtype, name=name) + else: + return gen_variable_ops.zero_initializer(ref, name=name) + @deprecated(None, "Please switch to tf.train.assert_global_step") def assert_global_step(global_step_tensor): diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 2f06df93acb0a4c0b36c68839ff531e3c22c5ee3..37ea6eb12aba7d25656f19cbbc86475c1228d916 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -1284,6 +1284,32 @@ class ZeroInitializerOpTest(test.TestCase): [10, 20], dtype=dtype), use_init) +class ZeroVarInitializerOpTest(test.TestCase): + + def _testZeroVarInitializer(self, shape, initializer, use_init): + var = resource_variable_ops.ResourceVariable(initializer) + var_zero = variables_lib2.zero_initializer(var) + + with self.test_session() as sess: + with self.assertRaisesOpError('Error while reading resource variable'): + var.eval() + if use_init: + sess.run(var.initializer) + with self.assertRaisesOpError('input is already initialized'): + var_zero.eval() + self.assertAllClose(np.ones(shape), var.eval()) + else: + var_zero.eval() + self.assertAllClose(np.zeros(shape), var.eval()) + + def testZeroVarInitializer(self): + for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): + for use_init in (False, True): + self._testZeroVarInitializer([10, 20], + array_ops.ones([10, 20], dtype=dtype), + use_init) + + class FilterVariablesTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index ce37672895b37275770d2f5410f662e9acf1de9d..0f0813c07f8bd330b089780064e02f8dfe7d49f6 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -75,6 +75,7 @@ tf_kernel_library( "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, ) @@ -94,6 +95,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_cuda//cuda:cudnn_header", ], ) @@ -157,15 +159,3 @@ cuda_py_test( "requires_cudnn6", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 0e06575d96f9b9538f0245b12d48cfd7c0e8d981..2458f7554afdc12709571c551a8323cda7fa5c17 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -247,7 +247,7 @@ class FusedConv2DBiasActivationOp : public OpKernel { }; #if GOOGLE_CUDA -namespace dnn = ::perftools::gputools::dnn; +namespace dnn = se::dnn; // A dummy type to group forward convolution autotune results together. struct ConvBiasActivationAutoTuneGroup { @@ -543,7 +543,8 @@ void LaunchFusedConv2DBiasActivationOp:: fused_conv_parameters, &algorithm_config)) { std::vector algorithms; CHECK(stream->parent()->GetConvolveAlgorithms( - fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), + fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo( + stream->parent()), &algorithms)); dnn::ProfileResult best_result; dnn::ProfileResult best_result_no_scratch; diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py index a97adf622e6e576f8b5ce2babe004cb3a46d80a5..983b6dc8e5a1512ba81ecbc8d5ca5adaea09afe4 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py @@ -65,7 +65,7 @@ def fused_conv2d_bias_activation(conv_input, side_input_scale: A scalar `float32` that will be multiplied by side_input. This is optional and defaults to 0. side_input: A `Tensor` of the format specified by `data_format`. - This is useful for imlementing ResNet blocks. + This is useful for implementing ResNet blocks. activation_mode: (optional) currently must be the default "Relu". Note that in qint8 mode, it also clips to 127, so acts like ReluX. data_format: Specifies the data format. diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index bb155aa2496cbafd9f0630d3dffb2ba69395186c..a955e21b72e765f751318c7927f9644481fe7933 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -33,6 +35,13 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging +def NoMemoryOptimizationConfig(): + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + return config + + def GetShrunkInceptionShapes(shrink=10): """Iterator for smaller versions of convolution shapes in 2015 Inception. @@ -193,7 +202,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): # This is to guarantee that there is always negative values after # bias add so that we can test whether relu works correctly. x3 = bias - with self.test_session(use_gpu=True): + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) fused_t2 = t2 @@ -241,7 +251,9 @@ class FusedConv2DBiasActivationTest(test.TestCase): x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) def _SetupVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session( + use_gpu=use_gpu, config=NoMemoryOptimizationConfig()): t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes) t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) @@ -289,8 +301,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): conv = tensors[i] value = values[i] ref_value = ref_values[i] - print("expected = ", ref_value) - print("actual = ", value) + tf_logging.info("expected = ", ref_value) + tf_logging.info("actual = ", value) tol = 1e-5 if value.dtype == np.float16: tol = 1e-3 @@ -566,7 +578,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, return Test -def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type): +def CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type): """Calculates the size of an output dimension of a strided convolution. Given the sizes of the corresponding dimension of the input and filter shapes, @@ -827,11 +839,12 @@ class FusedConvInt8Tests(test.TestCase): maxval=1.0, dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) - output_height = CalculateCovolvedOutputDim(input_height, filter_height, - vertical_stride, padding_type) - output_width = CalculateCovolvedOutputDim(input_width, filter_width, - horizontal_stride, padding_type) - print("output_height=", output_height, ", output_width=", output_width) + output_height = CalculateConvolvedOutputDim(input_height, filter_height, + vertical_stride, padding_type) + output_width = CalculateConvolvedOutputDim(input_width, filter_width, + horizontal_stride, padding_type) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -864,10 +877,12 @@ class FusedConvInt8Tests(test.TestCase): conv_input_scale, conv_input, kernel, padding_type, strides, side_input_scale, side_input, biases) - with self.test_session(use_gpu=True) as sess: + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session( + use_gpu=True, config=NoMemoryOptimizationConfig()) as sess: actual_y, expected_y = sess.run([actual, expected]) - print("actual_y = ", actual_y) - print("expected_y = ", expected_y) + tf_logging.info("actual_y = ", actual_y) + tf_logging.info("expected_y = ", expected_y) self.assertTrue(np.array_equal(actual_y, expected_y)) def testFusedConvInt8(self): diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 0eb0e3cbe20f5804db5476c08167d4e1c9080cfa..b305f37791d71f5a6edeada2bb710a2e5f23087d 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -354,6 +354,7 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", @@ -363,6 +364,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:variables", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -544,15 +546,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 082c42eba180917e732bb7890129dfa94bf00fec..4092b320042162e4eb4c5f4879c2c3ea5dc14fc9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -88,8 +88,8 @@ class GANEstimator(estimator.Estimator): discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5), - discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5)) + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) # Train estimator. gan_estimator.train(train_input_fn, steps) @@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator): generator_optimizer=None, discriminator_optimizer=None, get_hooks_fn=None, + get_eval_metric_ops_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator): list of hooks. These hooks are run on the generator and discriminator train ops, and can be used to implement the GAN training scheme. Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn) + use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 387a62bd741bd42c03dc1bf70592060c29ccd7a8..955482599b372be3f0d0cbc81451c514958d0eb1 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) + def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( @@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase): discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) # TRAIN @@ -213,6 +221,9 @@ class GANEstimatorIntegrationTest(test.TestCase): scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index a21358c50bbdb4a1a929b0c5bc322cec4c9923b5..ff903a78cc36c1965b7655aa902501b1943637a8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,17 +25,21 @@ from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib __all__ = [ 'GANHead', 'gan_head', ] +def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - name=None): + get_eval_metric_ops_fn=None, name=None): """Creates a `GANHead`. Args: @@ -47,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + If `None`, uses defaults. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. @@ -62,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn, name=name) @@ -72,6 +80,7 @@ class GANHead(head._Head): # pylint: disable=protected-access generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, + get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. @@ -85,8 +94,11 @@ class GANHead(head._Head): # pylint: disable=protected-access discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. Defaults to `train.get_sequential_train_hooks()` + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. Defaults to `train.get_sequential_train_hooks()` + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ @@ -104,6 +116,8 @@ class GANHead(head._Head): # pylint: disable=protected-access self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn + self._get_eval_metric_ops_fn = get_eval_metric_ops_fn + self._name = name @property def name(self): @@ -173,13 +187,26 @@ class GANHead(head._Head): # pylint: disable=protected-access gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + eval_metric_ops = { + _summary_key(self._name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(self._name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if self._get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, - # TODO(joelshor): Add metrics. If head name provided, append it to - # metric keys. - eval_metric_ops={}) + eval_metric_ops=eval_metric_ops) elif mode == model_fn_lib.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None.') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8168f005cd1105886390a2384a936663c83fa5f5..6587f1fc600b94d27f7c12b44ca2136d0be5a8c5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase): generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + get_eval_metric_ops_fn=self.get_metrics) self.assertTrue(isinstance(self.gan_head, head.GANHead)) + def get_metrics(self, gan_model): + self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) + return {} + def _test_modes_helper(self, mode): self.gan_head.create_estimator_spec( features=None, diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index fdfabd07c13f689d075ecbb8786d725fa8a62d01..d914f549457a1e893ed43a3b8bc1ae5be7bb4303 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -44,11 +44,11 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import resource_loader - __all__ = [ 'get_graph_def_from_disk', 'get_graph_def_from_resource', @@ -62,10 +62,11 @@ __all__ = [ 'frechet_inception_distance', 'frechet_classifier_distance', 'frechet_classifier_distance_from_activations', + 'mean_only_frechet_classifier_distance_from_activations', + 'diagonal_only_frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] - INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' INCEPTION_INPUT = 'Mul:0' @@ -77,8 +78,7 @@ INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): images = ops.convert_to_tensor(images) images.shape.with_rank(4) - images.shape.assert_is_compatible_with( - [None, image_size, image_size, None]) + images.shape.assert_is_compatible_with([None, image_size, image_size, None]) return images @@ -109,9 +109,10 @@ def _symmetric_matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -def preprocess_image( - images, height=INCEPTION_DEFAULT_IMAGE_SIZE, - width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): +def preprocess_image(images, + height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, + scope=None): """Prepare a batch of images for evaluation. This is the preprocessing portion of the graph from @@ -272,8 +273,11 @@ def run_inception(images, return activations -def run_image_classifier(tensor, graph_def, input_tensor, - output_tensor, scope='RunClassifier'): +def run_image_classifier(tensor, + graph_def, + input_tensor, + output_tensor, + scope='RunClassifier'): """Runs a network from a frozen graph. Args: @@ -317,7 +321,7 @@ def classifier_score(images, classifier_fn, num_batches=1): NOTE: This function consumes images, computes their logits, and then computes the classifier score. If you would like to precompute many logits for - large batches, use clasifier_score_from_logits(), which this method also + large batches, use classifier_score_from_logits(), which this method also uses. Args: @@ -433,8 +437,8 @@ def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above - sqrt_a_sigmav_a = math_ops.matmul( - sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) + sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, + math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) @@ -450,9 +454,9 @@ def frechet_classifier_distance(real_images, This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates + C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the @@ -463,7 +467,7 @@ def frechet_classifier_distance(real_images, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same - sample size to compute frechet classifier distance when comparing two + sample size to compute Frechet classifier distance when comparing two generative models. NOTE: This function consumes images, computes their activations, and then @@ -484,25 +488,25 @@ def frechet_classifier_distance(real_images, The Frechet Inception distance. A floating-point scalar of the same type as the output of `classifier_fn`. """ - real_images_list = array_ops.split( real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split( generated_images, num_or_size_splits=num_batches) - imgs = array_ops.stack(real_images_list + generated_images_list) + real_imgs = array_ops.stack(real_images_list) + generated_imgs = array_ops.stack(generated_images_list) # Compute the activations using the memory-efficient `map_fn`. - activations = functional_ops.map_fn( - fn=classifier_fn, - elems=imgs, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') + def compute_activations(elems): + return functional_ops.map_fn(fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') - # Split the activations by the real and generated images. - real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) + real_a = compute_activations(real_imgs) + gen_a = compute_activations(generated_imgs) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) @@ -511,10 +515,142 @@ def frechet_classifier_distance(real_images, return frechet_classifier_distance_from_activations(real_a, gen_a) -def frechet_classifier_distance_from_activations( +def mean_only_frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model from activations. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + In this variant, we only compute the difference between the means of the + fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet + still retains much of the same information as FID. + + Args: + real_activations: 2D array of activations of real images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + generated_activations: 2D array of activations of generated images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + + Returns: + The mean-only Frechet Inception distance. A floating-point scalar of the + same type as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute means of activations. + m = math_ops.reduce_mean(real_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + mofid = mean + if activations_dtype != dtypes.float64: + mofid = math_ops.cast(mofid, activations_dtype) + + return mofid + + +def diagonal_only_frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. In this variant, we compute diagonal-only covariance matrices. + As a result, instead of computing an expensive matrix square root, we can do + something much simpler, and has O(n) vs O(n^2) space complexity. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The diagonal-only Frechet Inception distance. A floating-point scalar of + the same type as the output of the activations. + + Raises: + ValueError: If the shape of the variance and mean vectors are not equal. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute mean and covariance matrices of activations. + m, var = nn_impl.moments(real_activations, axes=[0]) + m_w, var_w = nn_impl.moments(generated_activations, axes=[0]) + + actual_shape = var.get_shape() + expected_shape = m.get_shape() + + if actual_shape != expected_shape: + raise ValueError('shape: {} must match expected shape: {}'.format( + actual_shape, expected_shape)) + + # Compute the two components of FID. + + # First the covariance component. + # Here, note that trace(A + B) = trace(A) + trace(B) + trace = math_ops.reduce_sum( + (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + dofid = trace + mean + if activations_dtype != dtypes.float64: + dofid = math_ops.cast(dofid, activations_dtype) + + return dofid + + +def frechet_classifier_distance_from_activations(real_activations, + generated_activations): + """Classifier distance for evaluating a generative model. + This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large @@ -523,15 +659,22 @@ def frechet_classifier_distance_from_activations( This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates + C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. @@ -553,36 +696,40 @@ def frechet_classifier_distance_from_activations( # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) - m_v = math_ops.reduce_mean(generated_activations, 0) - num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) + m_w = math_ops.reduce_mean(generated_activations, 0) + num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0]) + num_examples_generated = math_ops.to_double( + array_ops.shape(generated_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( - real_centered, real_centered, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / ( + num_examples_real - 1) - gen_centered = generated_activations - m_v - sigma_v = math_ops.matmul( - gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_w + sigma_w = math_ops.matmul( + gen_centered, gen_centered, transpose_a=True) / ( + num_examples_generated - 1) - # Find the Tr(sqrt(sigma sigma_v)) component of FID - sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) + # Find the Tr(sqrt(sigma sigma_w)) component of FID + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component + trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid - frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 61dc8646ddc10605561ae6b19e90f4739c346608..4fb8d58bc9125664d42260de72b83b2362eff9ba 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -22,6 +22,7 @@ import os import tarfile import tempfile +from absl.testing import parameterized import numpy as np from scipy import linalg as scp_linalg @@ -50,6 +51,26 @@ def _expected_inception_score(logits): return np.exp(np.mean(per_example_logincscore)) +def _expected_mean_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + mean = np.square(m - m_v).sum() + mofid = mean + return mofid + + +def _expected_diagonal_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + var = np.var(real_imgs, axis=0) + var_v = np.var(gen_imgs, axis=0) + sqcc = np.sqrt(var * var_v) + mean = (np.square(m - m_v)).sum() + trace = (var + var_v - 2 * sqcc).sum() + dofid = mean + trace + return dofid + + def _expected_fid(real_imgs, gen_imgs): m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) @@ -162,13 +183,20 @@ def _run_with_mock(function, *args, **kwargs): return function(*args, **kwargs) -class ClassifierMetricsTest(test.TestCase): +class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): - def test_run_inception_graph(self): + @parameterized.named_parameters( + ('GraphDef', False), + ('DefaultGraphDefFn', True)) + def test_run_inception_graph(self, use_default_graph_def): """Test `run_inception` graph construction.""" batch_size = 7 img = array_ops.ones([batch_size, 299, 299, 3]) - logits = _run_with_mock(classifier_metrics.run_inception, img) + + if use_default_graph_def: + logits = _run_with_mock(classifier_metrics.run_inception, img) + else: + logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) self.assertTrue(isinstance(logits, ops.Tensor)) logits.shape.assert_is_compatible_with([batch_size, 1001]) @@ -176,14 +204,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) - def test_run_inception_graph_pool_output(self): + @parameterized.named_parameters( + ('GraphDef', False), + ('DefaultGraphDefFn', True)) + def test_run_inception_graph_pool_output(self, use_default_graph_def): """Test `run_inception` graph construction with pool output.""" batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) - pool = _run_with_mock( - classifier_metrics.run_inception, - img, - output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) + + if use_default_graph_def: + pool = _run_with_mock( + classifier_metrics.run_inception, + img, + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) + else: + pool = classifier_metrics.run_inception( + img, _get_dummy_graphdef(), + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -285,6 +322,46 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(_expected_inception_score(logits), incscore_np) + def test_mean_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_mofid = sess.run(mofid_op) + + expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_mofid, actual_mofid, 0.0001) + + def test_diagonal_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_dofid = sess.run(dofid_op) + + expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_dofid, actual_dofid, 0.0001) + def test_frechet_classifier_distance_value(self): """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 9bebcacbe46d85fc4226c4275b71b3ecbde57a97..4b1105f6bd4f21a0da02338b0fc9db87a41b145f 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): proj = random_ops.random_normal( [array_ops.shape(a)[1], random_projection_dim]) proj *= math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True)) # Project both distributions and sort them. proj_a = math_ops.matmul(a, proj) proj_b = math_ops.matmul(b, proj) @@ -212,7 +212,7 @@ def sliced_wasserstein_distance(real_images, Args: real_images: (tensor) Real images (batch, height, width, channels). fake_images: (tensor) Fake images (batch, height, width, channels). - resolution_min: (int) Minimum resolution for the Laplacion pyramid. + resolution_min: (int) Minimum resolution for the Laplacian pyramid. patches_per_image: (int) Number of patches to extract per image per Laplacian level. patch_size: (int) Width of a square patch. @@ -221,7 +221,7 @@ def sliced_wasserstein_distance(real_images, use_svd: experimental method to compute a more accurate distance. Returns: List of tuples (distance_real, distance_fake) for each level of the - Laplacian pyramid from the highest resoluion to the lowest. + Laplacian pyramid from the highest resolution to the lowest. distance_real is the Wasserstein distance between real images distance_fake is the Wasserstein distance between real and fake images. Raises: diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 0d1afad72da8a8e087239868e25ddebe23490d1e..508f487722fba89cc8391a340f73673a526e86c4 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -31,6 +31,7 @@ __all__ = [ 'add_image_comparison_summaries', 'add_gan_model_summaries', 'add_regularization_loss_summaries', + 'add_cyclegan_image_summaries', ] @@ -51,14 +52,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): ValueError: If real and generated data aren't images. """ if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_summaries'): - add_gan_model_image_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_summaries'): - add_gan_model_image_summaries(gan_model.model_y2x, **saved_params) - return - + raise ValueError( + '`add_gan_model_image_summaries` does not take CycleGANModels. Please ' + 'use `add_cyclegan_image_summaries` instead.') _assert_is_image(gan_model.real_data) _assert_is_image(gan_model.generated_data) @@ -89,6 +85,49 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): add_gan_model_summaries(gan_model) +def add_cyclegan_image_summaries(cyclegan_model): + """Adds image summaries for CycleGAN. + + There are two summaries, one for each generator. The first image is the + generator input, the second is the generator output, and the third is G(F(x)). + + Args: + cyclegan_model: A CycleGANModel tuple. + + Raises: + ValueError: If `cyclegan_model` isn't a CycleGANModel. + ValueError: If generated data, generator inputs, and reconstructions aren't + images. + ValueError: If the generator input, generated data, and reconstructions + aren't all the same size. + """ + if not isinstance(cyclegan_model, namedtuples.CycleGANModel): + raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was ' + '%s' % type(cyclegan_model)) + + _assert_is_image(cyclegan_model.model_x2y.generator_inputs) + _assert_is_image(cyclegan_model.model_x2y.generated_data) + _assert_is_image(cyclegan_model.reconstructed_x) + _assert_is_image(cyclegan_model.model_y2x.generator_inputs) + _assert_is_image(cyclegan_model.model_y2x.generated_data) + _assert_is_image(cyclegan_model.reconstructed_y) + + def _add_comparison_summary(gan_model, reconstructions): + image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) + + array_ops.unstack(gan_model.generated_data[:1]) + + array_ops.unstack(reconstructions[:1])) + summary.image( + 'image_comparison', eval_utils.image_reshaper( + image_list, num_cols=len(image_list)), max_outputs=1) + + with ops.name_scope('x2y_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_x2y, cyclegan_model.reconstructed_x) + with ops.name_scope('y2x_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_y2x, cyclegan_model.reconstructed_y) + + def add_image_comparison_summaries(gan_model, num_comparisons=2, display_diffs=False): """Adds image summaries to compare triplets of images. @@ -109,15 +148,6 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, ValueError: If the generator input, real, and generated data aren't all the same size. """ - if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_y2x, **saved_params) - return - _assert_is_image(gan_model.generator_inputs) _assert_is_image(gan_model.generated_data) _assert_is_image(gan_model.real_data) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 45eb108586bed07434ac29595164745eac6054c1..33d51bfc218ab93fb52439b1eefed98a4568c4a1 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -65,15 +65,14 @@ def get_cyclegan_model(): return namedtuples.CycleGANModel( model_x2y=model_x2y, model_y2x=model_y2x, - reconstructed_x=array_ops.zeros([3, 30, 35, 6]), - reconstructed_y=array_ops.zeros([3, 30, 35, 6])) + reconstructed_x=array_ops.zeros([4, 32, 32, 3]), + reconstructed_y=array_ops.zeros([4, 32, 32, 3])) class SummariesTest(test.TestCase): - def _test_add_gan_model_image_summaries_impl(self, get_model_fn, - expected_num_summary_ops, - model_summaries): + def _test_add_gan_model_image_summaries_impl( + self, get_model_fn, expected_num_summary_ops, model_summaries): summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, model_summaries=model_summaries) @@ -89,8 +88,9 @@ class SummariesTest(test.TestCase): def test_add_gan_model_image_summaries_no_model(self): self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) - def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True) + def test_cyclegan_image_summaries_dont_work(self): + with self.assertRaises(ValueError): + summaries.add_gan_model_image_summaries(get_cyclegan_model()) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): @@ -137,7 +137,11 @@ class SummariesTest(test.TestCase): self._test_add_image_comparison_summaries_impl(get_gan_model, 1) def test_add_image_comparison_summaries_for_cyclegan(self): - self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2) + summaries.add_cyclegan_image_summaries(get_cyclegan_model()) + + self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + with self.test_session(use_gpu=True): + summary.merge_all().eval() if __name__ == '__main__': diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py index df71187fbd98c8ce1372bb89c83656dd666ce677..a9b8faa7126253126a3bc3c30e831b26b8326996 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples.""" +"""Miscellaneous utilities for TFGAN code and examples.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index cd31c62667fc048b1003d334377405b284f32af5..e2594faf85bcf91cbe09f266e4d4211d20bdee17 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples. +"""Miscellaneous utilities for TFGAN code and examples. Includes: 1) Conditioning the value of a Tensor, based on techniques from diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 4cfae0de4451880cf8229903b0eb74b1c6e2e04d..9e4ec59e7098443efc53506a4ba159e84b5c1618 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -17,7 +17,7 @@ We use this to keep a history of values created by a generator, such that a discriminator can randomly be trained on some older samples, not just the current one. This can help to not let the discriminator get too far ahead of the -generator and also to keep the system from oscilating, if the discriminator +generator and also to keep the system from oscillating, if the discriminator forgets too fast what past samples from the generator looked like. See the following papers for more details. @@ -97,7 +97,7 @@ def tensor_pool(input_values, dtypes=[v.dtype for v in input_values], shapes=None) - # In pseudeo code this code does the following: + # In pseudo code this code does the following: # if not pool_full: # enqueue(input_values) # return input_values diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f8b372546b60ec8fa5fd1d72b57adaf67596c059..650eab97a3952e9aec2b489fffcc83c3bc49f2dd 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -64,11 +64,11 @@ def _statistics(x, axes): y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True)) + shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True) + shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True) + mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) mean = array_ops.squeeze(mean, axes) mean_squared = array_ops.squeeze(mean_squared, axes) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py index 845f89827b6e60eda41a55a80671f43460247b05..2fe06a287284ff994326d5a977a2e4d4634268ae 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -148,7 +148,7 @@ class VirtualBatchnormTest(test.TestCase): self.assertAllClose(bn_np[i, ...], vb_np) def test_minibatch_independent(self): - """Test that virtual batch normalized exampels are independent. + """Test that virtual batch normalized examples are independent. Unlike batch normalization, virtual batch normalization has the property that the virtual batch normalized value of an example is independent of the diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 39588b7219ebac1cc4855532be3fcc38e6381134..1ba3a641671c7f2a411a0c5f99228ca16eee1080 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -306,6 +306,7 @@ def wasserstein_gradient_penalty( discriminator_scope, epsilon=1e-10, target=1.0, + one_sided=False, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -327,6 +328,8 @@ def wasserstein_gradient_penalty( computing the gradient norm. target: Optional Python number or `Tensor` indicating the target value of gradient norm. Defaults to 1.0. + one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 + is used. Defaults to `False`. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -377,10 +380,13 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes / target - 1.0) + penalties = slopes / target - 1.0 + if one_sided: + penalties = math_ops.maximum(0., penalties) + penalties_squared = math_ops.square(penalties) penalty = losses.compute_weighted_loss( - penalties, weights, scope=scope, loss_collection=loss_collection, - reduction=reduction) + penalties_squared, weights, scope=scope, + loss_collection=loss_collection, reduction=reduction) if add_summaries: summary.scalar('gradient_penalty_loss', penalty) @@ -665,7 +671,7 @@ def least_squares_discriminator_loss( loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): - """Least squares generator loss. + """Least squares discriminator loss. This loss comes from `Least Squares Generative Adversarial Networks` (https://arxiv.org/abs/1611.04076). diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index dbaa624ae9d6a5a5949db692e52c0c1deb18b8df..9f5fee45422e0b9bcbc73674e55ae395ea8533d5 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,28 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_using_one_sided_mode(self): + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + one_sided=True) + self.assertEqual(generated_data.dtype, loss.dtype) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run(loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): """Test loss value with non default gradient norm target.""" generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) @@ -548,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul' + self._expected_op_name = 'mutual_information_loss/mul_1' self._batch_size = 2 diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 776eb11ecb1624544d24611d8fe6ca19768b8313..6fa43059f3125daea080f780210223363d0a89f9 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -461,6 +461,7 @@ def gan_loss( gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, gradient_penalty_target=1.0, + gradient_penalty_one_sided=False, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, @@ -485,6 +486,8 @@ def gan_loss( gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python number or `Tensor` indicating the target value of gradient norm. See the CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. + gradient_penalty_one_sided: If `True`, penalty proposed in + https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -546,6 +549,7 @@ def gan_loss( model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, + one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): @@ -706,7 +710,10 @@ def gan_train_ops( be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): - saved_params = locals() + # Get and store all arguments other than model and loss from locals. + # Contents of locals should not be modified, may not affect values. So make + # a copy. https://docs.python.org/2/library/functions.html#locals. + saved_params = dict(locals()) saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index f9bdaa74c948ecee11d5cfd89f06087924f8dace..3ebbe55d059e5e72607bc4efdbf95a6c96d99f11 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -359,10 +359,12 @@ class GANLossTest(test.TestCase): self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) # Test gradient penalty option. - def _test_grad_penalty_helper(self, create_gan_model_fn): + def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False): model = create_gan_model_fn() loss = train.gan_loss(model) - loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0) + loss_gp = train.gan_loss(model, + gradient_penalty_weight=1.0, + gradient_penalty_one_sided=one_sided) self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) # Check values. @@ -394,6 +396,25 @@ class GANLossTest(test.TestCase): def test_grad_penalty_callable_acgan(self): self._test_grad_penalty_helper(create_callable_acgan_model) + def test_grad_penalty_one_sided_gan(self): + self._test_grad_penalty_helper(create_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_gan(self): + self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_infogan(self): + self._test_grad_penalty_helper(create_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_infogan(self): + self._test_grad_penalty_helper( + create_callable_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_acgan(self): + self._test_grad_penalty_helper(create_acgan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_acgan(self): + self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True) + # Test mutual information penalty option. def _test_mutual_info_penalty_helper(self, create_gan_model_fn): train.gan_loss(create_gan_model_fn(), diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 707ae25d485c64f15694ee0e357f32b619d3cd33..e534fdc17749974ebe713c2730682bea6d7a85e4 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -9,18 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - filegroup( name = "c_srcs", data = glob([ diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 28f68cec8cce126f1b177a73e197ccd7ab749f4a..94f522c04e5a09ed2d9355fa675125c340407923 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -155,7 +155,7 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { } Device* dst_device; - Status s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + Status s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); if (!s.ok()) { sess->worker_cache->ReleaseWorker(src_worker, rwi); done(s, Args(), recv_args, Tensor{}, false); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index 1f9dd0decb84cf9b7b703f18c061d3c0c7a1cb25..9025c992a4467f521d6d8d514e6a5e92f5492947 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -57,7 +57,7 @@ Status GdrServer::Init() { new GdrWorker(env, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR( - GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); + GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func)); return remote_memory_manager_->Init(); } diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 967ad2fc090906e93f22c777816eede37f9a1b04..1711100e3a857dba0d15c5b4f6c96cddc568e800 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -39,18 +39,6 @@ py_library( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "match", srcs = ["tests/match.py"], diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 7ffdbb7139281734917fdb715601b317eb58b82f..95c02a64d47c26e731ef2628fb551529e9bc3f4d 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -471,9 +471,10 @@ def remove_control_inputs(op, cops): if cop not in op.control_inputs: raise ValueError("{} is not a control_input of {}".format(op.name, cop.name)) + control_inputs = [cop for cop in op.control_inputs if cop not in cops] # pylint: disable=protected-access - op._control_inputs = [cop for cop in op._control_inputs if cop not in cops] - op._recompute_node_def() + op._remove_all_control_inputs() + op._add_control_inputs(control_inputs) # pylint: enable=protected-access @@ -496,9 +497,6 @@ def add_control_inputs(op, cops): if cop in op.control_inputs: raise ValueError("{} is already a control_input of {}".format(cop.name, op.name)) - # pylint: disable=protected-access - op._control_inputs += cops - op._recompute_node_def() - # pylint: enable=protected-access + op._add_control_inputs(cops) # pylint: disable=protected-access remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index 3ea6ff4d6163b107ca0daaf3b9ad1daf0ccc1f6f..d700e6e1a7523622f845acbbc353eb0f438c9bc2 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -383,6 +383,7 @@ def get_within_boundary_ops(ops, def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, + within_ops_fn=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. @@ -395,6 +396,9 @@ def get_forward_walk_ops(seed_ops, within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. @@ -423,7 +427,8 @@ def get_forward_walk_ops(seed_ops, seed_ops &= within_ops def is_within(op): - return within_ops is None or op in within_ops + return (within_ops is None or op in within_ops) and ( + within_ops_fn is None or within_ops_fn(op)) result = list(seed_ops) wave = set(seed_ops) @@ -450,6 +455,7 @@ def get_forward_walk_ops(seed_ops, def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, + within_ops_fn=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. @@ -462,6 +468,9 @@ def get_backward_walk_ops(seed_ops, within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: @@ -488,7 +497,8 @@ def get_backward_walk_ops(seed_ops, seed_ops &= within_ops def is_within(op): - return within_ops is None or op in within_ops + return (within_ops is None or op in within_ops) and ( + within_ops_fn is None or within_ops_fn(op)) result = list(seed_ops) wave = set(seed_ops) @@ -516,6 +526,7 @@ def get_walks_intersection_ops(forward_seed_ops, forward_inclusive=True, backward_inclusive=True, within_ops=None, + within_ops_fn=None, control_inputs=False, control_outputs=None, control_ios=None): @@ -535,6 +546,9 @@ def get_walks_intersection_ops(forward_seed_ops, within_ops: an iterable of tf.Operation within which the search is restricted. If within_ops is None, the search is performed within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. @@ -555,11 +569,13 @@ def get_walks_intersection_ops(forward_seed_ops, forward_seed_ops, inclusive=forward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_outputs=control_outputs) backward_ops = get_backward_walk_ops( backward_seed_ops, inclusive=backward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_inputs=control_inputs) return [op for op in forward_ops if op in backward_ops] @@ -569,6 +585,7 @@ def get_walks_union_ops(forward_seed_ops, forward_inclusive=True, backward_inclusive=True, within_ops=None, + within_ops_fn=None, control_inputs=False, control_outputs=None, control_ios=None): @@ -587,6 +604,9 @@ def get_walks_union_ops(forward_seed_ops, resulting set. within_ops: restrict the search within those operations. If within_ops is None, the search is done within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. @@ -607,11 +627,13 @@ def get_walks_union_ops(forward_seed_ops, forward_seed_ops, inclusive=forward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_outputs=control_outputs) backward_ops = get_backward_walk_ops( backward_seed_ops, inclusive=backward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_inputs=control_inputs) return util.concatenate_unique(forward_ops, backward_ops) diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py index 82f999637d0c1866a5a329974f021fe2e30fd33f..d12c6d3cbd11dde2b609a59154297a8907b0cadc 100644 --- a/tensorflow/contrib/graph_editor/tests/select_test.py +++ b/tensorflow/contrib/graph_editor/tests/select_test.py @@ -77,12 +77,10 @@ class SelectTest(test.TestCase): """Test for ge.get_ops_ios.""" control_outputs = ge.util.ControlOutputs(self.graph) self.assertEqual( - len(ge.get_ops_ios( - self.h.op, control_ios=control_outputs)), 3) + len(ge.get_ops_ios(self.h.op, control_ios=control_outputs)), 3) self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2) self.assertEqual( - len(ge.get_ops_ios( - self.c.op, control_ios=control_outputs)), 6) + len(ge.get_ops_ios(self.c.op, control_ios=control_outputs)), 6) self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5) def test_compute_boundary_ts_0(self): @@ -135,16 +133,49 @@ class SelectTest(test.TestCase): ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op]) self.assertEqual(len(ops), 2) + ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op]) + self.assertEqual(len(ops), 3) + self.assertTrue(self.a.op in ops) + self.assertTrue(self.c.op in ops) + self.assertTrue(self.f.op in ops) + + within_ops = [self.a.op, self.f.op] + ops = ge.get_walks_intersection_ops( + [self.a.op], [self.f.op], within_ops=within_ops) + self.assertEqual(len(ops), 0) + + within_ops_fn = lambda op: op in [self.a.op, self.f.op] + ops = ge.get_walks_intersection_ops( + [self.a.op], [self.f.op], within_ops_fn=within_ops_fn) + self.assertEqual(len(ops), 0) + def test_get_walks_union(self): """Test for ge.get_walks_union_ops.""" ops = ge.get_walks_union_ops([self.f.op], [self.g.op]) self.assertEqual(len(ops), 6) + ops = ge.get_walks_union_ops([self.a.op], [self.f.op]) + self.assertEqual(len(ops), 8) + + within_ops = [self.a.op, self.c.op, self.d.op, self.f.op] + ops = ge.get_walks_union_ops([self.a.op], [self.f.op], + within_ops=within_ops) + self.assertEqual(len(ops), 4) + self.assertTrue(self.b.op not in ops) + + within_ops_fn = lambda op: op in [self.a.op, self.c.op, self.f.op] + ops = ge.get_walks_union_ops([self.a.op], [self.f.op], + within_ops_fn=within_ops_fn) + self.assertEqual(len(ops), 3) + self.assertTrue(self.b.op not in ops) + self.assertTrue(self.d.op not in ops) + def test_select_ops(self): parameters = ( (("^foo/",), 7), (("^foo/bar/",), 4), - (("^foo/bar/", "a"), 5),) + (("^foo/bar/", "a"), 5), + ) for param, length in parameters: ops = ge.select_ops(*param, graph=self.graph) self.assertEqual(len(ops), length) @@ -152,7 +183,8 @@ class SelectTest(test.TestCase): def test_select_ts(self): parameters = ( (".*:0", 8), - (r".*/bar/\w+:0", 4),) + (r".*/bar/\w+:0", 4), + ) for regex, length in parameters: ts = ge.select_ts(regex, graph=self.graph) self.assertEqual(len(ts), length) @@ -160,12 +192,121 @@ class SelectTest(test.TestCase): def test_select_ops_and_ts(self): parameters = ( (("^foo/.*",), 7, 0), - (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),) + (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4), + ) for param, l0, l1 in parameters: ops, ts = ge.select_ops_and_ts(*param, graph=self.graph) self.assertEqual(len(ops), l0) self.assertEqual(len(ts), l1) + def test_forward_walk_ops(self): + seed_ops = [self.a.op, self.d.op] + # Include all ops except for self.g.op + within_ops = [ + x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] + ] + # For the fn, exclude self.e.op. + within_ops_fn = lambda op: op not in (self.e.op,) + stop_at_ts = (self.f,) + + with self.graph.as_default(): + # No b.op since it's an independent source node. + # No g.op from within_ops. + # No e.op from within_ops_fn. + # No h.op from stop_at_ts and within_ops. + ops = ge.select.get_forward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual( + set(ops), set([self.a.op, self.c.op, self.d.op, self.f.op])) + + # Also no a.op and d.op when inclusive=False + ops = ge.select.get_forward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.c.op, self.f.op])) + + # Not using within_ops_fn adds e.op. + ops = ge.select.get_forward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.c.op, self.e.op, self.f.op])) + + # Not using stop_at_ts adds back h.op. + ops = ge.select.get_forward_walk_ops( + seed_ops, inclusive=False, within_ops=within_ops) + self.assertEqual( + set(ops), set([self.c.op, self.e.op, self.f.op, self.h.op])) + + # Starting just form a (the tensor, not op) omits a, b, d. + ops = ge.select.get_forward_walk_ops([self.a], inclusive=True) + self.assertEqual( + set(ops), set([self.c.op, self.e.op, self.f.op, self.g.op, + self.h.op])) + + def test_backward_walk_ops(self): + seed_ops = [self.h.op] + # Include all ops except for self.g.op + within_ops = [ + x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] + ] + # For the fn, exclude self.c.op. + within_ops_fn = lambda op: op not in (self.c.op,) + stop_at_ts = (self.f,) + + with self.graph.as_default(): + # Backward walk only includes h since we stop at f and g is not within. + ops = ge.select.get_backward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.h.op])) + + # If we do inclusive=False, the result is empty. + ops = ge.select.get_backward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set()) + + # Removing stop_at_fs adds f.op, d.op. + ops = ge.select.get_backward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn) + self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op])) + + # Not using within_ops_fn adds back ops for a, b, c. + ops = ge.select.get_backward_walk_ops( + seed_ops, inclusive=True, within_ops=within_ops) + self.assertEqual( + set(ops), + set([ + self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op + ])) + + # Vanially backward search via self.h.op includes everything excpet e.op. + ops = ge.select.get_backward_walk_ops(seed_ops, inclusive=True) + self.assertEqual( + set(ops), + set([ + self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op, + self.h.op + ])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ca00394388f67e2ed9508684a47b23c3ee9e79e8..97f38c923f4a19cedf3e16203ca1e66b7e5e45d2 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -18,11 +18,14 @@ from __future__ import division from __future__ import print_function import collections +import functools import numpy as np from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -41,6 +44,7 @@ class TransformTest(test.TestCase): self.graph = ops.Graph() with self.graph.as_default(): c0 = constant_op.constant(1.0, shape=[10], name="Const") + c0.op._set_attr("_foo", attr_value_pb2.AttrValue(s=b"foo")) c1 = constant_op.constant(1.0, shape=[10], name="Const") c2 = constant_op.constant(1.0, shape=[10], name="Const") i = constant_op.constant(1.0, shape=[10], name="Input") @@ -84,9 +88,9 @@ class TransformTest(test.TestCase): def test_transform(self): transformer = ge.Transformer() - def my_transform_op_handler(info, op): + def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") - op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op @@ -111,6 +115,32 @@ class TransformTest(test.TestCase): top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) + def test_transform_nodedef_fn(self): + transformer = ge.Transformer() + + def nodedef_fn(node_def): + if "_foo" in node_def.attr: + del node_def.attr["_foo"] + node_def.attr["_bar"].s = b"bar" + return node_def + + my_copy_op_handler = functools.partial( + ge.transform.copy_op_handler, nodedef_fn=nodedef_fn) + transformer.transform_op_handler = my_copy_op_handler + + graph = ops.Graph() + transformer(self.graph, graph, "", "") + + c0_before = self.graph.get_operation_by_name("Const") + c0_after = graph.get_operation_by_name("Const") + self.assertEquals(c0_before.get_attr("_foo"), b"foo") + with self.assertRaises(ValueError): + c0_after.get_attr("_foo") + + all_ops = graph.get_operations() + for op in all_ops: + self.assertEquals(op.get_attr("_bar"), b"bar") + def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = constant_op.constant(10.0, shape=[10], name="Input") @@ -201,15 +231,56 @@ class TransformTest(test.TestCase): get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. - self.assertEquals(original_mul1_grad._original_op.name, u"mul1") - self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") - self.assertNotEquals(res.name, g.name) + self.assertEqual(original_mul1_grad._original_op.name, u"mul1") + self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + def test_graph_while_loop(self): + graph = ops.Graph() + with graph.as_default(): + max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) + index_start = constant_op.constant(1) + sum_start = constant_op.constant(0) + _, result = control_flow_ops.while_loop( + cond=lambda i, unused_s: i <= max_index, + body=lambda i, s: (i + 1, s + i), + loop_vars=[index_start, sum_start]) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_max_index = copy_info.transformed(max_index) + with copied_graph.as_default(): + with session.Session() as sess: + n = 10 + sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) + self.assertEqual(sum_val, 55) + + def test_graph_cond(self): + graph = ops.Graph() + with graph.as_default(): + choice = array_ops.placeholder(shape=(), dtype=dtypes.bool) + result = control_flow_ops.cond( + choice, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_choice = copy_info.transformed(choice) + with copied_graph.as_default(): + with session.Session() as sess: + res = sess.run(copied_result, feed_dict={copied_choice: True}) + self.assertEqual(res, 1) + res = sess.run(copied_result, feed_dict={copied_choice: False}) + self.assertEqual(res, 2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 14ac5296657d48c7f9e94d220c9e7e28af4d4353..026a3d1200033400472c4fd763a244c04b284a9b 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -129,36 +129,51 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, copy_shape=True): +def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. + new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor + nodedef_fn: If provided, a function that will be run on the NodeDef + and should return a mutated NodeDef before a new Operation is created. + This is useful as certain features cannot be set on the Operation and + must be modified in NodeDef. + Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ + # The `new_inputs` was added to this function. For compatibility reason, + # let's raise an error if `new_inputs` is a boolean. + if isinstance(new_inputs, bool): + raise TypeError("the `new_inputs` argument must be an iterable.") + # pylint: disable=protected-access # Clone the node def: - node_def_ = deepcopy(op._node_def) + node_def_ = deepcopy(op.node_def) # Transform name: name_ = info.new_name(op.name) name_ = info.graph_.unique_name(name_) node_def_.name = name_ + # Mutate NodeDef if requested: + if nodedef_fn is not None: + node_def_ = nodedef_fn(node_def_) + # Copy the other inputs needed for initialization output_types_ = op._output_types[:] input_types_ = op._input_types[:] # Make a copy of the op_def too. # Its unique to every _type_ of Operation. - op_def_ = deepcopy(op._op_def) + op_def_ = deepcopy(op.op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, [], input_types_, None, op_def_) # copy the shape over @@ -170,12 +185,10 @@ def copy_op_handler(info, op, copy_shape=True): # attribute to exist, we will create a dummy original_op first and then # later finalise it with the actual original_op when all the ops have # been copied. + # TODO(fkp): Stop worrying about _original_op and remove this code? if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -328,6 +341,14 @@ class _TmpInfo(object): for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler + # The graph is transformed op by op, in the same order the original ops + # were created. However, this is sometimes not possible due to cycles + # (i.e. while loops). So when the transformer creates a new op whose + # inputs do not exist yet, temporary placeholders are created and stored + # in this `tmp_cyclic_ts` container. During a second pass, + # those temporary tensors are replaced by the proper transformed tensors + # (see the function `_finalize_cycles`). + self.tmp_cyclic_ts = [] def new_name(self, name): """Compute a destination name from a source name. @@ -428,10 +449,10 @@ class Transformer(object): # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) - info.transform_original_op_handler = self.transform_original_op_handler self._copy_ops(info) - self._connect_ops(info) + self._finalize_cycles(info) + self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) @@ -440,10 +461,10 @@ class Transformer(object): def _copy_ops(self, info): """Copy ops without connecting them.""" - for op in info.sgv.ops: - logging.debug("Copying op: %s", op.name) - # TODO(fkp): return a subgraph? - op_, op_outputs_ = self.transform_op_handler(info, op) + sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access + for op in sorted_ops: + new_inputs = [self._transformed_t(info, t, op) for t in op.inputs] + op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) if op is op_: raise ValueError("In-place transformation not allowed.") @@ -456,27 +477,36 @@ class Transformer(object): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_) - def _connect_ops(self, info): + def _finalize_cycles(self, info): + """Reconnects the cyclic tensors.""" + for t, tmp_t_, consumer_op in info.tmp_cyclic_ts: + if t not in info.transformed_ts: + raise ValueError("The tensor {} should be transformed by now.".format( + t.name)) + if consumer_op not in info.transformed_ops: + raise ValueError("The op {} should be transformed by now.".format( + consumer_op.name)) + t_ = info.transformed_ts[t] + consumer_op_ = info.transformed_ops[consumer_op] + t_index_ = list(consumer_op_.inputs).index(tmp_t_) + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access + + def _connect_control_inputs(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: - logging.debug("Finalizing op: %s", op.name) + logging.debug("Connecting control inputs of op: %s", op.name) op_ = info.transformed_ops[op] - # pylint: disable=protected-access - if op_.inputs: - raise ValueError("The newly transformed op should not have " - "any inputs yet: {}".format(op_.name)) - inputs_ = [self._transformed_t(info, t) for t in op.inputs] - for t in inputs_: - op_._add_input(t) - # Finalize original op. + # TODO(fkp): Stop worrying about _original_op and remove this code? + # pylint: disable=protected-access if op._original_op: - original_op = info.transform_original_op_handler(info, op._original_op) + original_op = self.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op + # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) @@ -525,19 +555,38 @@ class Transformer(object): return sgv_.remap(input_map_, output_map_) - def _transformed_t(self, info, t): + def _transformed_t(self, info, t, consumer_op): """Return tre transformed tensor of `t`.""" - if t not in info.transformed_ts: - # If op is not in the subgraph. - if t in info.sgv_inputs_set: - # t is an input of the subgraph. - return self.transform_external_input_handler(info, t) + if t in info.transformed_ts: + # If op is in the subgraph, just return its transformed counterpart. + return info.transformed_ts[t] + + if t in info.sgv_inputs_set: + # `t` is an input of the subgraph. + return self.transform_external_input_handler(info, t) + elif t.op in info.ops: + # `t` is an internal tensor but is not transformed yet because it + # belongs to a graph cycle. + logging.debug("Cyclic tensor: t.name = %s", t.name) + # Try to find an existing tensor we can use for now, + # otherwise create one. We'll rewire this later. + if consumer_op.type == "Merge": + first_input = consumer_op.inputs[0] + tmp_t_ = self._transformed_t(info, first_input, consumer_op) + elif t.op.type == "Enter": + enter_input = t.op.inputs[0] + tmp_t_ = self._transformed_t(info, enter_input, consumer_op) else: - # t is a hidden input of the subgraph. - return self.transform_external_hidden_input_handler(info, t) + with info.graph_.as_default(): + tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, + prefix="geph_tmp") + logging.debug("Created temporary placeholder: %s.", tmp_t_.name) + # Register as temporary and return. + info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) + return tmp_t_ else: - # If op is in the subgraph, just return its transformed. - return info.transformed_ts[t] + # `t` is a hidden input of the subgraph. + return self.transform_external_hidden_input_handler(info, t) def copy(sgv, dst_graph=None, dst_scope="", src_scope="", @@ -624,6 +673,40 @@ def copy_with_input_replacements(sgv, replacement_ts, sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) +def _add_control_flow_ops(ops, control_ios): + """Complete `ops` so that the transformed graph is valid. + + Partially copying a graph can lead to a malformed graph. For instance, + copying half of a while construct is likely to result in an invalid graph. + This function attempts to add missing ops so that the transformation result + in a valid graph. + + Args: + ops: list of ops (modifed in-place). + control_ios: object created by a call to `util.ControlOutputs`. + """ + # Find while contexts. + control_flow_contexts = set() + for op in ops: + cfc = op._control_flow_context # pylint: disable=protected-access + if cfc: + control_flow_contexts.add(cfc) + # Find new ops. + new_ops = [] + for cfc in control_flow_contexts: + if cfc.IsWhileContext(): + new_ops += select.get_walks_intersection_ops( + [enter_t.op for enter_t in cfc.loop_enters], + [exit_t.op for exit_t in cfc.loop_exits], + control_ios=control_ios) + # Add new ops. + new_ops_set = set(new_ops) + ops_set = frozenset(ops) + for op in new_ops_set: + if op not in ops_set: + ops.append(op) + + def graph_replace(target_ts, replacement_ts, dst_scope="", src_scope="", reuse_dst_scope=False): """Create a new graph which compute the targets from the replaced Tensors. @@ -657,8 +740,13 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") + + # Complete ops to avoid malformed control flow. + # TODO(fkp): Consider moving this function deeper (in the transformer?). + _add_control_flow_ops(ops, control_ios) + # Create a copy of the relevant subgraph - _, info = copy_with_input_replacements( + unused_sgv_, info = copy_with_input_replacements( ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) # Return the transformed targets but keep the original if the transformed # counterpart cannot be found diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 30bc33b9ee42ba78bc7307c67c0fc0af9f3356ef..584f4509ccc0aab30edc2be3bad7a9cb938d6e6a 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -38,6 +38,11 @@ __all__ = [ ] +# The graph editor sometimes need to create placeholders, they are named +# "geph_*". "geph" stands for Graph-Editor PlaceHolder. +_DEFAULT_PLACEHOLDER_PREFIX = "geph" + + def concatenate_unique(la, lb): """Add all the elements of `lb` to `la` if they are not there already. @@ -405,7 +410,7 @@ def scope_basename(scope): return scope[slash + 1:] -def placeholder_name(t=None, scope=None): +def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create placeholder name for the graph editor. Args: @@ -413,6 +418,7 @@ def placeholder_name(t=None, scope=None): on scope: absolute scope with which to prefix the placeholder's name. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A new placeholder name prefixed by "geph". Note that "geph" stands for Graph Editor PlaceHolder. This convention allows to quickly identify the @@ -430,19 +436,20 @@ def placeholder_name(t=None, scope=None): if scope is None: scope = op_dirname - if op_basename.startswith("geph__"): + if op_basename.startswith("{}__".format(prefix)): ph_name = op_basename else: - ph_name = "geph__{}_{}".format(op_basename, t.value_index) + ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) return scope + ph_name else: if scope is None: scope = "" - return scope + "geph" + return "{}{}".format(scope, prefix) -def make_placeholder_from_tensor(t, scope=None): +def make_placeholder_from_tensor(t, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a `tf.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -452,17 +459,19 @@ def make_placeholder_from_tensor(t, scope=None): (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of `t` is preserved. `""` means the root scope. + prefix: placeholder name prefix. Returns: A newly created `tf.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ return tf_array_ops.placeholder( - dtype=t.dtype, shape=t.get_shape(), name=placeholder_name( - t, scope=scope)) + dtype=t.dtype, shape=t.get_shape(), + name=placeholder_name(t, scope=scope, prefix=prefix)) -def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): +def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -474,11 +483,13 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder( - dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + dtype=dtype, shape=shape, + name=placeholder_name(scope=scope, prefix=prefix)) _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD index d601a1ec6f7a219bcd461d819ab2dfc64135a3ae..d0b44640667010b58c017d933d50ae5f87e8b275 100644 --- a/tensorflow/contrib/grid_rnn/BUILD +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -41,15 +41,3 @@ cuda_py_tests( "//tensorflow/python:variables", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 252788140f8c1906718c150574b963385b6ecfa1..bcd2a34c4e791a2ab66a439109145d6b78c14e22 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -110,7 +110,7 @@ class GridRNNCell(rnn.RNNCell): logging.warning('%s: Using a concatenated state is slower and will ' 'soon be deprecated. Use state_is_tuple=True.', self) if not output_is_tuple: - logging.warning('%s: Using a concatenated output is slower and will' + logging.warning('%s: Using a concatenated output is slower and will ' 'soon be deprecated. Use output_is_tuple=True.', self) if num_dims < 1: diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index 1b528d7afc1112f5dc0667ae299ade02bc8fd04b..d65b2d6026dd89959aa62b57e07b073eef84572c 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -23,14 +23,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/hvx/README.md b/tensorflow/contrib/hvx/README.md index 163993a3f6bb1bedcdffb32944a98c7cc846878e..68e34f3b0938f795c8ad4c8c75226f6b0afe188d 100644 --- a/tensorflow/contrib/hvx/README.md +++ b/tensorflow/contrib/hvx/README.md @@ -42,11 +42,12 @@ If you've finished walking through the quick start guide, you may want to try bu ### Build libhexagon\_nn\_skel.so -Download Hexagon NN library from codeaurora.org and build it. +Download Hexagon NN library from codeaurora.org and build it. For Hexagon SDK 3.0, we need use the compatible version([721b2d58f](https://source.codeaurora.org/quic/hexagon_nn/nnlib/commit/?id=721b2d58f0f4e2d5b182f41e6b7c4db5356bf0fb)) of nnlib. ```shell git clone https://source.codeaurora.org/quic/hexagon_nn/nnlib cd nnlib +git reset 721b2d58f --hard ``` Just follow the instructions in `README.HOW_TO_BUILD`. You can find the file `libhexagon_nn_skel.so` in `hexagon_Release_dynamic_toolv72_v60/ship`. diff --git a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD index 324035100df366b80f57af9052c4bd935655b248..e39c60b252a1b49a68d51302fff47734869dddfe 100644 --- a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD +++ b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD @@ -13,18 +13,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//visibility:public"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_cc_binary( name = "clock_cycle_profiling", testonly = 1, diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c index 6a5d982dc8514d69277b8f042ac1256e28715d9e..2e5c84704f8464ab46d740ea3c1eef0548826e8d 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c @@ -19,7 +19,7 @@ limitations under the License. #include "hexagon_controller.h" -#include +#include #include #include "adspmsgd.h" diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 909dc396a33b6fef1b2d51c3f52fab7782fc8ea5..0081fb61770075a2c36e92f65e01126f657edeb4 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -10,17 +10,6 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - tf_cc_binary( name = "hvx_ops_support_checker", testonly = 1, diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 60281951dda94008cad3a164be67d6fe8b59a916..66939fbb0f0d3bb5d2181e38428c038f661d3772 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def, HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "dump_all_nodes = " << dump_all_nodes - << ", dump_shape_and_tpye = " << dump_shape_and_type; + << ", dump_shape_and_type = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 3ff02e085ee63fabf42b3cc4389f4605455f3800..da450480b30b548484e69c61c85667d6dd390417 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -78,7 +78,10 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + ":dense_image_warp_py", ":image_ops", + ":interpolate_spline_py", + ":sparse_image_warp_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", @@ -194,6 +197,117 @@ cuda_py_test( ], ) +py_library( + name = "dense_image_warp_py", + srcs = [ + "python/ops/dense_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_library( + name = "interpolate_spline_py", + srcs = [ + "python/ops/interpolate_spline.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_library( + name = "sparse_image_warp_py", + srcs = [ + "python/ops/sparse_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dense_image_warp_py", + ":interpolate_spline_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "sparse_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/sparse_image_warp_test.py"], + additional_deps = [ + ":sparse_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], + data = [":sparse_image_warp_test_data"], + tags = ["no_pip"], +) + +filegroup( + name = "sparse_image_warp_test_data", + srcs = glob(["python/kernel_tests/test_data/*.png"]), +) + +cuda_py_test( + name = "dense_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/dense_image_warp_test.py"], + additional_deps = [ + ":dense_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + +cuda_py_test( + name = "interpolate_spline_test", + size = "medium", + srcs = ["python/kernel_tests/interpolate_spline_test.py"], + additional_deps = [ + ":interpolate_spline_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + tf_py_test( name = "segmentation_test", size = "medium", @@ -270,15 +384,3 @@ cuda_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index cc8ed117ba2edcc7a53e609381166f17a2fbb45e..f230d93da4a9c01e8dee47aa258d9c28499469f1 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -17,7 +17,7 @@ ### API This module provides functions for image manipulation; currently, chrominance -transformas (including changing saturation and hue) in YIQ space and +transforms (including changing saturation and hue) in YIQ space and projective transforms (including rotation) are supported. ## Image Transformation `Ops` @@ -25,11 +25,16 @@ projective transforms (including rotation) are supported. @@angles_to_projective_transforms @@compose_transforms @@adjust_yiq_hsv +@@flat_transforms_to_matrices +@@matrices_to_flat_transforms @@random_yiq_hsv @@rotate @@transform @@translate @@translations_to_projective_transforms +@@dense_image_warp +@@interpolate_spline +@@sparse_image_warp ## Image Segmentation `Ops` @@ -47,17 +52,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.python.ops.dense_image_warp import dense_image_warp + from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms from tensorflow.contrib.image.python.ops.image_ops import compose_transforms from tensorflow.contrib.image.python.ops.image_ops import connected_components +from tensorflow.contrib.image.python.ops.image_ops import flat_transforms_to_matrices +from tensorflow.contrib.image.python.ops.image_ops import matrices_to_flat_transforms from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms +from tensorflow.contrib.image.python.ops.interpolate_spline import interpolate_spline from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms +from tensorflow.contrib.image.python.ops.sparse_image_warp import sparse_image_warp from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc index b71ff9cd507faac66b3a33d3c02ec9b5901d814a..bbb3a3b18fd7bfdc68e8b8532568985245154794 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -53,13 +53,13 @@ void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, OP_REQUIRES_OK(ctx, ctx->allocate_temp( DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), &tranformation_matrix)); - // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // TODO(huangyp): It takes about 3.5 us to compute tranformation_matrix // with one thread. Improve its performance if necessary. internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), tranformation_matrix.flat().size()); // Call cuBlas C = A * B directly. - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto no_transpose = se::blas::Transpose::kNoTranspose; auto a_ptr = AsDeviceMemory(input->flat().data(), input->flat().size()); auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc index fe8bf6e21c7b7310527668324571774e8bc50893..93722896233f0278c6cbb44af7203345e58c3172 100644 --- a/tensorflow/contrib/image/kernels/segmentation_ops.cc +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -101,8 +101,8 @@ struct ImageConnectedComponentsFunctor { int cost = (union_find.block_height() + union_find.block_width()) * 20; Shard(worker_threads->num_threads, worker_threads->workers, num_images * num_blocks_vertically * num_blocks_horizontally, cost, - [&union_find, num_images, num_blocks_vertically, - num_blocks_horizontally](int64 start_block, int64 limit_block) { + [&union_find, num_blocks_vertically, num_blocks_horizontally]( + int64 start_block, int64 limit_block) { for (int64 i = start_block; i < limit_block; i++) { int64 block_x = i % num_blocks_horizontally; int64 block_y = diff --git a/tensorflow/contrib/image/ops/distort_image_ops.cc b/tensorflow/contrib/image/ops/distort_image_ops.cc index b169b0b2b22ad6432baed2cc96711da5ca995875..ca49635d5d0bc7bb84b19508a74be74362d96ddf 100644 --- a/tensorflow/contrib/image/ops/distort_image_ops.cc +++ b/tensorflow/contrib/image/ops/distort_image_ops.cc @@ -36,9 +36,9 @@ REGISTER_OP("AdjustHsvInYiq") Adjust the YIQ hue of one or more images. `images` is a tensor of at least 3 dimensions. The last dimension is -interpretted as channels, and must be three. +interpreted as channels, and must be three. -We used linear transfomation described in: +We used linear transformation described in: beesbuzz.biz/code/hsv_color_transforms.php The input image is considered in the RGB colorspace. Conceptually, the RGB colors are first mapped into YIQ space, rotated around the Y channel by diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 68771b3d054a64ba94141c092e20df1ed6b2339b..ebdcaea7abae2a967786831b62b331897aa3f6a3 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -93,7 +93,7 @@ row_to_col_match_indices: A vector of length num_rows, which is the number of If `row_to_col_match_indices[i]` is not -1, row i is matched to column `row_to_col_match_indices[i]`. col_to_row_match_indices: A vector of length num_columns, which is the number - of columns of the input ditance matrix. + of columns of the input distance matrix. If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. )doc"); diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 8139d4272d6950815bd39a64e86e0f7422e6f799..bd784c6bda0344c092c1ae0af2c60be50fdff102 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -69,7 +69,7 @@ Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the -encode 3-D data witin the image. +encode 3-D data within the image. This Op is based upon: 'http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper' @@ -111,7 +111,7 @@ output_image_shape: Output size of returned image in X,Y, Channels 1-grayscale, output_data_window: Size of "DATA" window, must be equal to or smaller than 'output_image_shape', will be centered and use 'convergence_dots_size' for best fit to avoid overlap if possible -image:= A tensor of size 'output_image_shape' with the encloded 'depth_values' +image:= A tensor of size 'output_image_shape' with the encoded 'depth_values' )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a58b6a247ed6ae252db25a12f1e47c08c9a5c147 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py @@ -0,0 +1,267 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import adam + + +class DenseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def test_interpolate_small_grid_ij(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_xy(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear( + grid, query_points, indexing='xy') + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_batched(self): + grid = constant_op.constant( + [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1]) + query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]], + [[0.5, 0.], [1., 0.], [1., 1.]]]) + expected_results = np.reshape( + np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def get_image_and_flow_placeholders(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + flow_shape = [batch_size, height, width, 2] + + tf_type = { + 'float16': dtypes.half, + 'float32': dtypes.float32, + 'float64': dtypes.float64 + } + + image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape) + + flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape) + return image, flows + + def get_random_image_and_flows(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + image = np.random.normal(size=image_shape) + flow_shape = [batch_size, height, width, 2] + flows = np.random.normal(size=flow_shape) * 3 + return image.astype(image_type), flows.astype(flow_type) + + def assert_correct_interpolation_value(self, + image, + flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=False): + """Assert that the tf interpolation matches hand-computed value.""" + + height = image.shape[1] + width = image.shape[2] + displacement = flows[batch_index, y_index, x_index, :] + float_y = y_index - displacement[0] + float_x = x_index - displacement[1] + floor_y = max(min(height - 2, math.floor(float_y)), 0) + floor_x = max(min(width - 2, math.floor(float_x)), 0) + ceil_y = floor_y + 1 + ceil_x = floor_x + 1 + + alpha_y = min(max(0.0, float_y - floor_y), 1.0) + alpha_x = min(max(0.0, float_x - floor_x), 1.0) + + floor_y = int(floor_y) + floor_x = int(floor_x) + ceil_y = int(ceil_y) + ceil_x = int(ceil_x) + + top_left = image[batch_index, floor_y, floor_x, :] + top_right = image[batch_index, floor_y, ceil_x, :] + bottom_left = image[batch_index, ceil_y, floor_x, :] + bottom_right = image[batch_index, ceil_y, ceil_x, :] + + interp_top = alpha_x * (top_right - top_left) + top_left + interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left + interp = alpha_y * (interp_bottom - interp_top) + interp_top + atol = 1e-6 + rtol = 1e-6 + if low_precision: + atol = 1e-2 + rtol = 1e-3 + self.assertAllClose( + interp, + pred_interpolation[batch_index, y_index, x_index, :], + atol=atol, + rtol=rtol) + + def check_zero_flow_correctness(self, shape, image_type, flow_type): + """Assert using zero flows doesn't change the input image.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + rand_flows *= 0 + + predicted_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + self.assertAllClose(rand_image, predicted_interpolation) + + def test_zero_flows(self): + """Apply check_zero_flow_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] + for shape in shapes_to_try: + self.check_zero_flow_correctness( + shape, image_type='float32', flow_type='float32') + + def check_interpolation_correctness(self, + shape, + image_type, + flow_type, + num_probes=5): + """Interpolate, and then assert correctness for a few query locations.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + low_precision = image_type == 'float16' or flow_type == 'float16' + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + + pred_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + + for _ in range(num_probes): + batch_index = np.random.randint(0, shape[0]) + y_index = np.random.randint(0, shape[1]) + x_index = np.random.randint(0, shape[2]) + + self.assert_correct_interpolation_value( + rand_image, + rand_flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=low_precision) + + def test_interpolation(self): + """Apply check_interpolation_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]] + for im_type in ['float32', 'float64', 'float16']: + for flow_type in ['float32', 'float64', 'float16']: + for shape in shapes_to_try: + self.check_interpolation_correctness(shape, im_type, flow_type) + + def test_gradients_exist(self): + """Check that backprop can run. + + The correctness of the gradients is assumed, since the forward propagation + is tested to be correct and we only use built-in tf ops. + However, we perform a simple test to make sure that backprop can actually + run. We treat the flows as a tf.Variable and optimize them to minimize + the difference between the interpolated image and the input image. + """ + + batch_size, height, width, numchannels = [4, 5, 6, 7] + image_shape = [batch_size, height, width, numchannels] + image = random_ops.random_normal(image_shape) + flow_shape = [batch_size, height, width, 2] + init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) + flows = variables.Variable(init_flows) + + interp = dense_image_warp.dense_image_warp(image, flows) + loss = math_ops.reduce_mean(math_ops.square(interp - image)) + + optimizer = adam.AdamOptimizer(1.0) + grad = gradients.gradients(loss, [flows]) + opt_func = optimizer.apply_gradients(zip(grad, [flows])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(10): + sess.run(opt_func) + + def test_size_exception(self): + """Make sure it throws an exception for images that are too small.""" + + shape = [1, 2, 1, 1] + msg = 'Should have raised an exception for invalid image size' + with self.assertRaises(ValueError, msg=msg): + self.check_interpolation_correctness(shape, 'float32', 'float32') + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1939caaa2d8586413cf9ecba6ce73cf64910d6fc --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for interpolate_spline.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import interpolate as sc_interpolate + +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util + +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import momentum + + +class _InterpolationProblem(object): + """Abstract class for interpolation problem descriptions.""" + + def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'): + """Make data for an interpolation problem where all x vectors are n-d. + + Args: + optimizable: If True, then make train_points a tf.Variable. + extrapolate: If False, then clamp the query_points values to be within + the max and min of train_points. + dtype: The data type to use. + + Returns: + query_points, query_values, train_points, train_values: training and + test tensors for interpolation problem + """ + + # The values generated here depend on a seed of 0. + np.random.seed(0) + + batch_size = 1 + num_training_points = 10 + num_query_points = 4 + + init_points = np.random.uniform( + size=[batch_size, num_training_points, self.DATA_DIM]) + + init_points = init_points.astype(dtype) + train_points = ( + variables.Variable(init_points) + if optimizable else constant_op.constant(init_points)) + train_values = self.tf_function(train_points) + + query_points_np = np.random.uniform( + size=[batch_size, num_query_points, self.DATA_DIM]) + query_points_np = query_points_np.astype(dtype) + if not extrapolate: + query_points_np = np.clip(query_points_np, np.min(init_points), + np.max(init_points)) + + query_points = constant_op.constant(query_points_np) + query_values = self.np_function(query_points_np) + + return query_points, query_values, train_points, train_values + + +class _QuadraticPlusSinProblem1D(_InterpolationProblem): + """1D interpolation problem used for regression testing.""" + DATA_DIM = 1 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387], + (1.0, + 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693], + (2.0, + 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059], + (2.0, + 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741], + (3.0, + 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122], + (3.0, + 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857], + (4.0, + 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256], + (4.0, 0.01): [ + -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725 + ] + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_mean( + math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10), + 2, + keepdims=True) + + +class _QuadraticPlusSinProblemND(_InterpolationProblem): + """3D interpolation problem used for regression testing.""" + + DATA_DIM = 3 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885], + (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569], + (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215], + (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312], + (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491], + (3.0, + 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866], + (4.0, + 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833], + (4.0, + 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382], + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_sum( + math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15), + 2, + keepdims=True) + + +class InterpolateSplineTest(test_util.TensorFlowTestCase): + + def test_1d_linear_interpolation(self): + """For 1d linear interpolation, we can compare directly to scipy.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, train_values) = tp.get_problem( + extrapolate=False, dtype='float64') + interpolation_order = 1 + + with ops.name_scope('interpolator'): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order) + with self.test_session() as sess: + fetches = [query_points, train_points, train_values, interpolator] + query_points_, train_points_, train_values_, interp_ = sess.run(fetches) + + # Just look at the first element of the minibatch. + # Also, trim the final singleton dimension. + interp_ = interp_[0, :, 0] + query_points_ = query_points_[0, :, 0] + train_points_ = train_points_[0, :, 0] + train_values_ = train_values_[0, :, 0] + + # Compute scipy interpolation. + scipy_interp_function = sc_interpolate.interp1d( + train_points_, train_values_, kind='linear') + + scipy_interpolation = scipy_interp_function(query_points_) + scipy_interpolation_on_train = scipy_interp_function(train_points_) + + # Even with float64 precision, the interpolants disagree with scipy a + # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. + tol = 1e-3 + + self.assertAllClose( + train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol) + self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol) + + def test_1d_interpolation(self): + """Regression test for interpolation with 1-D points.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_nd_linear_interpolation(self): + """Regression test for interpolation with N-D points.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_interpolation_gradient(self): + """Make sure that backprop can run. Correctness of gradients is assumed. + + Here, we create a use a small 'training' set and a more densely-sampled + set of query points, for which we know the true value in advance. The goal + is to choose x locations for the training data such that interpolating using + this training data yields the best reconstruction for the function + values at the query points. The training data locations are optimized + iteratively using gradient descent. + """ + tp = _QuadraticPlusSinProblemND() + (query_points, query_values, train_points, + train_values) = tp.get_problem(optimizable=True) + + regularization = 0.001 + for interpolation_order in (1, 2, 3, 4): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order, + regularization) + + loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator)) + + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [train_points]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [train_points])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(100): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0135c66e293693345c3da7fdb21e28ca6d160154 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py @@ -0,0 +1,254 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sparse_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import sparse_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import image_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + +from tensorflow.python.training import momentum + + +class SparseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def testGetBoundaryLocations(self): + image_height = 11 + image_width = 11 + num_points_per_edge = 4 + locs = sparse_image_warp._get_boundary_locations(image_height, image_width, + num_points_per_edge) + num_points = locs.shape[0] + self.assertEqual(num_points, 4 + 4 * num_points_per_edge) + locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)] + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (2, 4, 6, 8): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (0, image_height - 1): + for j in (2, 4, 6, 8): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + def testGetGridLocations(self): + image_height = 5 + image_width = 3 + grid = sparse_image_warp._get_grid_locations(image_height, image_width) + for i in range(image_height): + for j in range(image_width): + self.assertEqual(grid[i, j, 0], i) + self.assertEqual(grid[i, j, 1], j) + + def testZeroShift(self): + """Run assertZeroShift for various hyperparameters.""" + for order in (1, 2): + for regularization in (0, 0.01): + for num_boundary_points in (0, 1): + self.assertZeroShift(order, regularization, num_boundary_points) + + def assertZeroShift(self, order, regularization, num_boundary_points): + """Check that warping with zero displacements doesn't change the image.""" + batch_size = 1 + image_height = 4 + image_width = 4 + channels = 3 + + image = np.random.uniform( + size=[batch_size, image_height, image_width, channels]) + + input_image_op = constant_op.constant(np.float32(image)) + + control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + + control_point_displacements = np.zeros( + control_point_locations.shape.as_list()) + control_point_displacements = constant_op.constant( + np.float32(control_point_displacements)) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + regularization_weight=regularization, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, _ = sess.run( + [warped_image_op, input_image_op, flow_field]) + + self.assertAllClose(warped_image, input_image) + + def testMoveSinglePixel(self): + """Run assertMoveSinglePixel for various hyperparameters and data types.""" + for order in (1, 2): + for num_boundary_points in (1, 2): + for type_to_use in (dtypes.float32, dtypes.float64): + self.assertMoveSinglePixel(order, num_boundary_points, type_to_use) + + def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): + """Move a single block in a small grid using warping.""" + batch_size = 1 + image_height = 7 + image_width = 7 + channels = 3 + + image = np.zeros([batch_size, image_height, image_width, channels]) + image[:, 3, 3, :] = 1.0 + input_image_op = constant_op.constant(image, dtype=type_to_use) + + # Place a control point at the one white pixel. + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0)), + dtype=type_to_use) + # Shift it one pixel to the right. + control_point_displacements = [[0., 1.0]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0)), + dtype=type_to_use) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, flow = sess.run( + [warped_image_op, input_image_op, flow_field]) + # Check that it moved the pixel correctly. + self.assertAllClose( + warped_image[0, 4, 5, :], + input_image[0, 4, 4, :], + atol=1e-5, + rtol=1e-5) + + # Test that there is no flow at the corners. + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertAllClose( + flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5) + + def load_image(self, image_file, sess): + image_op = image_ops.decode_png( + io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3] + return sess.run(image_op) + + def testSmileyFace(self): + """Check warping accuracy by comparing to hardcoded warped images.""" + + test_data_dir = test.test_src_dir_path('contrib/image/python/' + 'kernel_tests/test_data/') + input_file = test_data_dir + 'Yellow_Smiley_Face.png' + with self.test_session() as sess: + input_image = self.load_image(input_file, sess) + control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111], + [180 - 39, 111], [90, 143], [58, 134], + [180 - 58, 134]]) # pyformat: disable + control_point_displacements = np.asarray( + [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25], + [10, 10.75]]) + control_points_op = constant_op.constant( + np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) + control_point_displacements_op = constant_op.constant( + np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0)) + float_image = np.expand_dims(np.float32(input_image) / 255, 0) + input_image_op = constant_op.constant(float_image) + + for interpolation_order in (1, 2, 3): + for num_boundary_points in (0, 1, 4): + warp_op, _ = sparse_image_warp.sparse_image_warp( + input_image_op, + control_points_op, + control_points_op + control_point_displacements_op, + interpolation_order=interpolation_order, + num_boundary_points=num_boundary_points) + with self.test_session() as sess: + warped_image = sess.run(warp_op) + out_image = np.uint8(warped_image[0, :, :, :] * 255) + target_file = ( + test_data_dir + + 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format( + interpolation_order, num_boundary_points)) + + target_image = self.load_image(target_file, sess) + + # Check that the target_image and out_image difference is no + # bigger than 2 (on a scale of 0-255). Due to differences in + # floating point computation on different devices, the float + # output in warped_image may get rounded to a different int + # than that in the saved png file loaded into target_image. + self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3) + + def testThatBackpropRuns(self): + """Run optimization to ensure that gradients can be computed.""" + + batch_size = 1 + image_height = 9 + image_width = 12 + image = variables.Variable( + np.float32( + np.random.uniform(size=[batch_size, image_height, image_width, 3]))) + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + control_point_displacements = [[0.25, -0.5]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0))) + warped_image, _ = sparse_image_warp.sparse_image_warp( + image, + control_point_locations, + control_point_locations + control_point_displacements, + num_boundary_points=3) + + loss = math_ops.reduce_mean(math_ops.abs(warped_image - image)) + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [image]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [image])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png new file mode 100644 index 0000000000000000000000000000000000000000..7e303881e213a82e412d18de9d9d86f368726f06 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..7fd9e4e6d69f3120428d1d778846d495cea1a989 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..86d225e5d2158804f88dca881f69ed3ab287d866 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..37e8ffae114625d0cc6a07ab2b8dbbb7413a3829 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..e49b5816120d43a669264915f1b6747606e080e0 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..df3cf2004312ed0ed0ebf1f0340cbfec7fd9ac46 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..e1799a87c8542d7e515b6185d7e8f6f75fe73f3e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..2c346e0ce5487e21d41aa4e6306fd83a7b4ffdb4 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..6f8b65451cc08a463e4305ddc4be0dbe2879fae9 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..8e78146d955ae8f02230121e6314f3285e87611e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b219ada492466919c615d8978e462e6c619d33 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using per-pixel flow vectors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _interpolate_bilinear(grid, + query_points, + name='interpolate_bilinear', + indexing='ij'): + """Similar to Matlab's interp2 function. + + Finds values for query points on a grid using bilinear interpolation. + + Args: + grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. + query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. + name: a name for the operation (optional). + indexing: whether the query points are specified as row and column (ij), + or Cartesian coordinates (xy). + + Returns: + values: a 3-D `Tensor` with shape `[batch, N, channels]` + + Raises: + ValueError: if the indexing mode is invalid, or if the shape of the inputs + invalid. + """ + if indexing != 'ij' and indexing != 'xy': + raise ValueError('Indexing mode must be \'ij\' or \'xy\'') + + with ops.name_scope(name): + grid = ops.convert_to_tensor(grid) + query_points = ops.convert_to_tensor(query_points) + shape = grid.get_shape().as_list() + if len(shape) != 4: + msg = 'Grid must be 4 dimensional. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + batch_size, height, width, channels = shape + query_type = query_points.dtype + grid_type = grid.dtype + + if (len(query_points.get_shape()) != 3 or + query_points.get_shape()[2].value != 2): + msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' + 'size: ') + raise ValueError(msg + str(query_points.get_shape())) + + _, num_queries, _ = query_points.get_shape().as_list() + + if height < 2 or width < 2: + msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + alphas = [] + floors = [] + ceils = [] + + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) + + for dim in index_order: + with ops.name_scope('dim-' + str(dim)): + queries = unstacked_query_points[dim] + + size_in_indexing_dimension = shape[dim + 1] + + # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 + # is still a valid index into the grid. + max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type) + min_floor = constant_op.constant(0.0, dtype=query_type) + floor = math_ops.minimum( + math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor) + int_floor = math_ops.cast(floor, dtypes.int32) + floors.append(int_floor) + ceil = int_floor + 1 + ceils.append(ceil) + + # alpha has the same type as the grid, as we will directly use alpha + # when taking linear combinations of pixel values from the image. + alpha = math_ops.cast(queries - floor, grid_type) + min_alpha = constant_op.constant(0.0, dtype=grid_type) + max_alpha = constant_op.constant(1.0, dtype=grid_type) + alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha) + + # Expand alpha to [b, n, 1] so we can use broadcasting + # (since the alpha values don't depend on the channel). + alpha = array_ops.expand_dims(alpha, 2) + alphas.append(alpha) + + if batch_size * height * width > np.iinfo(np.int32).max / 8: + error_msg = """The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""" + raise ValueError(error_msg) + + flattened_grid = array_ops.reshape(grid, + [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) + + # This wraps array_ops.gather. We reshape the image data such that the + # batch, y, and x coordinates are pulled into the first dimension. + # Then we gather. Finally, we reshape the output back. It's possible this + # code would be made simpler by using array_ops.gather_nd. + def gather(y_coords, x_coords, name): + with ops.name_scope('gather-' + name): + linear_coordinates = batch_offsets + y_coords * width + x_coords + gathered_values = array_ops.gather(flattened_grid, linear_coordinates) + return array_ops.reshape(gathered_values, + [batch_size, num_queries, channels]) + + # grab the pixel values in the 4 corners around each query point + top_left = gather(floors[0], floors[1], 'top_left') + top_right = gather(floors[0], ceils[1], 'top_right') + bottom_left = gather(ceils[0], floors[1], 'bottom_left') + bottom_right = gather(ceils[0], ceils[1], 'bottom_right') + + # now, do the actual interpolation + with ops.name_scope('interpolate'): + interp_top = alphas[1] * (top_right - top_left) + top_left + interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left + interp = alphas[0] * (interp_bottom - interp_top) + interp_top + + return interp + + +def dense_image_warp(image, flow, name='dense_image_warp'): + """Image warping using per-pixel flow vectors. + + Apply a non-linear warp to the image, where the warp is specified by a dense + flow field of offset vectors that define the correspondences of pixel values + in the output image back to locations in the source image. Specifically, the + pixel value at output[b, j, i, c] is + images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. + + The locations specified by this formula do not necessarily map to an int + index. Therefore, the pixel value is obtained by bilinear + interpolation of the 4 nearest pixels around + (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside + of the image, we use the nearest pixel values at the image boundary. + + + Args: + image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. + flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. + name: A name for the operation (optional). + + Note that image and flow can be of type tf.half, tf.float32, or tf.float64, + and do not necessarily have to be the same type. + + Returns: + A 4-D float `Tensor` with shape`[batch, height, width, channels]` + and same type as input image. + + Raises: + ValueError: if height < 2 or width < 2 or the inputs have the wrong number + of dimensions. + """ + with ops.name_scope(name): + batch_size, height, width, channels = image.get_shape().as_list() + # The flow is defined on the image grid. Turn the flow into a list of query + # points in the grid space. + grid_x, grid_y = array_ops.meshgrid( + math_ops.range(width), math_ops.range(height)) + stacked_grid = math_ops.cast( + array_ops.stack([grid_y, grid_x], axis=2), flow.dtype) + batched_grid = array_ops.expand_dims(stacked_grid, axis=0) + query_points_on_grid = batched_grid - flow + query_points_flattened = array_ops.reshape(query_points_on_grid, + [batch_size, height * width, 2]) + # Compute values at the query points, then reshape the result back to the + # image grid. + interpolated = _interpolate_bilinear(image, query_points_flattened) + interpolated = array_ops.reshape(interpolated, + [batch_size, height, width, channels]) + return interpolated diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index c139ae89d8d682d6b87813c3a21703ffa762f28e..cd984c80543886be1f682933e2e003bd3374e425 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -433,7 +433,7 @@ def bipartite_match(distance_mat, of rows of the input `distance_matrix`. If `row_to_col_match_indices[i]` is not -1, row i is matched to column `row_to_col_match_indices[i]`. col_to_row_match_indices: A vector of length num_columns, which is the - number of columns of the input ditance matrix. + number of columns of the input distance matrix. If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. """ diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py new file mode 100644 index 0000000000000000000000000000000000000000..daf8c56456327f102f1409296a91f9f7b68ec799 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py @@ -0,0 +1,291 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Polyharmonic spline interpolation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + +EPSILON = 0.0000000001 + + +def _cross_squared_distance_matrix(x, y): + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y + Args: + x: [batch_size, n, d] float `Tensor` + y: [batch_size, m, d] float `Tensor` + + Returns: + squared_dists: [batch_size, n, m] float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 + """ + x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2) + y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2) + + # Expand so that we can broadcast. + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1) + + x_y_transpose = math_ops.matmul(x, y, adjoint_b=True) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile + + return squared_dists + + +def _pairwise_squared_distance_matrix(x): + """Pairwise squared distance among a (batch) matrix's rows (2nd dim). + + This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x) + + Args: + x: `[batch_size, n, d]` float `Tensor` + + Returns: + squared_dists: `[batch_size, n, n]` float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 + """ + + x_x_transpose = math_ops.matmul(x, x, adjoint_b=True) + x_norm_squared = array_ops.matrix_diag_part(x_x_transpose) + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + + # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose( + x_norm_squared_tile, [0, 2, 1]) + + return squared_dists + + +def _solve_interpolation(train_points, train_values, order, + regularization_weight): + """Solve for interpolation coefficients. + + Computes the coefficients of the polyharmonic interpolant for the 'training' + data defined by (train_points, train_values) using the kernel phi. + + Args: + train_points: `[b, n, d]` interpolation centers + train_values: `[b, n, k]` function values + order: order of the interpolation + regularization_weight: weight to place on smoothness regularization term + + Returns: + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + """ + + b, n, d = train_points.get_shape().as_list() + _, _, k = train_values.get_shape().as_list() + + # First, rename variables so that the notation (c, f, w, v, A, B, etc.) + # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. + # To account for python style guidelines we use + # matrix_a for A and matrix_b for B. + + c = train_points + f = train_values + + # Next, construct the linear system. + with ops.name_scope('construct_linear_system'): + + matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] + if regularization_weight > 0: + batch_identity_matrix = np.expand_dims(np.eye(n), 0) + batch_identity_matrix = constant_op.constant( + batch_identity_matrix, dtype=train_points.dtype) + + matrix_a += regularization_weight * batch_identity_matrix + + # Append ones to the feature values for the bias term in the linear model. + ones = array_ops.ones([b, n, 1], train_points.dtype) + matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = array_ops.concat( + [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1) + + num_b_cols = matrix_b.get_shape()[2] # d + 1 + lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype) + right_block = array_ops.concat([matrix_b, lhs_zeros], + 1) # [b, n + d + 1, d + 1] + lhs = array_ops.concat([left_block, right_block], + 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype) + rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + with ops.name_scope('solve_linear_system'): + w_v = linalg_ops.matrix_solve(lhs, rhs) + w = w_v[:, :n, :] + v = w_v[:, n:, :] + + return w, v + + +def _apply_interpolation(query_points, train_points, w, v, order): + """Apply polyharmonic interpolation model to data. + + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at + train_points: `[b, n, d]` x values that act as the interpolation centers + ( the c variables in the wikipedia article) + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + order: order of the interpolation + + Returns: + Polyharmonic interpolation evaluated at points defined in query_points. + """ + + batch_size = train_points.get_shape()[0].value + num_query_points = query_points.get_shape()[1].value + + # First, compute the contribution from the rbf term. + pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) + phi_pairwise_dists = _phi(pairwise_dists, order) + + rbf_term = math_ops.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + query_points_pad = array_ops.concat([ + query_points, + array_ops.ones([batch_size, num_query_points, 1], train_points.dtype) + ], 2) + linear_term = math_ops.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def _phi(r, order): + """Coordinate-wise nonlinearity used to define the order of the interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + + Args: + r: input op + order: interpolation order + + Returns: + phi_k evaluated coordinate-wise on r, for k = r + """ + + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + with ops.name_scope('phi'): + if order == 1: + r = math_ops.maximum(r, EPSILON) + r = math_ops.sqrt(r) + return r + elif order == 2: + return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON)) + elif order == 4: + return 0.5 * math_ops.square(r) * math_ops.log( + math_ops.maximum(r, EPSILON)) + elif order % 2 == 0: + r = math_ops.maximum(r, EPSILON) + return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r) + else: + r = math_ops.maximum(r, EPSILON) + return math_ops.pow(r, 0.5 * order) + + +def interpolate_spline(train_points, + train_values, + query_points, + order, + regularization_weight=0.0, + name='interpolate_spline'): + r"""Interpolate signal using polyharmonic interpolation. + + The interpolant has the form + $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ + + This is a sum of two terms: (1) a weighted sum of radial basis function (RBF) + terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias. + The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v + by appending 1 as a final dimension to x. The coefficients w and v are + estimated such that the interpolant exactly fits the value of the function at + the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the + vector w sums to 0. With these constraints, the coefficients can be obtained + by solving a linear system. + + \\(\phi\\) is an RBF, parametrized by an interpolation + order. Using order=2 produces the well-known thin-plate spline. + + We also provide the option to perform regularized interpolation. Here, the + interpolant is selected to trade off between the squared loss on the training + data and a certain measure of its curvature + ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). + Using a regularization weight greater than zero has the effect that the + interpolant will no longer exactly fit the training data. However, it may be + less vulnerable to overfitting, particularly for high-order interpolation. + + Note the interpolation procedure is differentiable with respect to all inputs + besides the order parameter. + + Args: + train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional + locations. These do not need to be regularly-spaced. + train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values + evaluated at train_points. + query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations + where we will output the interpolant's values. + order: order of the interpolation. Common values are 1 for + \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline), + or 3 for \\(\phi(r) = r^3\\). + regularization_weight: weight placed on the regularization term. + This will depend substantially on the problem, and it should always be + tuned. For many problems, it is reasonable to use no regularization. + If using a non-zero value, we recommend a small value like 0.001. + name: name prefix for ops created by this function + + Returns: + `[b, m, k]` float `Tensor` of query values. We use train_points and + train_values to perform polyharmonic interpolation. The query values are + the values of the interpolant evaluated at the locations specified in + query_points. + """ + with ops.name_scope(name): + train_points = ops.convert_to_tensor(train_points) + train_values = ops.convert_to_tensor(train_values) + query_points = ops.convert_to_tensor(query_points) + + # First, fit the spline to the observed data. + with ops.name_scope('solve'): + w, v = _solve_interpolation(train_points, train_values, order, + regularization_weight) + + # Then, evaluate the spline at the query locations. + with ops.name_scope('predict'): + query_values = _apply_interpolation(query_points, train_points, w, v, + order) + + return query_values diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index d4a6a5bcbb52511d4093587814100b2a0e8b2420..0ceb683ff4c6965a5ee4bcb04846a69d4d8ea0a5 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -45,7 +45,7 @@ def single_image_random_dot_stereograms(depth_values, Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may - corrupt the encode 3-D data witin the image. + corrupt the encode 3-D data within the image. Based upon [this paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). diff --git a/tensorflow/contrib/image/python/ops/sparse_image_warp.py b/tensorflow/contrib/image/python/ops/sparse_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..54a215d6db6ded56a1a4a018a7e176f35fe6397e --- /dev/null +++ b/tensorflow/contrib/image/python/ops/sparse_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using sparse flow defined at control points.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +def _get_grid_locations(image_height, image_width): + """Wrapper for np.meshgrid.""" + + y_range = np.linspace(0, image_height - 1, image_height) + x_range = np.linspace(0, image_width - 1, image_width) + y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') + return np.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(np_array, batch_size): + """Tile arbitrarily-sized np_array to include new batch dimension.""" + tiles = [batch_size] + [1] * np_array.ndim + return np.tile(np.expand_dims(np_array, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) + x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) + ys, xs = np.meshgrid(y_range, x_range, indexing='ij') + is_boundary = np.logical_or( + np.logical_or(xs == 0, xs == image_width - 1), + np.logical_or(ys == 0, ys == image_height - 1)) + return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = control_point_locations.get_shape()[0].value + + boundary_point_locations = _get_boundary_locations(image_height, image_width, + boundary_points_per_edge) + + boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) + + type_to_use = control_point_locations.dtype + boundary_point_locations = constant_op.constant( + _expand_to_minibatch(boundary_point_locations, batch_size), + dtype=type_to_use) + + boundary_point_flows = constant_op.constant( + _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use) + + merged_control_point_locations = array_ops.concat( + [control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = array_ops.concat( + [control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (@{tf.contrib.image.interpolate_spline}) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (@{tf.contrib.image.dense_image_warp}). + + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See @{tf.contrib.image.interpolate_spline} for further documentation of the + interpolation_order and regularization_weight arguments. + + + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense + flow field produced by the interpolation. + """ + + image = ops.convert_to_tensor(image) + source_control_point_locations = ops.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = ops.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with ops.name_scope(name): + + batch_size, image_height, image_width, _ = image.get_shape().as_list() + + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = np.reshape(grid_locations, + [image_height * image_width, 2]) + + flattened_grid_locations = constant_op.constant( + _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary( + dest_control_point_locations, control_point_flows, image_height, + image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline( + dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, regularization_weight) + + dense_flows = array_ops.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp.dense_image_warp(image, dense_flows) + + return warped_image, dense_flows diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 9d6b4d5d87e24d72b29ab33ee805fe0d068cc30a..0e34315db45d61282af1882631dc769a72965c3e 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -114,14 +114,3 @@ tf_cc_tests( "//tensorflow/core:testlib", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/input_pipeline/kernels/BUILD b/tensorflow/contrib/input_pipeline/kernels/BUILD index f20a6e38d4e80f869e9274d6fc49338a95fc6788..797605b8fe66e8375edcc70668a07a8d2a6d73f3 100644 --- a/tensorflow/contrib/input_pipeline/kernels/BUILD +++ b/tensorflow/contrib/input_pipeline/kernels/BUILD @@ -17,14 +17,3 @@ cc_library( ], alwayslink = 1, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/integrate/BUILD b/tensorflow/contrib/integrate/BUILD index 66948c1ea1f3f239d3f43a57626f8c229fe24ad9..0b7d64f4edd7587000ca5b9ecae257fe8fedd4a1 100644 --- a/tensorflow/contrib/integrate/BUILD +++ b/tensorflow/contrib/integrate/BUILD @@ -42,14 +42,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py index 68bf511099ab473d158108df6ff07827819297d9..694f0c14bd4e74535c70fab76c5f7ac58f452559 100644 --- a/tensorflow/contrib/integrate/__init__.py +++ b/tensorflow/contrib/integrate/__init__.py @@ -18,6 +18,7 @@ See the @{$python/contrib.integrate} guide. @@odeint +@@odeint_fixed """ from __future__ import absolute_import diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index b4a99867ed46897f60be3f230838c3f576d5455e..61f78febfc07bb4e677259366a81c16b2b585244 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops @@ -279,13 +278,27 @@ def _assert_increasing(t): return ops.control_dependencies([assert_increasing]) -def _check_input_types(t, y0): +def _check_input_types(y0, t, dt=None): if not (y0.dtype.is_floating or y0.dtype.is_complex): raise TypeError('`y0` must have a floating point or complex floating ' 'point dtype') if not t.dtype.is_floating: raise TypeError('`t` must have a floating point dtype') + if dt is not None and not dt.dtype.is_floating: + raise TypeError('`dt` must have a floating point dtype') + + +def _check_input_sizes(t, dt): + if len(t.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if len(dt.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if t.get_shape()[0] != dt.get_shape()[0] + 1: + raise ValueError('t and dt have incompatible lengths, must be N and N-1') + def _dopri5(func, y0, @@ -510,7 +523,7 @@ def odeint(func, # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - _check_input_types(t, y0) + _check_input_types(y0, t) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') @@ -530,24 +543,74 @@ def odeint(func, class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): """Base class for fixed-grid ODE integrators.""" - def integrate(self, evol_func, y0, time_grid): - time_delta_grid = time_grid[1:] - time_grid[:-1] - - scan_func = self._make_scan_func(evol_func) + def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): + """Returns integrated values of differential equation on the `time grid`. + + Numerically integrates differential equation defined via time derivative + evaluator `evol_func` using fixed time steps specified in dt_grid. + + Args: + evol_func: Callable, evaluates time derivative of y at a given time. + y0: N-D Tensor holds initial values of the solution. + time_grid: 1-D Tensor holding the time points at which the solution + will be recorded, must have a floating dtype. + dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid + intervals. Must be a floating dtype and have one less element than that + of the time_grid. + steps_on_intervals: 1-D Tensor of integer dtype, must have the same size + as dt_grid. Specifies number of steps needed for every interval. Assumes + steps_on_intervals * dt_grid == time intervals. + + Returns: + (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + """ - y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), - y0) - return array_ops.concat([[y0], y_grid], axis=0) + iteration_func = self._make_iteration_func(evol_func, dt_grid) + integrate_interval = self._make_interval_integrator(iteration_func, + steps_on_intervals) - def _make_scan_func(self, evol_func): + num_times = array_ops.size(time_grid) + current_time = time_grid[0] + solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times) + solution_array = solution_array.write(0, y0) - def scan_func(y, t_and_dt): - t, dt = t_and_dt + solution_array, _, _, _ = control_flow_ops.while_loop( + lambda _, __, ___, i: i < num_times, + integrate_interval, + (solution_array, y0, current_time, 1) + ) + solution_array = solution_array.stack() + solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape())) + return solution_array + + def _make_iteration_func(self, evol_func, dt_grid): + """Returns a function that builds operations of a single time step.""" + + def iteration_func(y, t, dt_step, interval_step): + """Performs a single time step advance.""" + dt = dt_grid[interval_step - 1] dy = self._step_func(evol_func, t, dt, y) dy = math_ops.cast(dy, dtype=y.dtype) - return y + dy + return y + dy, t + dt, dt_step + 1, interval_step + + return iteration_func + + def _make_interval_integrator(self, iteration_func, interval_sizes): + """Returns a function that builds operations for interval integration.""" - return scan_func + def integrate_interval(solution_array, y, t, interval_num): + """Integrates y with fixed time step on interval `interval_num`.""" + y, t, _, _ = control_flow_ops.while_loop( + lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1], + iteration_func, + (y, t, 0, interval_num) + ) + return solution_array.write(interval_num, y), y, t, interval_num + 1 + + return integrate_interval @abc.abstractmethod def _step_func(self, evol_func, t, dt, y): @@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing midpoint scheme.""" def _step_func(self, evol_func, t, dt, y): dt_cast = math_ops.cast(dt, y.dtype) @@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator): class _RK4FixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing RK4 scheme.""" def _step_func(self, evol_func, t, dt, y): k1 = evol_func(y, t) @@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator): return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) -def odeint_fixed(func, y0, t, method='rk4', name=None): +def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None): """ODE integration on a fixed grid (with no step size control). Useful in certain scenarios to avoid the overhead of adaptive step size @@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. + dt: 0-D or 1-D Tensor providing time step suggestion to be used on time + integration intervals in `t`. 1-D Tensor should provide values + for all intervals, must have 1 less element than that of `t`. + If given a 0-D Tensor, the value is interpreted as time step suggestion + same for all intervals. If passed None, then time step is set to be the + t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by + insuring an integer number of steps per interval, potentially reducing the + time step. method: One of 'midpoint' or 'rk4'. name: Optional name for the resulting operation. @@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): Raises: ValueError: Upon caller errors. """ - with ops.name_scope(name, 'odeint_fixed', [y0, t]): + with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]): t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') y0 = ops.convert_to_tensor(y0, name='y0') - _check_input_types(t, y0) + + intervals = t[1:] - t[:-1] + if dt is None: + dt = intervals + dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt') + + steps_on_intervals = math_ops.ceil(intervals / dt) + dt = intervals / steps_on_intervals + steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32) + + _check_input_types(y0, t, dt) + _check_input_sizes(t, dt) with _assert_increasing(t): with ops.name_scope(method): if method == 'midpoint': - return _MidpointFixedGridIntegrator().integrate(func, y0, t) + return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) elif method == 'rk4': - return _RK4FixedGridIntegrator().integrate(func, y0, t) + return _RK4FixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) else: raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 3ec01212d25ca8dc6e13f340177a5e85138868d5..c7b4e2faa84e1a87cb1904b22eb0008ab1ee4be6 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase): class OdeIntFixedTest(test.TestCase): - def _test_integrate_sine(self, method): + def _test_integrate_sine(self, method, t, dt=None): def evol_func(y, t): del t return array_ops.stack([y[1], -y[0]]) y0 = [0., 1.] - time_grid = np.linspace(0., 10., 200) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2) - def _test_integrate_gaussian(self, method): + def _test_integrate_gaussian(self, method, t, dt=None): def evol_func(y, t): return -math_ops.cast(t, dtype=y.dtype) * y[0] y0 = [1.] - time_grid = np.linspace(0., 2., 100) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_integrate_sine_all(self, method): + uniform_time_grid = np.linspace(0., 10., 200) + non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0]) + uniform_dt = 0.02 + non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03]) + self._test_integrate_sine(method, uniform_time_grid) + self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt) + + def _test_integrate_gaussian_all(self, method): + uniform_time_grid = np.linspace(0., 2., 100) + non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0]) + uniform_dt = 0.01 + non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03]) + self._test_integrate_gaussian(method, uniform_time_grid) + self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt) def _test_everything(self, method): - self._test_integrate_sine(method) - self._test_integrate_gaussian(method) + self._test_integrate_sine_all(method) + self._test_integrate_gaussian_all(method) def test_midpoint(self): self._test_everything('midpoint') @@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase): def test_rk4(self): self._test_everything('rk4') + def test_dt_size_exceptions(self): + times = np.linspace(0., 2., 100) + dt = np.ones(99) * 0.01 + dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03]) + dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0) + times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0) + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_length) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_dim) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times_wrong_dim, dt) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD index efb403462a6e5df5b69ac0735ffc03f40d4a252c..3913c9dc7abfba2829bde5e86fe2927e8fc29a9d 100644 --- a/tensorflow/contrib/kafka/BUILD +++ b/tensorflow/contrib/kafka/BUILD @@ -1,66 +1,93 @@ -package( - default_visibility = ["//visibility:private"], -) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_py_test", +) -tf_kernel_library( - name = "kafka_kernels", +py_library( + name = "kafka", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", srcs = ["kernels/kafka_dataset_ops.cc"], - visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels:bounds_check_lib", - "//tensorflow/core/kernels:dataset", + "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@kafka", + "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) -tf_gen_op_libs( - op_lib_names = ["kafka_ops"], +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/kafka_dataset_ops.py", + ], + srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:lib", + ":kafka_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", ], ) tf_gen_op_wrapper_py( - name = "gen_kafka_ops", - out = "python/ops/gen_kafka_ops.py", - require_shape_functions = True, - deps = [":kafka_ops_op_lib"], + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/kafka:dataset_ops_op_lib"], ) -py_library( - name = "kafka", - srcs = [ - "__init__.py", - "python/ops/kafka_dataset_ops.py", +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "kafka_op_loader", + srcs = ["python/ops/kafka_op_loader.py"], + dso = ["//tensorflow/contrib/kafka:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ - ":gen_kafka_ops", + ":gen_dataset_ops", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", ], ) @@ -88,18 +115,7 @@ tf_py_test( ], tags = [ "manual", + "no_windows", "notap", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index 88ef5f357113372b0a2d0cb13382ac980a61252d..2638b25ec424b5b4ef556ff769e94e64da32fec2 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/dataset.h" - -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/dataset.h" #include "src-cpp/rdkafkacpp.h" @@ -66,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel { eof_(eof), timeout_(timeout) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Kafka")})); @@ -83,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "KafkaDatasetOp::Dataset"; } + string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/dataset_ops.cc similarity index 100% rename from tensorflow/contrib/kafka/ops/kafka_ops.cc rename to tensorflow/contrib/kafka/ops/dataset_ops.cc diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 8e51d27a342359881de072c3979a2b5a7fc034ea..a1624614d1ab1be31463c5cdc0b4cfb653165a0c 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -17,8 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.kafka.python.ops import gen_kafka_ops -from tensorflow.python.data.ops.readers import Dataset +from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import +from tensorflow.contrib.kafka.python.ops import gen_dataset_ops +from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -58,8 +59,8 @@ class KafkaDataset(Dataset): timeout, dtype=dtypes.int64, name="timeout") def _as_variant_tensor(self): - return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, - self._eof, self._timeout) + return gen_dataset_ops.kafka_dataset(self._topics, self._servers, + self._group, self._eof, self._timeout) @property def output_classes(self): diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/__init__.py b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py similarity index 69% rename from tensorflow/python/keras/_impl/keras/preprocessing/__init__.py rename to tensorflow/contrib/kafka/python/ops/kafka_op_loader.py index 2ca48cdbf9c066194f4f4ed448fd621167db7ba9..ec2fdea962ef946d3f8f32b9e630b92649d612fe 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/__init__.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Data preprocessing module. -""" +"""Python helper for loading kafka ops and kernels.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing import image -from tensorflow.python.keras._impl.keras.preprocessing import sequence -from tensorflow.python.keras._impl.keras.preprocessing import text +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 7e0019ce4ad6c96e09ac9e222e2f4e2840273983..7a4cab20d1a3471af2a2a402a6d1443a90fa7f9b 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -52,15 +52,3 @@ py_library( "//tensorflow/python/keras", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index d04838c218d6643a703723a1d163c88547c14da7..3f0184276f6b903be63f7b35459e4ad57044eb2c 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh +from tensorflow.python.keras.activations import elu +from tensorflow.python.keras.activations import hard_sigmoid +from tensorflow.python.keras.activations import linear +from tensorflow.python.keras.activations import relu +from tensorflow.python.keras.activations import selu +from tensorflow.python.keras.activations import sigmoid +from tensorflow.python.keras.activations import softmax +from tensorflow.python.keras.activations import softplus +from tensorflow.python.keras.activations import softsign +from tensorflow.python.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get +from tensorflow.python.keras.activations import deserialize +from tensorflow.python.keras.activations import serialize +from tensorflow.python.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index abf8393ae45d71dc0cb746706abb72f77b82d199..6dfb5cab17c088bfab8ed806adeabd793ced4d12 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index b809e91193b459a46906443796344c092e1d2a6b..67306cc51e1927cfbc2db424b1f4165dabfa22f9 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index 530805d150bfe32c5b81d7d7d3f92e203b83b602..a25ff48b593a9a9ea56fd427a932bb64c10f7b7b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 118361604bbc7e0a88ed34243c0d5ea98856a301..4964b1b7deb56fe0025e9a8d8cb45d18e0209fea 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index cda52628f3c10d65fdbe70b2f86cc12c771870a9..afb3abebdd6735e6f17bc94c1fcd15a31b74f983 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c..2e3335d02aff0fff805fc2dac614b14e0593d40d 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception +from tensorflow.python.keras.applications.xception import decode_predictions +from tensorflow.python.keras.applications.xception import preprocess_input +from tensorflow.python.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index 10ef5a75852deb6595bced2703d7c5f29b0efac3..a755364014206e92289eec0b9c8e510251862e0e 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like +from tensorflow.python.keras.backend import abs +from tensorflow.python.keras.backend import all +from tensorflow.python.keras.backend import any +from tensorflow.python.keras.backend import arange +from tensorflow.python.keras.backend import argmax +from tensorflow.python.keras.backend import argmin +from tensorflow.python.keras.backend import backend +from tensorflow.python.keras.backend import batch_dot +from tensorflow.python.keras.backend import batch_flatten +from tensorflow.python.keras.backend import batch_get_value +from tensorflow.python.keras.backend import batch_normalization +from tensorflow.python.keras.backend import batch_set_value +from tensorflow.python.keras.backend import bias_add +from tensorflow.python.keras.backend import binary_crossentropy +from tensorflow.python.keras.backend import cast +from tensorflow.python.keras.backend import cast_to_floatx +from tensorflow.python.keras.backend import categorical_crossentropy +from tensorflow.python.keras.backend import clear_session +from tensorflow.python.keras.backend import clip +from tensorflow.python.keras.backend import concatenate +from tensorflow.python.keras.backend import constant +from tensorflow.python.keras.backend import conv1d +from tensorflow.python.keras.backend import conv2d +from tensorflow.python.keras.backend import conv2d_transpose +from tensorflow.python.keras.backend import conv3d +from tensorflow.python.keras.backend import cos +from tensorflow.python.keras.backend import count_params +from tensorflow.python.keras.backend import ctc_batch_cost +from tensorflow.python.keras.backend import ctc_decode +from tensorflow.python.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras.backend import dot +from tensorflow.python.keras.backend import dropout +from tensorflow.python.keras.backend import dtype +from tensorflow.python.keras.backend import elu +from tensorflow.python.keras.backend import epsilon +from tensorflow.python.keras.backend import equal +from tensorflow.python.keras.backend import eval +from tensorflow.python.keras.backend import exp +from tensorflow.python.keras.backend import expand_dims +from tensorflow.python.keras.backend import eye +from tensorflow.python.keras.backend import flatten +from tensorflow.python.keras.backend import floatx +from tensorflow.python.keras.backend import foldl +from tensorflow.python.keras.backend import foldr +from tensorflow.python.keras.backend import function +from tensorflow.python.keras.backend import gather +from tensorflow.python.keras.backend import get_session +from tensorflow.python.keras.backend import get_uid +from tensorflow.python.keras.backend import get_value +from tensorflow.python.keras.backend import gradients +from tensorflow.python.keras.backend import greater +from tensorflow.python.keras.backend import greater_equal +from tensorflow.python.keras.backend import hard_sigmoid +from tensorflow.python.keras.backend import image_data_format +from tensorflow.python.keras.backend import in_test_phase +from tensorflow.python.keras.backend import in_top_k +from tensorflow.python.keras.backend import in_train_phase +from tensorflow.python.keras.backend import int_shape +from tensorflow.python.keras.backend import is_sparse +from tensorflow.python.keras.backend import l2_normalize +from tensorflow.python.keras.backend import learning_phase +from tensorflow.python.keras.backend import less +from tensorflow.python.keras.backend import less_equal +from tensorflow.python.keras.backend import log +from tensorflow.python.keras.backend import manual_variable_initialization +from tensorflow.python.keras.backend import map_fn +from tensorflow.python.keras.backend import max +from tensorflow.python.keras.backend import maximum +from tensorflow.python.keras.backend import mean +from tensorflow.python.keras.backend import min +from tensorflow.python.keras.backend import minimum +from tensorflow.python.keras.backend import moving_average_update +from tensorflow.python.keras.backend import name_scope +from tensorflow.python.keras.backend import ndim +from tensorflow.python.keras.backend import normalize_batch_in_training +from tensorflow.python.keras.backend import not_equal +from tensorflow.python.keras.backend import one_hot +from tensorflow.python.keras.backend import ones +from tensorflow.python.keras.backend import ones_like +from tensorflow.python.keras.backend import permute_dimensions +from tensorflow.python.keras.backend import placeholder +from tensorflow.python.keras.backend import pool2d +from tensorflow.python.keras.backend import pool3d +from tensorflow.python.keras.backend import pow +from tensorflow.python.keras.backend import print_tensor +from tensorflow.python.keras.backend import prod +from tensorflow.python.keras.backend import random_binomial +from tensorflow.python.keras.backend import random_normal +from tensorflow.python.keras.backend import random_normal_variable +from tensorflow.python.keras.backend import random_uniform +from tensorflow.python.keras.backend import random_uniform_variable +from tensorflow.python.keras.backend import relu +from tensorflow.python.keras.backend import repeat +from tensorflow.python.keras.backend import repeat_elements +from tensorflow.python.keras.backend import reset_uids +from tensorflow.python.keras.backend import reshape +from tensorflow.python.keras.backend import resize_images +from tensorflow.python.keras.backend import resize_volumes +from tensorflow.python.keras.backend import reverse +from tensorflow.python.keras.backend import rnn +from tensorflow.python.keras.backend import round +from tensorflow.python.keras.backend import separable_conv2d +from tensorflow.python.keras.backend import set_epsilon +from tensorflow.python.keras.backend import set_floatx +from tensorflow.python.keras.backend import set_image_data_format +from tensorflow.python.keras.backend import set_learning_phase +from tensorflow.python.keras.backend import set_session +from tensorflow.python.keras.backend import set_value +from tensorflow.python.keras.backend import shape +from tensorflow.python.keras.backend import sigmoid +from tensorflow.python.keras.backend import sign +from tensorflow.python.keras.backend import sin +from tensorflow.python.keras.backend import softmax +from tensorflow.python.keras.backend import softplus +from tensorflow.python.keras.backend import softsign +from tensorflow.python.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras.backend import spatial_2d_padding +from tensorflow.python.keras.backend import spatial_3d_padding +from tensorflow.python.keras.backend import sqrt +from tensorflow.python.keras.backend import square +from tensorflow.python.keras.backend import squeeze +from tensorflow.python.keras.backend import stack +from tensorflow.python.keras.backend import std +from tensorflow.python.keras.backend import stop_gradient +from tensorflow.python.keras.backend import sum +from tensorflow.python.keras.backend import switch +from tensorflow.python.keras.backend import tanh +from tensorflow.python.keras.backend import temporal_padding +from tensorflow.python.keras.backend import to_dense +from tensorflow.python.keras.backend import transpose +from tensorflow.python.keras.backend import truncated_normal +from tensorflow.python.keras.backend import update +from tensorflow.python.keras.backend import update_add +from tensorflow.python.keras.backend import update_sub +from tensorflow.python.keras.backend import var +from tensorflow.python.keras.backend import variable +from tensorflow.python.keras.backend import zeros +from tensorflow.python.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2d884790ddb9ccf49649c6af4cfd40cddbc38cb3..10e05f2969bc404d4cf3a9b7a999510cd40e3c17 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras.callbacks import BaseLogger +from tensorflow.python.keras.callbacks import Callback +from tensorflow.python.keras.callbacks import CSVLogger +from tensorflow.python.keras.callbacks import EarlyStopping +from tensorflow.python.keras.callbacks import History +from tensorflow.python.keras.callbacks import LambdaCallback +from tensorflow.python.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import ProgbarLogger +from tensorflow.python.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras.callbacks import RemoteMonitor +from tensorflow.python.keras.callbacks import TensorBoard +from tensorflow.python.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 152606d8ebbcadf57d971d508e15283da65e4aa3..08debf974ec3a36174c353ecaf9e425a9afc3f36 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm +from tensorflow.python.keras.constraints import Constraint +from tensorflow.python.keras.constraints import max_norm +from tensorflow.python.keras.constraints import MaxNorm +from tensorflow.python.keras.constraints import min_max_norm +from tensorflow.python.keras.constraints import MinMaxNorm +from tensorflow.python.keras.constraints import non_neg +from tensorflow.python.keras.constraints import NonNeg +from tensorflow.python.keras.constraints import unit_norm +from tensorflow.python.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get +from tensorflow.python.keras.constraints import deserialize +from tensorflow.python.keras.constraints import serialize +from tensorflow.python.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index b5371a03fd5f5755ba8844415276113c565f52db..a5a6fdab445d2d5328f203b6a704f89e9bb4ce67 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.python.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index 68d3eb789ea2c410095c0c75e0b79a9b07d209a3..e74e5f347df2eeb626cd781c54c9a7b76561d4e9 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data +from tensorflow.python.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index ca93742673341660ba69712feb59c5dd32ea3252..8f5753a6360dfbddb5678c4f2c02adff86b5f0cb 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.python.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index 1c6396d2d32b88eaa900a5af4e62c7484fceab63..bd6ec4b8dfb0344ad0b89956939607ef51bb0889 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data +from tensorflow.python.keras.datasets.imdb import get_word_index +from tensorflow.python.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 364255f3387b59a419c010db9b93cdfbcba36186..f61145655bd5d98965e15fecd387d538e9bc642b 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data +from tensorflow.python.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index bb6791a344ad0c372ac60cd4a332f5632841dd46..ade31f4ea9c33204a4350e6bc3a5a2469e54fd61 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data +from tensorflow.python.keras.datasets.reuters import get_word_index +from tensorflow.python.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7..c6bdc4f0dac3f446238dc4cbc72fe4be278a5ff6 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros +from tensorflow.python.keras.initializers import Constant +from tensorflow.python.keras.initializers import Identity +from tensorflow.python.keras.initializers import Initializer +from tensorflow.python.keras.initializers import Ones +from tensorflow.python.keras.initializers import Orthogonal +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.initializers import RandomUniform +from tensorflow.python.keras.initializers import TruncatedNormal +from tensorflow.python.keras.initializers import VarianceScaling +from tensorflow.python.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform +from tensorflow.python.keras.initializers import glorot_normal +from tensorflow.python.keras.initializers import glorot_uniform +from tensorflow.python.keras.initializers import he_normal +from tensorflow.python.keras.initializers import he_uniform +from tensorflow.python.keras.initializers import lecun_normal +from tensorflow.python.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get +from tensorflow.python.keras.initializers import deserialize +from tensorflow.python.keras.initializers import serialize +from tensorflow.python.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index acf0a5e1799b7c57dfd82861c9ccc1f132c34375..938c881fcbe18623fa18c21c112375f9914f887b 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine import Input +from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine import Layer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 66721b694f5fd5fae7ca521ff56d4c6c6bce79b5..c4476a7bbd5056fa898468a46031bf3d8b1e44cf 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get +from tensorflow.python.keras.losses import deserialize +from tensorflow.python.keras.losses import serialize +from tensorflow.python.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 59faf037bce0f087d244a2faaeb52713bdc3b772..7317fdb52c5b79e787a49d71be49f5261d6b1fff 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras.metrics import binary_accuracy +from tensorflow.python.keras.metrics import binary_crossentropy +from tensorflow.python.keras.metrics import categorical_accuracy +from tensorflow.python.keras.metrics import categorical_crossentropy +from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import hinge +from tensorflow.python.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras.metrics import mean_absolute_error +from tensorflow.python.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras.metrics import mean_squared_error +from tensorflow.python.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras.metrics import poisson +from tensorflow.python.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras.metrics import squared_hinge +from tensorflow.python.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get +from tensorflow.python.keras.metrics import deserialize +from tensorflow.python.keras.metrics import serialize +from tensorflow.python.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac..3a196984cd88cb60fbc2a9db306ce8fecf0febc0 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras.models import load_model +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.models import model_from_config +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.keras.models import model_from_yaml +from tensorflow.python.keras.models import save_model +from tensorflow.python.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index 44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44..4849a06747958ab41b8b6309fa848aff3da3f633 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras.optimizers import Adadelta +from tensorflow.python.keras.optimizers import Adagrad +from tensorflow.python.keras.optimizers import Adam +from tensorflow.python.keras.optimizers import Adamax +from tensorflow.python.keras.optimizers import Nadam +from tensorflow.python.keras.optimizers import Optimizer +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get +from tensorflow.python.keras.optimizers import deserialize +from tensorflow.python.keras.optimizers import serialize +from tensorflow.python.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index b96e7675527041d3952b049f5f431d3df36eea4c..1f9e82b41bf09b235e93fa512a50ea4c3047c01b 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom +from tensorflow.python.keras.preprocessing.image import apply_transform +from tensorflow.python.keras.preprocessing.image import array_to_img +from tensorflow.python.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras.preprocessing.image import flip_axis +from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras.preprocessing.image import img_to_array +from tensorflow.python.keras.preprocessing.image import Iterator +from tensorflow.python.keras.preprocessing.image import load_img +from tensorflow.python.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras.preprocessing.image import random_rotation +from tensorflow.python.keras.preprocessing.image import random_shear +from tensorflow.python.keras.preprocessing.image import random_shift +from tensorflow.python.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 112f6af5e588bcb2e85fdbecea86f402742d44e7..9a93b6fb57ff5aaab25f2b606249a6022814b5e4 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index 5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f..86386a9b6762d1c5cb3915ace64686cc25367e0f 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras.preprocessing.text import one_hot +from tensorflow.python.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index 3e707ccab577b5e28febd83d91f84d7b1c0d5d82..d668e39c09ca28239e56763f111fb01939bedc69 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 +from tensorflow.python.keras.regularizers import l1 +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get +from tensorflow.python.keras.regularizers import deserialize +from tensorflow.python.keras.regularizers import serialize +from tensorflow.python.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index a7c2179fe7ad434356921a5fb8709aa5b1f33498..47cd01b924fb43e8a83836c58f8ced61e9e88268 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index a46f859273ea0117e29a403057f9f81bc758dd52..c4b7aa765c26bafbfcfe45df02e58d1cf1064b4b 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index eff7dfeb4c1117e40f4faf43c5e92a52cffd6528..87c2dcd89b63fa9f92d93c87abce91fd3460d44e 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -90,15 +90,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index f182fef067b7f523bc5ca63227265be40528b171..4ef0a66a52429233c6e6f70667a451466493629c 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -43,10 +43,10 @@ def sparse_multiclass_hinge_loss( This is a generalization of standard (binary) hinge loss. For a given instance with correct label c*, the loss is given by: - loss = max_{c != c*} logits_c - logits_{c*} + 1. + $$loss = max_{c != c*} logits_c - logits_{c*} + 1.$$ or equivalently - loss = max_c { logits_c - logits_{c*} + I_{c != c*} } - where I_{c != c*} = 1 if c != c* and 0 otherwise. + $$loss = max_c { logits_c - logits_{c*} + I_{c != c*} }$$ + where \\(I_{c != c*} = 1\ \text{if}\ c != c*\\) and 0 otherwise. Args: labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py index 9dc01124ab195ae17b8795a11e4ebefe3f2c746b..9a721a9d440e66eb30bb94daf2b6878318f1e75f 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py @@ -34,33 +34,31 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): r"""Class that implements Random Fourier Feature Mapping (RFFM) in TensorFlow. The RFFM mapping is used to approximate the Gaussian (RBF) kernel: - ``` - exp(-||x-y||_2^2 / (2 * sigma^2)) - ``` + $$(exp(-||x-y||_2^2 / (2 * \sigma^2))$$ The implementation of RFFM is based on the following paper: "Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht. (link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) - The mapping uses a matrix `Omega \in R^{d x D}` and a bias vector `b \in R^D` - where `d` is the input dimension (number of dense input features) and `D` is - the output dimension (i.e., dimension of the feature space the input is mapped - to). Each entry of `Omega` is sampled i.i.d. from a (scaled) Gaussian - distribution and each entry of `b` is sampled independently and uniformly from - [0, 2 * pi]. - - For a single input feature vector x in R^d, its RFFM is defined as: - ``` - sqrt(2/D) * cos(x * Omega + b) - ``` - where `cos` is the element-wise cosine function and `x, b` are represented as - row vectors. The aforementioned paper shows that the linear kernel of - RFFM-mapped vectors approximates the Gaussian kernel of the initial vectors. + The mapping uses a matrix \\(\Omega \in R^{d x D}\\) and a bias vector + \\(b \in R^D\\) where \\(d\\) is the input dimension (number of dense input + features) and \\(D\\) is the output dimension (i.e., dimension of the feature + space the input is mapped to). Each entry of \\(\Omega\\) is sampled i.i.d. + from a (scaled) Gaussian distribution and each entry of \\(b\\) is sampled + independently and uniformly from [0, \\(2 * \pi\\)]. + + For a single input feature vector \\(x \in R^d\\), its RFFM is defined as: + $$\sqrt(2/D) * cos(x * \Omega + b)$$ + + where \\(cos\\) is the element-wise cosine function and \\(x, b\\) are + represented as row vectors. The aforementioned paper shows that the linear + kernel of RFFM-mapped vectors approximates the Gaussian kernel of the initial + vectors. """ def __init__(self, input_dim, output_dim, stddev=1.0, seed=1, name=None): - """Constructs a RandomFourierFeatureMapper instance. + r"""Constructs a RandomFourierFeatureMapper instance. Args: input_dim: The dimension (number of features) of the tensors to be mapped. @@ -68,11 +66,11 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): stddev: The standard deviation of the Gaussian kernel to be approximated. The error of the classifier trained using this approximation is very sensitive to this parameter. - seed: An integer used to initialize the parameters (`Omega` and `b`) of - the mapper. For repeatable sequences across different invocations of the - mapper object (for instance, to ensure consistent mapping both at - training and eval/inference if these happen in different invocations), - set this to the same integer. + seed: An integer used to initialize the parameters (\\(\Omega\\) and + \\(b\\)) of the mapper. For repeatable sequences across different + invocations of the mapper object (for instance, to ensure consistent + mapping both at training and eval/inference if these happen in + different invocations), set this to the same integer. name: name for the mapper object. """ # TODO(sibyl-vie3Poto): Maybe infer input_dim and/or output_dim (if not explicitly diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 6f4a264485993ab737723171409042b4a9673669..2ff4d41d75fe59fb765a83e1b6a5b3eaad9d9163 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -31,10 +31,10 @@ from tensorflow.python.platform import googletest def _inner_product(x, y): - """Inner product between tensors x and y. + r"""Inner product between tensors x and y. The input tensors are assumed to be in ROW representation, that is, the method - returns x * y^T. + returns \\(x * y^T\\). Args: x: input tensor in row format @@ -131,10 +131,6 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): mapped_dim = 5000 stddev = 5.0 - # TODO(sibyl-vie3Poto): Reduce test's running time before moving to third_party. One - # possible way to speed the test up is to compute both the approximate and - # the exact kernel matrix directly using matrix operations instead of - # computing the values for each pair of points separately. points_shape = [1, input_dim] points = [ random_ops.random_uniform(shape=points_shape, maxval=1.0) diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD index 9a5759bf14f753bbc50d3ef8f54ceab7daf745ab..b719046b37ac761d56e8d5aa34772103be691cd6 100644 --- a/tensorflow/contrib/kfac/BUILD +++ b/tensorflow/contrib/kfac/BUILD @@ -24,15 +24,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 762a2f0b57e95e2fef3dd177070701afb410e93a..102626925db560e47cdc73eb1e25e08836cb4fba 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,5 +1,10 @@ # K-FAC: Kronecker-Factored Approximate Curvature +# WARNING: +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== +# ==== + **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. When applied to feedforward and convolutional neural networks, K-FAC can converge `>3.5x` diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD index 89965eda374b2b403f680fc77eb923d0e660d1e2..8186fa1c62cb952f86614a96c3965bcddae1686e 100644 --- a/tensorflow/contrib/kfac/examples/BUILD +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -28,8 +28,28 @@ py_library( ) py_binary( - name = "convnet_mnist_main", - srcs = ["convnet_mnist_main.py"], + name = "convnet_mnist_single_main", + srcs = ["convnet_mnist_single_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_multi_tower_main", + srcs = ["convnet_mnist_multi_tower_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_distributed_main", + srcs = ["convnet_mnist_distributed_main.py"], srcs_version = "PY2AND3", deps = [ ":convnet", @@ -58,15 +78,3 @@ py_library( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 39d80addaac1fe855a37255b32bf4412b99df46a..d6b1a61b716ab7412f6b09ba2cfbc4325f790637 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -37,6 +37,8 @@ import tensorflow as tf from tensorflow.contrib.kfac.examples import mlp from tensorflow.contrib.kfac.examples import mnist +from tensorflow.contrib.kfac.python.ops import optimizer as opt + lc = tf.contrib.kfac.layer_collection oq = tf.contrib.kfac.op_queue @@ -48,12 +50,18 @@ __all__ = [ "linear_layer", "build_model", "minimize_loss_single_machine", - "minimize_loss_distributed", + "distributed_grads_only_and_ops_chief_worker", + "distributed_grads_and_ops_dedicated_workers", "train_mnist_single_machine", - "train_mnist_distributed", + "train_mnist_distributed_sync_replicas", + "train_mnist_multitower" ] +# Inverse update ops will be run every _INVERT_EVRY iterations. +_INVERT_EVERY = 10 + + def conv_layer(layer_id, inputs, kernel_size, out_channels): """Builds a convolutional layer with ReLU non-linearity. @@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection): accuracy = tf.reduce_mean( tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) + with tf.device("/cpu:0"): + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) # Register parameters. K-FAC needs to know about the inputs, outputs, and # parameters of each conv/fully connected layer and the logits powering the @@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection): def minimize_loss_single_machine(loss, accuracy, layer_collection, + device="/gpu:0", session_config=None): """Minimize loss with K-FAC on a single machine. - A single Session is responsible for running all of K-FAC's ops. + A single Session is responsible for running all of K-FAC's ops. The covariance + and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse + update ops are run on this device. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() + g_step = tf.train.get_or_create_global_step() optimizer = opt.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=[device], + inv_devices=[device], momentum=0.9) - train_op = optimizer.minimize(loss, global_step=global_step) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, train_op]) - if global_step_ % 100 == 0: + if global_step_ % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks): return int(np.ceil(0.6 * num_tasks)) -def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection): - """Minimize loss with an synchronous implementation of K-FAC. +def _make_distributed_train_op( + task_id, + num_worker_tasks, + num_ps_tasks, + layer_collection +): + """Creates optimizer and distributed training op. + + Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes + the train op. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC + optimizer. + optimizer: Instance of `opt.KfacOptimizer`. + global_step: `tensor`, Global step. + """ + tf.logging.info("Task id : %d", task_id) + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), + total_num_replicas=num_worker_tasks) + return sync_optimizer, optimizer, global_step + + +def distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection, invert_every=10): + """Minimize loss with a synchronous implementation of K-FAC. - Different tasks are responsible for different parts of K-FAC's Ops. The first - 60% of tasks update weights; the next 20% accumulate covariance statistics; - the last 20% invert the matrices used to precondition gradients. + All workers perform gradient computation. Chief worker applies gradient after + averaging the gradients obtained from all the workers. All workers block + execution until the update is applied. Chief worker runs covariance and + inverse update ops. Covariance and inverse matrices are placed on parameter + servers in a round robin manner. For further details on synchronous + distributed optimization check `tf.train.SyncReplicasOptimizer`. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. @@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + invert_every: `int`, Number of steps between update the inverse. Returns: final value for 'accuracy'. @@ -278,20 +352,82 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, Raises: ValueError: if task_id >= num_worker_tasks. """ - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) + + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + tf.logging.info("Starting training.") + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + if is_chief: + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, invert_every), 0), + lambda: make_update_op(inv_update_thunks), + tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = sync_optimizer.minimize(loss, global_step=global_step) + else: train_op = sync_optimizer.minimize(loss, global_step=global_step) + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, train_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + return accuracy_ + + +def distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection): + """Minimize loss with a synchronous implementation of K-FAC. + + Different workers are responsible for different parts of K-FAC's Ops. The + first 60% of tasks compute gradients; the next 20% accumulate covariance + statistics; the last 20% invert the matrices used to precondition gradients. + The chief worker applies the gradient . + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + inv_update_queue = oq.OpQueue(inv_update_ops) + tf.logging.info("Starting training.") is_chief = (task_id == 0) hooks = [sync_optimizer.make_session_run_hook(is_chief)] @@ -306,7 +442,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = optimizer.cov_update_op + learning_op = cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): # TODO(duckworthd): Running this op before cov_update_op has been run a # few times can result in "InvalidArgumentError: Cholesky decomposition @@ -324,13 +460,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, return accuracy_ -def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): +def train_mnist_single_machine(data_dir, + num_epochs, + use_fake_data=False, + device="/gpu:0"): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. @@ -350,22 +491,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. - return minimize_loss_single_machine(loss, accuracy, layer_collection) + return minimize_loss_single_machine( + loss, accuracy, layer_collection, device=device) def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True): + use_fake_data=True, devices=None): """Train a ConvNet on MNIST. + Training data is split equally among the towers. Each tower computes loss on + its own batch of data and the loss is aggregated on the CPU. The model + variables are placed on first tower. The covariance and inverse update ops + and variables are placed on GPUs in a round robin manner. + Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. + devices: string, Either list of CPU or GPU. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. """ + if devices: + device_count = {"GPU": num_towers} + else: + device_count = {"CPU": num_towers} + + devices = devices or [ + "/cpu:{}".format(tower_id) for tower_id in range(num_towers) + ] # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 @@ -388,7 +545,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, layer_collection = lc.LayerCollection() tower_results = [] for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): + with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) @@ -402,34 +559,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, accuracy = tf.reduce_mean(accuracies) # Fit model. + session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize_loss_single_machine( - loss, accuracy, layer_collection, session_config=session_config) + allow_soft_placement=False, + device_count=device_count, + ) + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices, + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() -def train_mnist_distributed(task_id, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - use_fake_data=False): - """Train a ConvNet on MNIST. + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=g_step) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, train_op]) + + if global_step_ % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + +def train_mnist_distributed_sync_replicas(task_id, + is_chief, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + op_strategy, + use_fake_data=False): + """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. + op_strategy: `string`, Strategy to run the covariance and inverse + ops. If op_strategy == `chief_worker` then covaraiance and inverse + update ops are run on chief worker otherwise they are run on dedicated + workers. + use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. + + Raises: + ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") @@ -448,9 +650,17 @@ def train_mnist_distributed(task_id, # Fit model. checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, - master, checkpoint_dir, loss, accuracy, - layer_collection) + if op_strategy == "chief_worker": + return distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + elif op_strategy == "dedicated_workers": + return distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + else: + raise ValueError("Only supported op strategies are : {}, {}".format( + "chief_worker", "dedicated_workers")) if __name__ == "__main__": diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c2d4a9e9bfcc4bfb55a25d2f23e66afe5b1375 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Distributed training with sync replicas optimizer. See +`convnet.train_mnist_distributed_sync_replicas` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_integer("task", -1, "Task identifier") +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") +flags.DEFINE_string( + "cov_inv_op_strategy", "chief_worker", + "In dist training mode run the cov, inv ops on chief or dedicated workers." +) +flags.DEFINE_string("master", "local", "Session master.") +flags.DEFINE_integer("ps_tasks", 2, + "Number of tasks in the parameter server job.") +flags.DEFINE_integer("replicas_to_aggregate", 5, + "Number of replicas to aggregate.") +flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") +flags.DEFINE_integer("num_epochs", None, "Number of epochs.") + + +def _is_chief(): + """Determines whether a job is the chief worker.""" + if "chief_worker" in FLAGS.brain_jobs: + return FLAGS.brain_job_name == "chief_worker" + else: + return FLAGS.task == 0 + + +def main(unused_argv): + _ = unused_argv + convnet.train_mnist_distributed_sync_replicas( + FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, + FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py similarity index 57% rename from tensorflow/contrib/kfac/examples/convnet_mnist_main.py rename to tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py index b0c6fbde198850c76af0bc1600dc23e926227229..4249bf8a8d9d3a5beb87d4140a55b0ee6eadbc64 100644 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py @@ -14,44 +14,35 @@ # ============================================================================== r"""Train a ConvNet on MNIST using K-FAC. -See convnet.py for details. +Multi tower training mode. See `convnet.train_mnist_multitower` for details. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import flags import tensorflow as tf from tensorflow.contrib.kfac.examples import convnet -FLAGS = None +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") +flags.DEFINE_integer("num_towers", 2, + "Number of towers for multi tower training.") -def main(argv): - _ = argv - - if FLAGS.num_towers > 1: - convnet.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) +def main(unused_argv): + _ = unused_argv + assert FLAGS.num_towers > 1 + devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] + convnet.train_mnist_multitower( + FLAGS.data_dir, + num_epochs=200, + num_towers=FLAGS.num_towers, + devices=devices) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + tf.app.run(main=main) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py similarity index 63% rename from tensorflow/contrib/bayesflow/python/ops/hmc.py rename to tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py index 7fd5652c5c3e085b23c05baef6e3a42b7a42e08f..2c1f09936073a34816da61d771f59e848b8787af 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py @@ -12,20 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" +r"""Train a ConvNet on MNIST using K-FAC. + +Train on single machine. See `convnet.train_mnist_single_machine` for details. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member -from tensorflow.python.util import all_util -_allowed_symbols = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", -] +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") + + +def main(unused_argv): + convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) + -all_util.remove_undocumented(__name__, _allowed_symbols) +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 87eed03888c894a04c0521d1ce5ee8975b60776b..ea2b252a05702d5adcdc5f70d713277ba604f691 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -105,18 +105,21 @@ def build_model(examples, labels, num_labels, layer_collection): return loss, accuracy -def minimize(loss, accuracy, layer_collection, session_config=None): +def minimize(loss, accuracy, layer_collection, num_towers, session_config=None): """Minimize 'loss' with KfacOptimizer. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance. Describes layers in model. + num_towers: int. Number of CPUs to split minibatch across. session_config: tf.ConfigProto. Configuration for tf.Session(). Returns: accuracy of classifier on final minibatch. """ + devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers)) + # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 # every 10k iterations. tf.logging.info("Building KFAC Optimizer.") @@ -125,27 +128,38 @@ def minimize(loss, accuracy, layer_collection, session_config=None): learning_rate=tf.train.exponential_decay( 0.00002, global_step, 10000, 0.5, staircase=True), cov_ema_decay=0.95, - damping=0.0001, + damping=0.0005, layer_collection=layer_collection, - momentum=0.99) - train_op = optimizer.minimize(loss, global_step=global_step) + momentum=0.99, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices) + + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt + # once that gets moved over? Could still leave more advanced examples as they + # are (e.g. train_mnist_estimator in this file) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # We update the inverses only every 20 iterations. + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, 100), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - # K-FAC has 3 primary ops, - # - train_op: Update the weights with the minibatch's gradient. - # - cov_update_op: Update statistics used for building K-FAC's - # preconditioner matrix. - # - inv_update_op: Update preconditioner matrix using statistics. - # - # The first 2 of these are cheap and should be done with each step. The - # latter is more expensive, and should be updated ~100 iterations. - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, train_op]) if global_step_ % 100 == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %f", @@ -180,7 +194,7 @@ def train_mnist(data_dir, num_epochs, use_fake_data=False): loss, accuracy = build_model(examples, labels, 10, layer_collection) # Fit model. - minimize(loss, accuracy, layer_collection) + minimize(loss, accuracy, layer_collection, 1) def train_mnist_multitower(data_dir, @@ -238,7 +252,8 @@ def train_mnist_multitower(data_dir, "CPU": num_towers }) return minimize( - loss, accuracy, layer_collection, session_config=session_config) + loss, accuracy, layer_collection, num_towers, + session_config=session_config) def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): @@ -298,13 +313,26 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): layer_collection=layer_collection, momentum=0.99) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + def make_batch_executed_op(update_thunks, batch_size=1): + return tf.group(*tf.contrib.kfac.utils.batch_execute( + global_step, update_thunks, batch_size=batch_size)) + # Run cov_update_op every step. Run 1 inv_update_ops per step. - cov_update_op = optimizer.cov_update_op - inv_update_op = tf.group( - tf.contrib.kfac.utils.batch_execute( - global_step, optimizer.inv_update_thunks, batch_size=1)) - with tf.control_dependencies([cov_update_op, inv_update_op]): - train_op = optimizer.minimize(loss, global_step=global_step) + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # But make sure to execute all the inverse ops on the first step + inverse_op = tf.cond(tf.equal(global_step, 0), + lambda: make_update_op(inv_update_thunks), + lambda: make_batch_executed_op(inv_update_thunks)) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) # Print metrics every 5 sec. hooks = [ diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD index ce7da95c124beaed4773d68ce0d0c41f187f7c9d..ede7f183fe24f26bd86e232e831dea5f8ea1fdc4 100644 --- a/tensorflow/contrib/kfac/examples/tests/BUILD +++ b/tensorflow/contrib/kfac/examples/tests/BUILD @@ -50,15 +50,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 8d86c2bb5150cd4bc8a2b21ba050e904929e0fe9..adecda71666ee74bc577859589060fa65baf5166 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase): def testMinimizeLossSingleMachine(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy, - layer_collection) - self.assertLess(accuracy_, 1.0) + accuracy_ = convnet.minimize_loss_single_machine( + loss, accuracy, layer_collection, device="/cpu:0") + self.assertLess(accuracy_, 2.0) def testMinimizeLossDistributed(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_distributed( + accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", @@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase): loss=loss, accuracy=accuracy, layer_collection=layer_collection) - self.assertLess(accuracy_, 1.0) + self.assertLess(accuracy_, 2.0) def testTrainMnistSingleMachine(self): with tf.Graph().as_default(): @@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase): # but there are too few parameters for the model to effectively memorize # the training set the way an MLP can. convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True) + data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") def testTrainMnistMultitower(self): with tf.Graph().as_default(): @@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase): def testTrainMnistDistributed(self): with tf.Graph().as_default(): # Ensure model training doesn't crash. - convnet.train_mnist_distributed( + convnet.train_mnist_distributed_sync_replicas( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", data_dir=None, - num_epochs=1, + num_epochs=2, + op_strategy="chief_worker", use_fake_data=True) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index f4ed978174a9ddd8b54a88e60bfb48a67a2e76d2..6e4a8d71baa85d05d514e4683016c2f4d299ec8e 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -36,6 +36,7 @@ py_test( srcs = ["fisher_factors_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -57,6 +58,7 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:linear_operator", "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -95,6 +97,7 @@ py_test( srcs = ["optimizer_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", @@ -113,6 +116,7 @@ py_test( name = "utils_test", srcs = ["utils_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/contrib/tpu", @@ -154,15 +158,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index bfdb69ad02caaa57827e0ae6b3c9fc0d0ed03754..0e65d419a31838a62d8ab37a5f30427c925382b4 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -40,30 +39,6 @@ from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] -class DeviceContextGeneratorTest(test.TestCase): - - def testNoDevice(self): - device_context_generator = estimator._DeviceContextGenerator(None) - with ops.device("/device:CPU:0"): # This is what will be used - with device_context_generator(): # Does nothing - a = constant_op.constant([2.0], name="a") - self.assertEqual("/device:CPU:0", a.op.device) - - def testTwoDevices(self): - device_context_generator = estimator._DeviceContextGenerator( - ["/device:GPU:0", "/device:GPU:1"]) - with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes - with device_context_generator(): - a = constant_op.constant([2.0], name="a") - with device_context_generator(): - b = constant_op.constant([2.0], name="b") - with device_context_generator(): - c = constant_op.constant([2.0], name="c") - self.assertEqual("/device:GPU:0", a.op.device) - self.assertEqual("/device:GPU:1", b.op.device) - self.assertEqual("/device:GPU:0", c.op.device) - - class EstimatorTest(test.TestCase): def setUp(self): @@ -90,57 +65,113 @@ class EstimatorTest(test.TestCase): def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + est.make_vars_and_create_op_thunks() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_vars_and_create_op_thunks() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_vars_and_create_op_thunks() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, - "not_a_real_mode") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") + est.make_vars_and_create_op_thunks() - def testModeListCorrect(self): + def testGradientsModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) - self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") + est.make_vars_and_create_op_thunks() - def testAllModesBuild(self): - for mode in _ALL_ESTIMATION_MODES: - with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, mode) + def testEmpiricalModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") + est.make_vars_and_create_op_thunks() + + def testCurvaturePropModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") + est.make_vars_and_create_op_thunks() + + def testExactModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") + est.make_vars_and_create_op_thunks() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, - cov_ema_decay=0.0, - damping=0.0) + damping=0.2, + cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] - cov_update_op_thunks = fisher_estimator.cov_update_thunks cov_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks)]) @@ -172,23 +203,64 @@ class EstimatorTest(test.TestCase): sess.run(cov_update_op) sess.run(increment_global_step) + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_thunks, + inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( + scope="test") + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, - cov_ema_decay=0.0, - damping=0.0) + damping=0.2, + cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, _, inv_variable_thunks, + inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + for thunk in inv_variable_thunks: + thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._inverses_by_damping.values() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] - inv_update_op_thunks = fisher_estimator.inv_update_thunks inv_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index fb4b3a241c1e9fd82e7bf630fd57295917048fbd..86ec7a095afdf4ecf7892a7e4e5d47dcdc239ed1 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -34,6 +36,19 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + +# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our +# inverse is something other than the identity" are actually broken. They never +# run the covariance update ops and so the inverse actually is the identity +# (possible plus the damping term, which would still make it a multiple of the +# identity). + + def _make_psd(dim): """Constructs a PSD matrix of the given dimension.""" mat = np.ones((dim, dim), dtype=np.float32) @@ -46,8 +61,9 @@ class UtilsTest(test.TestCase): def testComputePiTracenorm(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) + diag = ops.convert_to_tensor([1., 2., 0., 1.]) + left_factor = lo.LinearOperatorDiag(diag) + right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) # pi is the sqrt of the left trace norm divided by the right trace norm pi = fb.compute_pi_tracenorm(left_factor, right_factor) @@ -63,7 +79,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -72,7 +88,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -81,7 +97,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -91,9 +107,12 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -109,9 +128,12 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -127,10 +149,13 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) damping = 0.5 block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) @@ -154,7 +179,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -163,7 +188,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -172,7 +197,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -182,9 +207,10 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -200,9 +226,10 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -217,10 +244,11 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) damping = 0.5 block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) sess.run(state_ops.assign(block._factor._cov, cov)) @@ -233,7 +261,6 @@ class NaiveDiagonalFBTest(test.TestCase): full = sess.run(block.full_fisher_block()) explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - self.assertAllClose(output_flat, explicit) @@ -312,8 +339,8 @@ class FullyConnectedDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -364,9 +391,10 @@ class FullyConnectedDiagonalFBTest(test.TestCase): block = fb.FullyConnectedDiagonalFB( lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) @@ -389,12 +417,12 @@ class EmbeddingKFACFBTest(test.TestCase): # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = array_ops.constant(0.) - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) def testMultiplyInverse(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -407,12 +435,17 @@ class EmbeddingKFACFBTest(test.TestCase): # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = array_ops.constant(0.) - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Create a sparse update. indices = array_ops.constant([1, 3, 4]) @@ -443,7 +476,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) @@ -453,10 +486,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): @@ -464,10 +497,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -475,9 +508,15 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -501,9 +540,14 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -524,13 +568,20 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): outputs = array_ops.zeros([32, output_dim]) params = array_ops.zeros([input_dim, output_dim]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) @@ -653,8 +704,8 @@ class ConvDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result, atol=1e-3) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -715,9 +766,10 @@ class ConvDiagonalFBTest(test.TestCase): block = fb.ConvDiagonalFB( lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) @@ -727,6 +779,54 @@ class ConvDiagonalFBTest(test.TestCase): return multiply_result, multiply_inverse_result +class DepthwiseConvKFCBasicFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Ensure inverse update op doesn't crash. + sess.run(tf_variables.global_variables_initializer()) + sess.run([ + factor.make_inverse_update_ops() + for factor in layer_collection.get_factors() + ]) + + # Ensure inverse-vector multiply doesn't crash. + output = block.multiply_inverse(params) + sess.run(output) + + # Ensure same shape. + self.assertAllEqual(output.shape, params.shape) + + class ConvKFCBasicFBTest(test.TestCase): def _testConvKFCBasicFBInitParams(self, params): @@ -738,16 +838,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.constant(params) inputs = random_ops.random_normal((2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testConvKFCBasicFBInitParamsParamsTuple(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) def testConvKFCBasicFBInitParamsParamsSingle(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.])]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -755,11 +856,16 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -781,12 +887,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertFalse(block._has_bias) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -804,12 +915,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = [random_ops.random_normal((2, 2, 2, 2))] inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertTrue(block._has_bias) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -827,12 +943,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.zeros((2, 2, 2, 2)) inputs = array_ops.zeros((2, 2, 2, 2)) outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) @@ -857,9 +978,9 @@ class FullyConnectedSeriesFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) + block.register_additional_tower([inputs], [outputs]) + self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) def testInstantiateFactorsHasBias(self): with ops.Graph().as_default(): @@ -868,11 +989,10 @@ class FullyConnectedSeriesFBTest(test.TestCase): outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), - inputs=[inputs], - outputs=[outputs], has_bias=True) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) + block.instantiate_factors((((grads,),),), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): @@ -881,11 +1001,10 @@ class FullyConnectedSeriesFBTest(test.TestCase): outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), - inputs=[inputs], - outputs=[outputs], has_bias=False) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) + block.instantiate_factors((((grads,),),), 0.5) def as_tensors(tensor_or_tuple): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 66e18974abfadaad5d7a20b40d0b1352bfda67ee..fad47cd02f372e0b180645b5636965514bafe6b0 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np import numpy.random as npr +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,36 +30,20 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test -class MaybeColocateTest(test.TestCase): +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) - def setUp(self): - self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS - - def tearDown(self): - ff.set_global_constants( - colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs) - def testFalse(self): - ff.set_global_constants(colocate_cov_ops_with_inputs=False) - with tf_ops.Graph().as_default(): - a = constant_op.constant([2.0], name='a') - with ff.maybe_colocate_with(a): - b = constant_op.constant(3.0, name='b') - self.assertEqual([b'loc:@a'], a.op.colocation_groups()) - self.assertEqual([b'loc:@b'], b.op.colocation_groups()) - - def testTrue(self): - ff.set_global_constants(colocate_cov_ops_with_inputs=True) - with tf_ops.Graph().as_default(): - a = constant_op.constant([2.0], name='a') - with ff.maybe_colocate_with(a): - b = constant_op.constant(3.0, name='b') - self.assertEqual([b'loc:@a'], a.op.colocation_groups()) - self.assertEqual([b'loc:@a'], b.op.colocation_groups()) +def make_damping_func(damping): + return fb._package_func(lambda: damping, damping) class FisherFactorTestingDummy(ff.FisherFactor): @@ -92,26 +77,44 @@ class FisherFactorTestingDummy(ff.FisherFactor): def get_cov(self): return NotImplementedError - def left_multiply(self, x, damping): + def instantiate_inv_variables(self): return NotImplementedError - def right_multiply(self, x, damping): - return NotImplementedError + def _num_towers(self): + raise NotImplementedError - def left_multiply_inverse(self, x, damping): - return NotImplementedError + def _get_data_device(self): + raise NotImplementedError - def right_multiply_inverse(self, x, damping): - return NotImplementedError + def register_matpower(self, exp, damping_func): + raise NotImplementedError + + def register_cholesky(self, damping_func): + raise NotImplementedError + + def register_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_matpower(self, exp, damping_func): + raise NotImplementedError + + def get_cholesky(self, damping_func): + raise NotImplementedError + + def get_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_cov_as_linear_operator(self): + raise NotImplementedError -class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): - """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. +class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): + """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. """ def __init__(self, shape): self._shape = shape - super(InverseProvidingFactorTestingDummy, self).__init__() + super(DenseSquareMatrixFactorTestingDummy, self).__init__() @property def _var_scope(self): @@ -135,6 +138,12 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def instantiate_covariance(self): pass + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class NumericalUtilsTest(test.TestCase): @@ -237,50 +246,64 @@ class FisherFactorTest(test.TestCase): self.assertEqual(0, len(factor.make_inverse_update_ops())) -class InverseProvidingFactorTest(test.TestCase): +class DenseSquareMatrixFactorTest(test.TestCase): def testRegisterDampedInverse(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [2, 2] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' - dampings = 0.1, 1e-1, 0.00001, 1e-5 + damping_funcs = [make_damping_func(0.1), + make_damping_func(0.1), + make_damping_func(1e-5), + make_damping_func(1e-5)] + for damping_func in damping_funcs: + factor.register_inverse(damping_func) - for damping in dampings: - factor.register_damped_inverse(damping) + factor.instantiate_inv_variables() - self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys())) - inv = factor._inverses_by_damping[dampings[0]] - self.assertEqual(inv, factor._inverses_by_damping[dampings[1]]) - self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]]) - self.assertEqual(factor._inverses_by_damping[dampings[2]], - factor._inverses_by_damping[dampings[3]]) + inv = factor.get_inverse(damping_funcs[0]).to_dense() + self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) + self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), + factor.get_inverse(damping_funcs[3]).to_dense()) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]], - factor_vars) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + self.assertEqual(set([inv, + factor.get_inverse(damping_funcs[2]).to_dense()]), + set(factor_tensors)) self.assertEqual(shape, inv.get_shape()) def testRegisterMatpower(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [3, 3] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' - factor.register_matpower(1, 0.5) - factor.register_matpower(2, 0.5) + # TODO(b/74201126): Change to using the same func for both once + # Topohash is in place. + damping_func_1 = make_damping_func(0.5) + damping_func_2 = make_damping_func(0.5) + + factor.register_matpower(-0.5, damping_func_1) + factor.register_matpower(2, damping_func_2) + + factor.instantiate_inv_variables() - self.assertEqual( - set([(1, 0.5), (2, 0.5)]), - set(factor._matpower_by_exp_and_damping.keys())) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - matpower1 = factor.get_matpower(1, 0.5) - matpower2 = factor.get_matpower(2, 0.5) - self.assertListEqual([matpower1, matpower2], factor_vars) + + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() + matpower2 = factor.get_matpower(2, damping_func_2).to_dense() + + self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower2.get_shape()) @@ -296,20 +319,28 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[1., 2.], [3., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + damping_funcs = [] for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): - factor.register_damped_inverse(1. / i) + damping_funcs.append(make_damping_func(1./i)) + + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + factor.register_inverse(damping_funcs[i]) + + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] sess.run(ops) - for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) + new_invs.append( + sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) + # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): for j in range(i + 1, len(new_invs)): @@ -320,18 +351,20 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[6., 2.], [2., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power damping = 0.5 + damping_func = make_damping_func(damping) - factor.register_matpower(exp, damping) + factor.register_matpower(exp, damping_func) + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) sess.run(ops[0]) - matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)]) + matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) self.assertAllClose(matpower, matpower_np) @@ -339,21 +372,24 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - factor.register_damped_inverse(0) + damping_func = make_damping_func(0) + + factor.register_inverse(damping_func) + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor._inverses_by_damping[0]) + old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose( sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ops) - new_inv = sess.run(factor._inverses_by_damping[0]) + new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose(new_inv, np.linalg.inv(cov)) @@ -364,6 +400,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) def testFullFactorInitFloat64(self): @@ -372,6 +409,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 6], cov.get_shape().as_list()) @@ -381,6 +419,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.FullFactor((tensor,), 2) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -394,7 +433,8 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) + factor.instantiate_cov_variables() + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -402,7 +442,8 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - cov = factor.get_cov_var() + factor.instantiate_cov_variables() + cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -411,6 +452,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 2) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -424,7 +466,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): input_ids = array_ops.constant([[0], [1], [4]]) vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) - cov = factor.get_cov_var() + factor.instantiate_cov_variables() + cov = factor.get_cov() self.assertEqual(cov.shape.as_list(), [vocab_size]) def testCovarianceUpdateOp(self): @@ -432,6 +475,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): input_ids = array_ops.constant([[0], [1], [4]]) vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() cov_update_op = factor.make_covariance_update_op(0.0) with self.test_session() as sess: @@ -440,6 +484,118 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) +class ConvDiagonalFactorTest(test.TestCase): + + def setUp(self): + self.batch_size = 10 + self.height = self.width = 32 + self.in_channels = 3 + self.out_channels = 1 + self.kernel_height = self.kernel_width = 3 + self.strides = [1, 2, 2, 1] + self.data_format = 'NHWC' + self.padding = 'SAME' + self.kernel_shape = [ + self.kernel_height, self.kernel_width, self.in_channels, + self.out_channels + ] + + def testInit(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format) + factor.instantiate_cov_variables() + + # Ensure covariance matrix's shape makes sense. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels, + self.out_channels + ], + factor.get_cov().shape.as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + # Construct all arguments such that convolution kernel is applied in + # exactly one spatial location. + inputs = np.random.randn( + 1, # batch_size + self.kernel_height, + self.kernel_width, + self.in_channels) # in_channels + outputs_grad = np.random.randn( + 1, # batch_size + 1, # output_height + 1, # output_width + self.out_channels) + + factor = ff.ConvDiagonalFactor( + (constant_op.constant(inputs),), + ((constant_op.constant(outputs_grad),),), + self.kernel_shape, + strides=[1, 1, 1, 1], + padding='VALID') + factor.instantiate_cov_variables() + + # Completely forget initial value on first update. + cov_update_op = factor.make_covariance_update_op(0.0) + + # Ensure new covariance value is same as outer-product of inputs/outputs + # vectorized, squared. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + cov = sess.run(cov_update_op) + expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 + self.assertAllClose(expected_cov, cov) + + def testHasBias(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format, + has_bias=True) + factor.instantiate_cov_variables() + + # Ensure shape accounts for bias. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels + 1, + self.out_channels + ], + factor.get_cov().shape.as_list()) + + # Ensure update op doesn't crash. + cov_update_op = factor.make_covariance_update_op(0.0) + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(cov_update_op) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, @@ -449,7 +605,8 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual(final_shape, cov.get_shape().as_list()) @@ -466,7 +623,8 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -476,40 +634,171 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,)) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) -class ConvInputKroneckerFactorTest(test.TestCase): +class ConvFactorTestCase(test.TestCase): + + def assertMatrixRank(self, rank, matrix, atol=1e-5): + assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' + eigvals = np.linalg.eigvals(matrix) + nnz_eigvals = np.sum(eigvals > atol) + self.assertEqual( + rank, + nnz_eigvals, + msg=('Found %d of %d expected non-zero eigenvalues: %s.' % + (nnz_eigvals, rank, eigvals))) + + +class ConvInputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**3 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0),), + filter_shape=(width, width, width, in_channels, out_channels), + padding='SAME', + strides=(2, 2, 2), + extract_patches_fn='extract_convolution_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + input_size = in_channels * (width**3) + self.assertEqual([input_size, input_size], + factor.get_cov().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank-8, as the filter will be applied at each corner of + # the 4-D cube. + self.assertMatrixRank(8, cov) + + def testPointwiseConv2d(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 1, 1, 1), + extract_patches_fn='extract_pointwise_conv2d_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + self.assertEqual([in_channels, in_channels], + factor.get_cov().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank-9, as the filter will be applied at each location. + self.assertMatrixRank(9, cov) + + def testStrides(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 2, 1, 1), + extract_patches_fn='extract_image_patches', + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be the sum of 3 * 2 = 6 outer products. + self.assertMatrixRank(6, cov) + + def testDilationRate(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(3, 3, in_channels, out_channels), + padding='SAME', + extract_patches_fn='extract_image_patches', + strides=(1, 1, 1, 1), + dilation_rate=(1, width, width, 1), + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank = in_channels, as only the center of the filter + # receives non-zero input for each input channel. + self.assertMatrixRank(in_channels, cov) def testConvInputKroneckerFactorInitNoBias(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=False) + inputs=(tensor,), + filter_shape=(1, 2, 3, 4), + padding='SAME', + has_bias=False) + factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3, 1 * 2 * 3], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], @@ -517,37 +806,82 @@ class ConvInputKroneckerFactorTest(test.TestCase): def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True) + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]], - new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose( + [ + [(1. + 4.) / 2., (1. + 2.) / 2.], # + [(1. + 2.) / 2., (1. + 1.) / 2.] + ], # + new_cov) def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) - factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1], - 'SAME') + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37], [37, 41]], new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose([[(1. + 4.) / 2.]], new_cov) + + def testSubSample(self): + with tf_ops.Graph().as_default(): + patches_1 = array_ops.constant(1, shape=(10, 2)) + patches_2 = array_ops.constant(1, shape=(10, 8)) + patches_3 = array_ops.constant(1, shape=(3, 3)) + patches_1_sub = ff._subsample_for_cov_computation(patches_1) + patches_2_sub = ff._subsample_for_cov_computation(patches_2) + patches_3_sub = ff._subsample_for_cov_computation(patches_3) + patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0] + patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0] + patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0] + self.assertEqual(2, patches_1_sub_batch_size) + self.assertEqual(8, patches_2_sub_batch_size) + self.assertEqual(3, patches_3_sub_batch_size) + + +class ConvOutputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + out_channels = width**3 + factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ + random_ops.random_uniform( + (batch_size, width, width, width, out_channels), seed=0) + ],)) + factor.instantiate_cov_variables() -class ConvOutputKroneckerFactorTest(test.TestCase): + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank 3^3, as each spatial position donates a rank-1 + # update. + self.assertMatrixRank(width**3, cov) def testConvOutputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) def testConvOutputKroneckerFactorInitFloat64(self): @@ -555,23 +889,18 @@ class ConvOutputKroneckerFactorTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([5, 5], cov.get_shape().as_list()) - def testConvOutputKroneckerFactorInitNotEnoughDims(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - with self.assertRaises(IndexError): - ff.ConvOutputKroneckerFactor(tensor) - def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) - factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),)) + factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -584,8 +913,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) def testFullyConnectedMultiKFInitFloat64(self): @@ -593,8 +922,8 @@ class FullyConnectedMultiKFTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([3, 3], cov.get_shape().as_list()) @@ -603,8 +932,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -614,8 +943,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,)) + factor = ff.FullyConnectedMultiKF(((tensor,),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index b8ccbeadd0a9d69edb41fef50e3edb090457adf2..cb80fca3705308f92e308e2a840336fb72d0fa62 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test class MockFisherBlock(object): """A fake FisherBlock.""" - num_registered_minibatches = 2 + num_registered_towers = 2 def __init__(self, name='MockFisherBlock'): self.name = name @@ -104,22 +104,53 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5))) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], - 'SAME', - array_ops.ones((1, 1, 1, 1)), - array_ops.constant(3), + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5)), approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_separable_conv2d( + depthwise_params=array_ops.ones((3, 3, 1, 2)), + pointwise_params=array_ops.ones((1, 1, 2, 4)), + inputs=array_ops.ones((32, 5, 5, 1)), + depthwise_outputs=array_ops.ones((32, 5, 5, 2)), + pointwise_outputs=array_ops.ones((32, 5, 5, 4)), + strides=[1, 1, 1, 1], + padding='SAME') + lc.register_convolution( + params=array_ops.ones((3, 3, 1, 8)), + inputs=array_ops.ones((32, 5, 5, 1)), + outputs=array_ops.ones((32, 5, 5, 8)), + padding='SAME') lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( array_ops.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - - self.assertEqual(6, len(lc.get_blocks())) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + + self.assertEqual(12, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): @@ -237,16 +268,16 @@ class LayerCollectionTest(test.TestCase): # Create a new loss function by name. lc.register_categorical_predictive_distribution(logits, name='loss1') - self.assertEqual(1, len(lc.losses)) + self.assertEqual(1, len(lc.towers_by_loss)) # Add logits to same loss function. lc.register_categorical_predictive_distribution( logits, name='loss1', reuse=True) - self.assertEqual(1, len(lc.losses)) + self.assertEqual(1, len(lc.towers_by_loss)) # Add another new loss function. lc.register_categorical_predictive_distribution(logits, name='loss2') - self.assertEqual(2, len(lc.losses)) + self.assertEqual(2, len(lc.towers_by_loss)) def testLossFunctionWithoutName(self): """Ensure loss functions get unique names if 'name' not specified.""" @@ -298,13 +329,9 @@ class LayerCollectionTest(test.TestCase): name='loss1', reuse=layer_collection.VARIABLE_SCOPE) - self.assertEqual(len(lc.losses), 1) - loss = lc.losses[0] - + self.assertEqual(len(lc.towers_by_loss), 1) # Three successful registrations. - self.assertEqual(loss.params.shape.as_list(), - [3 * batch_size, output_size]) - self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size]) + self.assertEqual(len(lc.towers_by_loss[0]), 3) def testRegisterCategoricalPredictiveDistributionBatchSize1(self): with ops.Graph().as_default(): @@ -441,13 +468,13 @@ class LayerCollectionTest(test.TestCase): b = variable_scope.get_variable('b', [3]) lc = layer_collection.LayerCollection() lc.register_fully_connected(w, inputs, outputs) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) with self.assertRaises(KeyError): lc.register_fully_connected((w, b), inputs, outputs, reuse=True) self.assertNotIn((w, b), lc.fisher_blocks) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) lc.register_fully_connected(w, inputs, outputs, reuse=True) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) def testMakeOrGetFactor(self): with ops.Graph().as_default(): @@ -479,17 +506,6 @@ class LayerCollectionTest(test.TestCase): variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertTrue(all([var.name.startswith(scope) for var in variables])) - def testGetUseCountMap(self): - """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - 'a': MockFisherBlock(), - ('a', 'c'): MockFisherBlock(), - ('b', 'c'): MockFisherBlock() - } - use_count_map = lc.get_use_count_map() - self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) - def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', shape=()) y = variable_scope.get_variable('y', shape=()) @@ -550,6 +566,32 @@ class LayerCollectionTest(test.TestCase): self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + def testDefaultLayerCollection(self): + with ops.Graph().as_default(): + # Can't get default if there isn't one set. + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # Can't set default twice. + lc = layer_collection.LayerCollection() + layer_collection.set_default_layer_collection(lc) + with self.assertRaises(ValueError): + layer_collection.set_default_layer_collection(lc) + + # Same as one set. + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + + # Can set to None. + layer_collection.set_default_layer_collection(None) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # as_default() is the same as setting/clearing. + with lc.as_default(): + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index ae787b6f1ac90218f2ac73d37fb270df0b822de2..c00af5593f085e3b1f3e030a24f4b821115cc869 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,7 +24,6 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -97,22 +96,6 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) - def testMultiMinibatchRegistration(self): - """Ensure this loss function supports registering multiple minibatches.""" - with ops.Graph().as_default(): - tower_logits = [] - loss = None - num_towers = 5 - for _ in range(num_towers): - logits = random_ops.random_uniform(shape=[2, 3]) - tower_logits.append(logits) - if loss is None: - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - else: - loss.register_additional_minibatch(logits) - self.assertListEqual(loss.input_minibatches, tower_logits) - self.assertEqual(loss.num_registered_minibatches, num_towers) - def testMultiplyFisherSingleVector(self): with ops.Graph().as_default(), self.test_session() as sess: logits = np.array([1., 2., 3.]) @@ -203,23 +186,5 @@ class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) - def testMultiMinibatchRegistration(self): - """Ensure this loss function supports registering multiple minibatches.""" - with ops.Graph().as_default(): - tower_logits = [] - loss = None - num_towers = 5 - for _ in range(num_towers): - logits = random_ops.random_uniform(shape=[2, 3]) - tower_logits.append(logits) - if loss is None: - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - logits) - else: - loss.register_additional_minibatch(logits) - self.assertListEqual(loss.input_minibatches, tower_logits) - self.assertEqual(loss.num_registered_minibatches, num_towers) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py index 9325aa1b7325fa9cf546d66e6505affa1af7db4d..560a9b0b426eccb262296a505df7f782a96d9c1d 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import optimizer from tensorflow.python.framework import ops @@ -32,6 +33,13 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + def dummy_layer_collection(): lcoll = lc.LayerCollection() dummy = array_ops.constant([1., 2.]) @@ -186,6 +194,11 @@ class OptimizerTest(test.TestCase): layer_collection, momentum=0.5, momentum_type='regular') + (cov_update_thunks, + inv_update_thunks) = opt.make_vars_and_create_op_thunks() + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + grads_and_vars = opt.compute_gradients(output, [weights, bias]) all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] @@ -193,6 +206,8 @@ class OptimizerTest(test.TestCase): sess.run(tf_variables.global_variables_initializer()) old_vars = sess.run(all_vars) + sess.run(cov_update_ops) + sess.run(inv_update_ops) sess.run(op) new_vars = sess.run(all_vars) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 97a97adbf5577cd2694d3055acaa59258ad27964..2cee01212a11595669e9df0fc95a5657926c1038 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -325,6 +327,84 @@ class UtilsTest(test.TestCase): ], values) + def testExtractConvolutionPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_spatial_shape = [9, 10, 11] + in_channels = out_channels = 32 + kernel_spatial_shape = [5, 3, 3] + spatial_strides = [1, 2, 1] + spatial_dilation = [1, 1, 1] + padding = 'SAME' + + images = random_ops.random_uniform( + [batch_size] + image_spatial_shape + [in_channels], seed=0) + kernel_shape = kernel_spatial_shape + [in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_convolution_patches( + images, + kernel_shape, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + result_spatial_shape = ( + patches.shape.as_list()[1:1 + len(image_spatial_shape)]) + self.assertEqual(patches.shape.as_list(), + [batch_size] + result_spatial_shape + + kernel_spatial_shape + [in_channels]) + + # Ensure extract...patches() + matmul() and convolution() implementation + # give the same answer. + outputs = nn_ops.convolution( + images, + kernel, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + + patches_flat = array_ops.reshape( + patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + def testExtractPointwiseConv2dPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_height = image_width = 8 + in_channels = out_channels = 3 + kernel_height = kernel_width = 1 + strides = [1, 1, 1, 1] + padding = 'VALID' + + images = random_ops.random_uniform( + [batch_size, image_height, image_width, in_channels], seed=0) + kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) + self.assertEqual(patches.shape.as_list(), [ + batch_size, image_height, image_width, kernel_height, kernel_width, + in_channels + ]) + + # Ensure extract...patches() + matmul() and conv2d() implementation + # give the same answer. + outputs = nn_ops.conv2d(images, kernel, strides, padding) + + patches_flat = array_ops.reshape( + patches, [-1, kernel_height * kernel_width * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index ee6549b109399766579b6ea18a987ae2c8275983..3c01eb65e7a687d6c477b858b8d91ea7f309dc64 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -35,12 +35,16 @@ py_library( srcs = ["fisher_factors.py"], srcs_version = "PY2AND3", deps = [ + ":linear_operator", ":utils", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:special_math_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -60,6 +64,19 @@ py_library( ], ) +py_library( + name = "linear_operator", + srcs = ["linear_operator.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", + ], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], @@ -144,10 +161,13 @@ py_library( ":fisher_estimator", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], ) @@ -168,6 +188,7 @@ py_library( name = "fisher_estimator", srcs = [ "estimator.py", + "placement.py", ], srcs_version = "PY2AND3", deps = [ @@ -177,6 +198,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -239,15 +261,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index a7b1f9d35c931fc44408be804479e758f28f7110..854f885c26f2b4340555adb91bc3b9749962d869 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,68 +18,59 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import itertools - +import abc import numpy as np +import six +from tensorflow.contrib.kfac.python.ops import placement from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -class _DeviceContextGenerator(object): - """Class for generating device contexts in a round-robin fashion.""" - - def __init__(self, devices): - """Creates a _DeviceContextGenerator object. +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. - Example usage: + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. - ```python - dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) - with dcg(): - # All operations in this context will be placed on GPU 0 - ... - with dcg(): - # All operations in this context will be placed on GPU 1 - ... - ``` + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. - Args: - devices: An iterable of device strings (or None). Successive calls to - __call__ will give contexts which place devices on these devices in - a round-robin fashion. - """ - self._cycle = None if devices is None else itertools.cycle(devices) + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. - @contextlib.contextmanager - def __call__(self): - """Returns a context manager specifying the default device.""" - if self._cycle is None: - yield - else: - with tf_ops.device(next(self._cycle)): - yield + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops " + "placement strategy : {}".format(placement_strategy)) +# pylint: enable=abstract-class-instantiated +@six.add_metaclass(abc.ABCMeta) class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher. - Attributes: - cov_update_thunks: list of no-arg functions. Executing a function adds - covariance update ops for a single FisherFactor to the graph. - cov_update_ops: List of Ops. Running an op updates covariance matrices for a - single FisherFactor. - cov_update_op: Op. Running updates covariance matrices for all - FisherFactors. - inv_update_thunks: list of no-arg functions. Executing a function adds - inverse update ops for a single FisherFactor to the graph. - inv_update_ops: List of Ops. Running an op updates inverse matrices for a - single FisherFactor. - inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. """ def __init__(self, @@ -87,26 +78,33 @@ class FisherEstimator(object): cov_ema_decay, damping, layer_collection, + exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, - cov_devices=None, - inv_devices=None): + name="FisherEstimator", + compute_cholesky=False, + compute_cholesky_inverse=False): """Create a FisherEstimator object. Args: - variables: A list of the variables for which to estimate the Fisher. This - must match the variables registered in layer_collection (if it is not - None). + variables: A `list` of variables or `callable` which returns the variables + for which to estimate the Fisher. This must match the variables + registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. - damping: The damping factor used to stabilize training due to errors in - the local approximation with the Fisher information matrix, and to - regularize the update direction by making it closer to the gradient. - (Higher damping means the update looks more like a standard gradient - update - see Tikhonov regularization.) + damping: float. The damping factor used to stabilize training due to + errors in the local approximation with the Fisher information matrix, + and to regularize the update direction by making it closer to the + gradient. (Higher damping means the update looks more like a standard + gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. + exps: List of floats or ints. These represent the different matrix + powers of the approximate Fisher that the FisherEstimator will be able + to multiply vectors by. If the user asks for a matrix power other + one of these (or 1, which is always supported), there will be a + failure. (Default: (-1,)) estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach @@ -125,24 +123,23 @@ class FisherEstimator(object): equal to the output dimension, roughly speaking. colocate_gradients_with_ops: Whether we should request gradients be colocated with their respective ops. (Default: True) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - + name: A string. A name given to this estimator, which is added to the + variable scope when constructing variables and ops. + (Default: "FisherEstimator") + compute_cholesky: Bool. Whether or not the FisherEstimator will be + able to multiply vectors by the Cholesky factor. + (Default: False) + compute_cholesky_inverse: Bool. Whether or not the FisherEstimator + will be able to multiply vectors by the Cholesky factor inverse. + (Default: False) Raises: ValueError: If no losses have been registered with layer_collection. """ - - self._cov_ema_decay = cov_ema_decay self._variables = variables + self._cov_ema_decay = cov_ema_decay self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection - self._layers.create_subgraph() - self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, @@ -151,39 +148,71 @@ class FisherEstimator(object): } self._colocate_gradients_with_ops = colocate_gradients_with_ops - # TODO(b/70674513): Factor device placement outside of this class. - self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) - if inv_devices == cov_devices: - self._inv_device_context_generator = self._cov_device_context_generator - else: - self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) - - self._instantiate_factors() - - self.cov_update_thunks = [ - self._create_cov_update_thunk(factor) - for factor in self._layers.get_factors() - ] - self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks] - self.cov_update_op = control_flow_ops.group( - self.cov_update_ops, name="cov_update_op") + self._made_vars = False + self._exps = exps + self._compute_cholesky = compute_cholesky + self._compute_cholesky_inverse = compute_cholesky_inverse - self.inv_update_thunks = [ - self._create_inv_update_thunk(factor) - for factor in self._layers.get_factors() - ] - self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks] - self.inv_update_op = control_flow_ops.group( - self.inv_update_ops, name="inv_update_op") + self._name = name @property def variables(self): - return self._variables + if callable(self._variables): + return self._variables() + else: + return self._variables @property def damping(self): return self._damping + @property + def blocks(self): + """All registered FisherBlocks.""" + return self._layers.get_blocks() + + @property + def factors(self): + """All registered FisherFactors.""" + return self._layers.get_factors() + + @property + def name(self): + return self._name + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass + def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. @@ -217,9 +246,7 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ - - return self._apply_transformation(vecs_and_vars, - lambda fb, vec: fb.multiply_inverse(vec)) + return self.multiply_matpower(-1, vecs_and_vars) def multiply(self, vecs_and_vars): """Multiplies the vectors by the corresponding (damped) blocks. @@ -231,9 +258,67 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ + return self.multiply_matpower(1, vecs_and_vars) + + def multiply_matpower(self, exp, vecs_and_vars): + """Multiplies the vecs by the corresponding matrix powers of the blocks. - return self._apply_transformation(vecs_and_vars, - lambda fb, vec: fb.multiply(vec)) + Args: + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert exp in self._exps + + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky(self, vecs_and_vars, transpose=False): + """Multiplies the vecs by the corresponding Cholesky factors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factors are transposed before + multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky + + fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): + """Mults the vecs by the inverses of the corresponding Cholesky factors. + + Note: if you are using Cholesky inverse multiplication to sample from + a matrix-variate Gaussian you will want to multiply by the transpose. + Let L be the Cholesky factor of F and observe that + + L^-T * L^-1 = (L * L^T)^-1 = F^-1 . + + Thus we want to multiply by L^-T in order to sample from Gaussian with + covariance F^-1. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factor inverses are transposed + before multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky_inverse + + fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) def _instantiate_factors(self): """Instantiates FisherFactors' variables. @@ -241,9 +326,9 @@ class FisherEstimator(object): Raises: ValueError: If estimation_mode was improperly specified at construction. """ - fisher_blocks_list = self._layers.get_blocks() + blocks = self.blocks tensors_to_compute_grads = [ - fb.tensors_to_compute_grads() for fb in fisher_blocks_list + block.tensors_to_compute_grads() for block in blocks ] try: @@ -253,45 +338,135 @@ class FisherEstimator(object): raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) - # TODO(b/68033310): This loop round-robins the "concat" operations which - # gather the inputs for the cov_updates. In future, we might do these - # computations locally then communicate the results, which would require a - # modification to this code. - for grads_list, fb in zip(grads_lists, fisher_blocks_list): - with self._cov_device_context_generator(): - fb.instantiate_factors(grads_list, self.damping) + for grads_list, block in zip(grads_lists, blocks): + block.instantiate_factors(grads_list, self.damping) + + def _check_vars_unmade_and_set_made_flag(self): + if self._made_vars: + raise Exception("Already made variables.") + self._made_vars = True + + def made_vars(self): + return self._made_vars + + def _register_matrix_functions(self): + for block in self.blocks: + for exp in self._exps: + block.register_matpower(exp) + if self._compute_cholesky: + block.register_cholesky() + if self._compute_cholesky_inverse: + block.register_cholesky_inverse() + + def _finalize_layer_collection(self): + self._layers.create_subgraph() + self._layers.check_registration(self.variables) + self._instantiate_factors() + self._register_matrix_functions() + + def create_ops_and_vars_thunks(self, scope=None): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. - def _create_cov_update_thunk(self, factor): + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All thunks will execute inside + of a variable scope of the given name. (Default: None) + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + self._check_vars_unmade_and_set_made_flag() + + self._finalize_layer_collection() + + scope = self.name if scope is None else scope + + cov_variable_thunks = [ + self._create_cov_variable_thunk(factor, scope) + for factor in self.factors + ] + cov_update_thunks = [ + self._create_cov_update_thunk(factor, scope) for factor in self.factors + ] + inv_variable_thunks = [ + self._create_inv_variable_thunk(factor, scope) + for factor in self.factors + ] + inv_update_thunks = [ + self._create_inv_update_thunk(factor, scope) for factor in self.factors + ] + + return (cov_variable_thunks, cov_update_thunks, + inv_variable_thunks, inv_update_thunks) + + def _create_cov_variable_thunk(self, factor, scope): + """Constructs a covariance variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_cov_variables() + + return thunk + + def _create_cov_update_thunk(self, factor, scope): """Constructs a covariance update thunk for a single FisherFactor.""" def thunk(): - with tf_ops.name_scope( - "create_cov_update_thunk", values=[self._cov_ema_decay]): + with variable_scope.variable_scope(scope): return factor.make_covariance_update_op(self._cov_ema_decay) return thunk - def _create_inv_update_thunk(self, factor): + def _create_inv_variable_thunk(self, factor, scope): + """Constructs a inverse variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_inv_variables() + + return thunk + + def _create_inv_update_thunk(self, factor, scope): """Constructs an inverse update thunk for a single FisherFactor.""" def thunk(): - with tf_ops.name_scope("create_inv_update_thunk"): - with self._inv_device_context_generator(): - return control_flow_ops.group(factor.make_inverse_update_ops()) + with variable_scope.variable_scope(scope): + return control_flow_ops.group(factor.make_inverse_update_ops()) return thunk def _get_grads_lists_gradients(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device grads_flat = gradients_impl.gradients( - self._layers.total_sampled_loss(), + self._layers.eval_losses_on_samples(), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device grads_flat = gradients_impl.gradients( - self._layers.total_loss(), + self._layers.eval_losses(), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) @@ -300,9 +475,10 @@ class FisherEstimator(object): def _get_transformed_random_signs(self): transformed_random_signs = [] for loss in self._layers.losses: - transformed_random_signs.append( - loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape))) + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) return transformed_random_signs def _get_grads_lists_curvature_prop(self, tensors): @@ -321,13 +497,20 @@ class FisherEstimator(object): # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients( - loss.inputs, - nest.flatten(tensors), - grad_ys=transformed_one_hot, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py index 33c969650615bf8e439c2f669b4a1efaf2f565ff..9c9fef471f8033bec53ceb1e4f073dd921cbe3c7 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'FisherEstimator', + 'make_fisher_estimator', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index cf38d28b43836dced8babe2ffa7853b1c4b1b369..3a5c8eb5f9630fbcc121e4c502f771af32a96bcb 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -19,11 +19,11 @@ Information matrix. Suppose one has a model that parameterizes a posterior distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its Fisher Information matrix is given by, - F(params) = E[ v(x, y, params) v(x, y, params)^T ] + $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ where, - v(x, y, params) = (d / d params) log p(y | x, params) + $$v(x, y, params) = (d / d params) log p(y | x, params)$$ and the expectation is taken with respect to the data's distribution for 'x' and the model's posterior distribution for 'y', @@ -40,12 +40,15 @@ from __future__ import print_function import abc import enum # pylint: disable=g-bad-import-order +import numpy as np import six from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest # For blocks corresponding to convolutional layers, or any type of block where # the parameters can be thought of as being replicated in time or space, @@ -80,34 +83,22 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): - """Computes the scalar constant pi for Tikhonov regularization/damping. + r"""Computes the scalar constant pi for Tikhonov regularization/damping. - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: - left_cov: The left Kronecker factor "covariance". - right_cov: The right Kronecker factor "covariance". + left_cov: A LinearOperator object. The left Kronecker factor "covariance". + right_cov: A LinearOperator object. The right Kronecker factor "covariance". Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ - - def _trace(cov): - if len(cov.shape) == 1: - # Diagonal matrix. - return math_ops.reduce_sum(cov) - elif len(cov.shape) == 2: - # Full matrix. - return math_ops.trace(cov) - else: - raise ValueError( - "What's the trace of a Tensor of rank %d?" % len(cov.shape)) - # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = left_cov.trace() * int(right_cov.domain_dimension) + right_norm = right_cov.trace() * int(left_cov.domain_dimension) return math_ops.sqrt(left_norm / right_norm) @@ -121,12 +112,44 @@ def compute_pi_adjusted_damping(left_cov, right_cov, damping): return (damping, damping) +class PackagedFunc(object): + """A Python thunk with a stable ID. + + Enables stable names for lambdas. + """ + + def __init__(self, func, func_id): + """Initializes PackagedFunc. + + Args: + func: a zero-arg Python function. + func_id: a hashable, function that produces a hashable, or a list/tuple + thereof. + """ + self._func = func + func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) + self._func_id = func_id + + def __call__(self): + return self._func() + + @property + def func_id(self): + """A hashable identifier for this function.""" + return tuple(elt() if callable(elt) else elt for elt in self._func_id) + + +def _package_func(func, func_id): + return PackagedFunc(func, func_id) + + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. - Subclasses must implement multiply_inverse(), instantiate_factors(), and - tensors_to_compute_grads() methods. + Subclasses must implement register_matpower, multiply_matpower, + instantiate_factors, tensors_to_compute_grads, and num_registered_towers + methods. """ def __init__(self, layer_collection): @@ -145,6 +168,42 @@ class FisherBlock(object): pass @abc.abstractmethod + def register_matpower(self, exp): + """Registers a matrix power to be computed by the block. + + Args: + exp: A float representing the power to raise the block by. + """ + pass + + @abc.abstractmethod + def register_cholesky(self): + """Registers a Cholesky factor to be computed by the block.""" + pass + + @abc.abstractmethod + def register_cholesky_inverse(self): + """Registers an inverse Cholesky factor to be computed by the block.""" + pass + + def register_inverse(self): + """Registers a matrix inverse to be computed by the block.""" + self.register_matpower(-1) + + @abc.abstractmethod + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + exp: A float representing the power to raise the block by before + multiplying it by the vector. + + Returns: + The vector left-multiplied by the (damped) matrix-power of the block. + """ + pass + def multiply_inverse(self, vector): """Multiplies the vector by the (damped) inverse of the block. @@ -154,9 +213,8 @@ class FisherBlock(object): Returns: The vector left-multiplied by the (damped) inverse of the block. """ - pass + return self.multiply_matpower(vector, -1) - @abc.abstractmethod def multiply(self, vector): """Multiplies the vector by the (damped) block. @@ -166,6 +224,33 @@ class FisherBlock(object): Returns: The vector left-multiplied by the (damped) block. """ + return self.multiply_matpower(vector, 1) + + @abc.abstractmethod + def multiply_cholesky(self, vector, transpose=False): + """Multiplies the vector by the (damped) Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor is transposed before + multiplying the vector. (Default: False) + + Returns: + The vector left-multiplied by the (damped) Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def multiply_cholesky_inverse(self, vector, transpose=False): + """Multiplies vector by the (damped) inverse Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor inverse is transposed + before multiplying the vector. (Default: False) + Returns: + Vector left-multiplied by (damped) inverse Cholesky-factor of the block. + """ pass @abc.abstractmethod @@ -175,8 +260,8 @@ class FisherBlock(object): pass @abc.abstractproperty - def num_registered_minibatches(self): - """Number of minibatches registered for this FisherBlock. + def num_registered_towers(self): + """Number of towers registered for this FisherBlock. Typically equal to the number of towers in a multi-tower setup. """ @@ -207,32 +292,46 @@ class FullFB(FisherBlock): super(FullFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._damping = damping + self._damping_func = _package_func(lambda: damping, (damping,)) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullFactor, (grads_list, self._batch_size)) - self._factor.register_damped_inverse(damping) - def multiply_inverse(self, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_inverse( - vector_flat, self._damping) - return utils.column_to_tensors(vector, out_flat) + def register_matpower(self, exp): + self._factor.register_matpower(exp, self._damping_func) - def multiply(self, vector): + def register_cholesky(self): + self._factor.register_cholesky(self._damping_func) + + def register_cholesky_inverse(self): + self._factor.register_cholesky_inverse(self._damping_func) + + def _multiply_matrix(self, matrix, vector, transpose=False): vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply( - vector_flat, self._damping) + out_flat = matrix.matmul(vector_flat, adjoint=transpose) return utils.column_to_tensors(vector, out_flat) + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov() + return self._factor.get_cov_as_linear_operator().to_dense() def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -240,7 +339,7 @@ class FullFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -248,7 +347,47 @@ class FullFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class NaiveDiagonalFB(FisherBlock): +@six.add_metaclass(abc.ABCMeta) +class DiagonalFB(FisherBlock): + """A base class for FisherBlocks that use diagonal approximations.""" + + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass + + def register_cholesky(self): + # Not needed for this. Cholesky's are computed on demand in the + # diagonal case + pass + + def register_cholesky_inverse(self): + # Not needed for this. Cholesky inverses's are computed on demand in the + # diagonal case + pass + + def _multiply_matrix(self, matrix, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def full_fisher_block(self): + return self._factor.get_cov_as_linear_operator().to_dense() + + +class NaiveDiagonalFB(DiagonalFB): """FisherBlock using a diagonal matrix approximation. This type of approximation is generically applicable but quite primitive. @@ -271,32 +410,16 @@ class NaiveDiagonalFB(FisherBlock): super(NaiveDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._damping = damping + self._damping_func = _package_func(lambda: damping, (damping,)) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - def multiply_inverse(self, vector): - vector_flat = utils.tensors_to_column(vector) - print("vector_flat: %s" % vector_flat) - out_flat = self._factor.left_multiply_inverse( - vector_flat, self._damping) - print("out_flat: %s" % out_flat) - return utils.column_to_tensors(vector, out_flat) - - def multiply(self, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply( - vector_flat, self._damping) - return utils.column_to_tensors(vector, out_flat) - - def full_fisher_block(self): - return self._factor.get_cov() - def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -304,7 +427,7 @@ class NaiveDiagonalFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -312,7 +435,92 @@ class NaiveDiagonalFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class FullyConnectedDiagonalFB(FisherBlock): +class InputOutputMultiTower(object): + """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" + + def __init__(self, *args, **kwargs): + self.__inputs = [] + self.__outputs = [] + super(InputOutputMultiTower, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + The initial format of self._inputs is expected to be a list of Tensors + over towers. Similarly grads_lists is expected to be a list over sources + of such lists. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single + tensor (represented as a PartitionedTensor object) equal to the + concatenation (across towers) of all of the elements of self._inputs. And + similarly grads_list is formatted into a tuple (over sources) of such + tensors (also represented as PartitionedTensors). + + If TOWER_STRATEGY is "separate", formatting of inputs and grads_list + remains unchanged from the initial format (although possibly converting + from lists into tuples). + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". + """ + inputs = self._inputs + # inputs is a list over towers of Tensors + # grads_list is a list of list with the first index being sources and the + # second being towers. + if fisher_factors.TOWER_STRATEGY == "concat": + # Merge towers together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + # Do the same for grads_list but preserve leading sources dimension + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + elif fisher_factors.TOWER_STRATEGY == "separate": + inputs = tuple(inputs) + grads_list = tuple(grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + return inputs, grads_list + + def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" + return tuple(self._outputs) + + def register_additional_tower(self, inputs, outputs): + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_towers(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + + @property + def _inputs(self): + return self.__inputs + + @property + def _outputs(self): + return self.__outputs + + +class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -322,14 +530,14 @@ class FullyConnectedDiagonalFB(FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider fully connected layer in this model with (unshared) weight matrix 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( a (d loss / d s)^T ) + $$v(x, y, w) = vec( a (d loss / d s)^T )$$ This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. @@ -344,80 +552,22 @@ class FullyConnectedDiagonalFB(FisherBlock): has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = [] - self._outputs = [] self._has_bias = has_bias super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) - self._damping = damping self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedDiagonalFactor, (inputs, grads_list, self._has_bias)) - def multiply_inverse(self, vector): - """Approximate damped inverse Fisher-vector product. + self._damping_func = _package_func(lambda: damping, (damping,)) - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. - Returns: - Tensor of the same shape, corresponding to the inverse Fisher-vector - product. - """ - reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_inverse( - reshaped_vec, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply(self, vector): - """Approximate damped Fisher-vector product. - - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. - - Returns: - Tensor of the same shape, corresponding to the Fisher-vector product. - """ - reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply( - reshaped_vec, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def tensors_to_compute_grads(self): - """Tensors to compute derivative of loss with respect to.""" - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - result = len(self._inputs) - assert result == len(self._outputs) - return result - - -class ConvDiagonalFB(FisherBlock): - """FisherBlock for convolutional layers using a diagonal approx. +class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): + """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" @@ -426,14 +576,14 @@ class ConvDiagonalFB(FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider a convoluational layer in this model with (unshared) filter matrix 'w'. For an example image 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ where 'loc' is a single (x, y) location in an image. @@ -441,7 +591,13 @@ class ConvDiagonalFB(FisherBlock): to the layer's parameters 'w'. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + strides, + padding, + data_format=None, + dilations=None): """Creates a ConvDiagonalFB block. Args: @@ -453,92 +609,110 @@ class ConvDiagonalFB(FisherBlock): containing the previous and a Tensor of shape [out_channels]. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (e.g. "SAME"). + data_format: str or None. Format of input data. + dilations: List of 4 ints or None. Rate for dilation along all dimensions. + + Raises: + ValueError: if strides is not length-4. + ValueError: if dilations is not length-4. + ValueError: if channel is not last dimension. """ - self._inputs = [] - self._outputs = [] - self._strides = tuple(strides) if isinstance(strides, list) else strides + if len(strides) != 4: + raise ValueError("strides must contain 4 numbers.") + + if dilations is None: + dilations = [1, 1, 1, 1] + + if len(dilations) != 4: + raise ValueError("dilations must contain 4 numbers.") + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + self._strides = maybe_tuple(strides) self._padding = padding + self._data_format = data_format + self._dilations = maybe_tuple(dilations) self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) + if len(self._filter_shape) != 4: + raise ValueError( + "Convolution filter must be of shape" + " [filter_height, filter_width, in_channels, out_channels].") + super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # Concatenate inputs, grads_list into single Tensors. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - inputs_shape = tuple(inputs.shape.as_list()) - self._num_locations = ( - inputs_shape[1] * inputs_shape[2] // - (self._strides[1] * self._strides[2])) - - self._damping = (self._num_locations - * normalize_damping(damping, self._num_locations)) + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, (inputs, grads_list, self._filter_shape, self._strides, self._padding, - self._has_bias)) - - def multiply_inverse(self, vector): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_inverse( - reshaped_vect, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply(self, vector): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply( - reshaped_vect, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) + self._data_format, self._dilations, self._has_bias)) - def tensors_to_compute_grads(self): - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to - the convolution. - outputs: Tensor of shape [batch_size, height, width, output_size]. Layer - preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) + def damping_func(): + return self._num_locations * normalize_damping(damping, + self._num_locations) - @property - def num_registered_minibatches(self): - return len(self._inputs) + damping_id = (self._num_locations, "mult", "normalize_damping", damping, + self._num_locations) + self._damping_func = _package_func(damping_func, damping_id) class KroneckerProductFB(FisherBlock): - """A base class for FisherBlocks with separate input and output factors. + """A base class for blocks with separate input and output Kronecker factors. The Fisher block is approximated as a Kronecker product of the input and output factors. """ - def _register_damped_input_and_output_inverses(self, damping): - """Registers damped inverses for both the input and output factors. + def _setup_damping(self, damping, normalization=None): + """Makes functions that compute the damping values for both factors.""" + def compute_damping(): + if normalization is not None: + maybe_normalized_damping = normalize_damping(damping, normalization) + else: + maybe_normalized_damping = damping + + return compute_pi_adjusted_damping( + self._input_factor.get_cov_as_linear_operator(), + self._output_factor.get_cov_as_linear_operator(), + maybe_normalized_damping**0.5) + + if normalization is not None: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + "normalize_damping", damping, normalization, "power", 0.5) + else: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + damping, "power", 0.5) + + self._input_damping_func = _package_func(lambda: compute_damping()[0], + damping_id + ("ref", 0)) + self._output_damping_func = _package_func(lambda: compute_damping()[1], + damping_id + ("ref", 1)) - Sets the instance members _input_damping and _output_damping. Requires the - instance members _input_factor and _output_factor. + def register_matpower(self, exp): + self._input_factor.register_matpower(exp, self._input_damping_func) + self._output_factor.register_matpower(exp, self._output_damping_func) - Args: - damping: The base damping factor (float or Tensor) for the damped inverse. - """ - self._input_damping, self._output_damping = compute_pi_adjusted_damping( - self._input_factor.get_cov(), - self._output_factor.get_cov(), - damping**0.5) + def register_cholesky(self): + self._input_factor.register_cholesky(self._input_damping_func) + self._output_factor.register_cholesky(self._output_damping_func) - self._input_factor.register_damped_inverse(self._input_damping) - self._output_factor.register_damped_inverse(self._output_damping) + def register_cholesky_inverse(self): + self._input_factor.register_cholesky_inverse(self._input_damping_func) + self._output_factor.register_cholesky_inverse(self._output_damping_func) @property def _renorm_coeff(self): @@ -552,29 +726,46 @@ class KroneckerProductFB(FisherBlock): """ return 1.0 - def multiply_inverse(self, vector): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply_inverse( - reshaped_vector, - self._output_damping) - reshaped_out = self._input_factor.left_multiply_inverse( - reshaped_out, self._input_damping) - if self._renorm_coeff != 1.0: - reshaped_out /= math_ops.cast( - self._renorm_coeff, dtype=reshaped_out.dtype) + reshaped_out = right_factor.matmul_right(reshaped_vector, + adjoint=transpose_right) + reshaped_out = left_factor.matmul(reshaped_out, + adjoint=transpose_left) + if extra_scale != 1.0: + reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) - def multiply(self, vector): - reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply( - reshaped_vector, - self._output_damping) - reshaped_out = self._input_factor.left_multiply( - reshaped_out, self._input_damping) - if self._renorm_coeff != 1.0: - reshaped_out *= math_ops.cast( - self._renorm_coeff, dtype=reshaped_out.dtype) - return utils.mat2d_to_layer_params(vector, reshaped_out) + def multiply_matpower(self, vector, exp): + left_factor = self._input_factor.get_matpower( + exp, self._input_damping_func) + right_factor = self._output_factor.get_matpower( + exp, self._output_damping_func) + extra_scale = float(self._renorm_coeff)**exp + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale) + + def multiply_cholesky(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky(self._input_damping_func) + right_factor = self._output_factor.get_cholesky(self._output_damping_func) + extra_scale = float(self._renorm_coeff)**0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky_inverse( + self._input_damping_func) + right_factor = self._output_factor.get_cholesky_inverse( + self._output_damping_func) + extra_scale = float(self._renorm_coeff)**-0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) def full_fisher_block(self): """Explicitly constructs the full Fisher block. @@ -584,16 +775,16 @@ class KroneckerProductFB(FisherBlock): Returns: The full Fisher block. """ - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() + left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() + right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() return self._renorm_coeff * utils.kronecker_product(left_factor, right_factor) -class EmbeddingKFACFB(KroneckerProductFB): +class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for embedding layers. - This FisherBlock is similar to EmbeddingKFACFB, except that its + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its input factor is approximated by a diagonal matrix. In the case that each example references exactly one embedding, this approximation is exact. @@ -608,8 +799,6 @@ class EmbeddingKFACFB(KroneckerProductFB): Fisher information matrix to which this FisherBlock belongs. vocab_size: int. Size of vocabulary for this embedding layer. """ - self._inputs = [] - self._outputs = [] self._vocab_size = vocab_size super(EmbeddingKFACFB, self).__init__(layer_collection) @@ -624,41 +813,17 @@ class EmbeddingKFACFB(KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.EmbeddingInputKroneckerFactor, # - ((inputs,), self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # - (grads_list,)) - self._register_damped_input_and_output_inverses(damping) + inputs, grads_list = self._process_data(grads_list) - def tensors_to_compute_grads(self): - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - return len(self._inputs) + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + self._setup_damping(damping) -class FullyConnectedKFACBasicFB(KroneckerProductFB): +class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. This uses the Kronecker-factorized approximation from the original @@ -674,8 +839,6 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = [] - self._outputs = [] self._has_bias = has_bias super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) @@ -690,42 +853,19 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) - self._register_damped_input_and_output_inverses(damping) - - def tensors_to_compute_grads(self): - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - return len(self._inputs) + self._setup_damping(damping) -class ConvKFCBasicFB(KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. +class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): + r"""FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -734,12 +874,12 @@ class ConvKFCBasicFB(KroneckerProductFB): 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates, - F(w) = #locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T]) + $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T])$$ where - ds = (d / ds) log p(y | x, w) + $$ds = (d / ds) log p(y | x, w)$$ #locations = number of (x, y) locations where 'w' is applied. where the expectation is taken over all examples and locations and flat() @@ -748,23 +888,40 @@ class ConvKFCBasicFB(KroneckerProductFB): See equation 23 in https://arxiv.org/abs/1602.01407 for details. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): """Creates a ConvKFCBasicFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, + kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". """ - self._inputs = [] - self._outputs = [] - self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params @@ -773,145 +930,614 @@ class ConvKFCBasicFB(KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._strides, self._padding, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - damping = normalize_damping(damping, self._num_locations) - self._register_damped_input_and_output_inverses(damping) - self._damping = damping + self._setup_damping(damping, normalization=self._num_locations) @property def _renorm_coeff(self): return self._num_locations - def tensors_to_compute_grads(self): - return self._outputs - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. +class DepthwiseConvDiagonalFB(ConvDiagonalFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvDiagonalFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvDiagonalFB, self).__init__( + layer_collection=layer_collection, + params=params, + strides=strides, + padding=padding, + dilations=rate, + data_format=data_format) + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def _multiply_matrix(self, matrix, vector): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super( + DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvKFCBasicFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. Args: - inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to - the convolution. - outputs: Tensor of shape [batch_size, height, width, output_size]. Layer - preactivations. + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. """ - self._inputs.append(inputs) - self._outputs.append(outputs) + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvKFCBasicFB, self).__init__( + layer_collection=layer_collection, + params=params, + padding=padding, + strides=strides, + dilation_rate=rate, + data_format=data_format, + extract_patches_fn="extract_image_patches") + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super( + DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( + left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, + transpose_left=transpose_left, transpose_right=transpose_right) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with conv2d. + + Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's + compatible with tf.nn.conv2d(). - @property - def num_registered_minibatches(self): - return len(self._inputs) + Args: + filter: Tensor of shape [height, width, in_channels, channel_multiplier]. + name: None or str. Name of Op. + Returns: + Tensor of shape [height, width, in_channels, out_channels]. -def _concat_along_batch_dim(tensor_list): - """Concatenate tensors along batch (first) dimension. + """ + with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, channel_multiplier = ( + filter.shape.as_list()) + + results = [] + for i in range(in_channels): + # Slice out one in_channel's filter. Insert zeros around it to force it + # to affect that channel and that channel alone. + elements = [] + if i > 0: + elements.append( + array_ops.zeros( + [filter_height, filter_width, i, channel_multiplier])) + elements.append(filter[:, :, i:(i + 1), :]) + if i + 1 < in_channels: + elements.append( + array_ops.zeros([ + filter_height, filter_width, in_channels - (i + 1), + channel_multiplier + ])) + + # Concat along in_channel. + results.append( + array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) + + # Concat along out_channel. + return array_ops.concat(results, axis=-1, name="out_channel") + + +def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with depthwise_conv2d. + + Transforms a filter for use with tf.nn.conv2d() to one that's + compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along + the diagonal. Args: - tensor_list: list of Tensors or list of tuples of Tensors. + filter: Tensor of shape [height, width, in_channels, out_channels]. + name: None or str. Name of Op. Returns: - Tensor or tuple of Tensors. + Tensor of shape, + [height, width, in_channels, channel_multiplier] Raises: - ValueError: If 'tensor_list' is empty. - + ValueError: if out_channels is not evenly divisible by in_channels. """ - if not tensor_list: - raise ValueError( - "Cannot concatenate Tensors if there are no Tensors to concatenate.") - - if isinstance(tensor_list[0], (tuple, list)): - # [(tensor1a, tensor1b), - # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b) - return tuple( - array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list)) - else: - # [tensor1, tensor2] --> tensor - return array_ops.concat(tensor_list, axis=0) + with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, out_channels = ( + filter.shape.as_list()) + + if out_channels % in_channels != 0: + raise ValueError("out_channels must be evenly divisible by in_channels.") + channel_multiplier = out_channels // in_channels + + results = [] + filter = array_ops.reshape(filter, [ + filter_height, filter_width, in_channels, in_channels, + channel_multiplier + ]) + for i in range(in_channels): + # Slice out output corresponding to the correct filter. + filter_slice = array_ops.reshape( + filter[:, :, i, i, :], + [filter_height, filter_width, 1, channel_multiplier]) + results.append(filter_slice) + + # Concat along out_channel. + return array_ops.concat(results, axis=-2, name="in_channels") + + +def maybe_tuple(obj): + if not isinstance(obj, list): + return obj + return tuple(obj) def num_conv_locations(input_shape, strides): """Returns the number of spatial locations a 2D Conv kernel is applied to. Args: - input_shape: list representing shape of inputs to the Conv layer. - strides: list representing strides for the Conv kernel. + input_shape: List of ints representing shape of inputs to + tf.nn.convolution(). + strides: List of ints representing strides along spatial dimensions as + passed in to tf.nn.convolution(). Returns: A scalar |T| denoting the number of spatial locations for the Conv layer. """ - return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + spatial_input_locations = np.prod(input_shape[1:-1]) + + if strides is None: + spatial_strides_divisor = 1 + else: + spatial_strides_divisor = np.prod(strides) + + return spatial_input_locations // spatial_strides_divisor + +class InputOutputMultiTowerMultiUse(InputOutputMultiTower): + """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" -class FullyConnectedMultiIndepFB(KroneckerProductFB): + def __init__(self, num_uses=None, *args, **kwargs): + self._num_uses = num_uses + super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process temporal/multi-use data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + It accepts the data in one of two initial formats. The first possible + format is where self._inputs is a list of list of Tensors. The first index + is tower, the second is use/time-step. grads_list, meanwhile, is a list + over sources of such lists of lists. + + The second possible data format is where self._inputs is a Tensor with + uses/times-steps folded into the batch dimension. i.e. it is a Tensor + of shape [num_uses * size_batch, ...] which represents a reshape of a + Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is + a list over sources of such Tensors. + + There are two possible formats which inputs and grads_list are transformed + into. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing + a single tensor (represented as a PartitionedTensor object) with all of + the data from the towers, as well as the uses/time-steps, concatenated + together. In this tensor the leading dimension is the batch and + use/time-step dimensions folded together (with 'use' being the major of + these two, so that the tensors can be thought of as reshapes of ones of + shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a + tuple over sources of such tensors. + + If TOWER_STRATEGY is "separate" the inputs are formatted into lists of + tensors over towers. Each of these tensors has a similar format to + the tensor produced by the "concat" option, except that each contains + only the data from a single tower. grads_list is similarly formatted + into a tuple over sources of such tuples. + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". + ValueError: If the given/initial format of self._inputs and grads_list + isn't recognized, or doesn't agree with self._num_uses. + """ + + inputs = self._inputs + + if isinstance(inputs[0], (list, tuple)): + num_uses = len(inputs[0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of inputs.") + else: + self._num_uses = num_uses + + # Check that all mini-batches/towers have the same number of uses + if not all(len(input_) == num_uses for input_ in inputs): + raise ValueError("Length of inputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + inputs = tuple(zip(*inputs)) + + # Flatten the two dimensions + inputs = nest.flatten(inputs) + + # Merge everything together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + else: + inputs = tuple(inputs) + + # Now we perform the analogous processing for grads_list + if isinstance(grads_list[0][0], (list, tuple)): + num_uses = len(grads_list[0][0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of outputs, " + "or length of outputs is inconsistent with length of " + "inputs.") + else: + self._num_uses = num_uses + + if not all(len(grad) == num_uses for grads in grads_list + for grad in grads): + raise ValueError("Length of outputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) + + # Flatten the two dimensions, leaving the leading dimension (source) + # intact + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + # Merge inner dimensions together into PartitionedTensors. We package + # them in a singleton tuple since the factors will expect a list over + # towers + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + grads_list = tuple(tuple(utils.PartitionedTensor(grad) + for grad in grads) + for grads in grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + else: + grads_list = tuple(tuple(grads) for grads in grads_list) + + if self._num_uses is None: + raise ValueError("You must supply a value for the num_uses argument if " + "the number of uses cannot be inferred from inputs or " + "outputs arguments (e.g. if they are both given in the " + "single Tensor format, instead of as lists of Tensors.") + + return inputs, grads_list + + +class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb """ - def __init__(self, layer_collection, inputs, outputs, has_bias=False): + def __init__(self, layer_collection, has_bias=False, num_uses=None): """Creates a FullyConnectedMultiIndepFB block. Args: layer_collection: LayerCollection instance. - inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, - inputs_size]. - outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, - outputs_size]. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's parameters. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) """ - - assert len(inputs) == len(outputs) - # We need to make sure inputs and outputs are tuples and not lists so that - # they get hashed by layer_collection.make_or_get_factor properly. - self._inputs = tuple(inputs) - self._outputs = tuple(outputs) self._has_bias = has_bias - self._num_uses = len(inputs) - super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) - - @property - def num_registered_minibatches(self): - # TODO(b/69411207): Add support for registering additional minibatches. - return 1 + super(FullyConnectedMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, - ((self._inputs,), self._has_bias)) + ((inputs,), self._num_uses, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - damping = normalize_damping(damping, self._num_uses) - self._register_damped_input_and_output_inverses(damping) + self._setup_damping(damping, normalization=self._num_uses) @property def _renorm_coeff(self): - return self._num_uses + return float(self._num_uses) - def tensors_to_compute_grads(self): - return self._outputs - def num_inputs(self): - return len(self._inputs) +class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + num_uses=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization= + (self._num_locations * self._num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses + + +class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to time + steps in an RNN architecture, but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size, num_uses=None): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with time folded into the batch + dimension (instead of time being a list dimension). (Default: None) + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) class SeriesFBApproximation(enum.IntEnum): @@ -920,34 +1546,35 @@ class SeriesFBApproximation(enum.IntEnum): option2 = 2 -class FullyConnectedSeriesFB(FisherBlock): +class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters across time. - See the following preprint for details: + This class implements the "Option 1" and "Option 2" approximation from the + following paper: https://openreview.net/pdf?id=HyMTkQZAb See the end of the appendix of the paper for a pseudo-code of the - algorithm being implemented by multiply_inverse here. Note that we are + algorithm being implemented by multiply_matpower here. Note that we are using pre-computed versions of certain matrix-matrix products to speed things up. This is explicitly explained wherever it is done. """ def __init__(self, layer_collection, - inputs, - outputs, has_bias=False, + num_uses=None, option=SeriesFBApproximation.option2): """Constructs a new `FullyConnectedSeriesFB`. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - inputs: List of tensors of shape [batch_size, input_size]. - Inputs to the layer. - outputs: List of tensors of shape [batch_size, input_size]. - Outputs of the layer (before activations). has_bias: Whether the layer includes a bias parameter. + num_uses: int or None. Number of time-steps over which the layer + is used. Only required if the data is formatted with time folded into + the batch dimension (instead of time being a list dimension). + (Default: None) option: A `SeriesFBApproximation` specifying the simplifying assumption to be used in this block. `option1` approximates the cross-covariance over time as a symmetric matrix, while `option2` makes @@ -955,48 +1582,58 @@ class FullyConnectedSeriesFB(FisherBlock): 3.5 of the paper for more details. """ - assert len(inputs) == len(outputs) - # We need to make sure inputs and outputs are tuples and not lists so that - # they get hashed by layer_collection.make_or_get_factor properly. - self._inputs = tuple(inputs) - self._outputs = tuple(outputs) self._has_bias = has_bias - self._num_timesteps = len(inputs) self._option = option - super(FullyConnectedSeriesFB, self).__init__(layer_collection) + super(FullyConnectedSeriesFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + @property + def _num_timesteps(self): + return self._num_uses @property - def num_registered_minibatches(self): - # TODO(b/69411207): Add support for registering additional minibatches. - return 1 + def _renorm_coeff(self): + # This should no longer be used since the multiply_X functions from the base + # class have been overridden + assert False def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + self._input_factor.register_cov_dt1() self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._output_factor.register_cov_dt1() - damping = normalize_damping(damping, self._num_timesteps) - self._damping_input, self._damping_output = compute_pi_adjusted_damping( - self._input_factor.get_cov(), - self._output_factor.get_cov(), - damping**0.5) + self._setup_damping(damping, normalization=self._num_uses) + + def register_matpower(self, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") if self._option == SeriesFBApproximation.option1: - self._input_factor.register_option1quants(self._damping_input) - self._output_factor.register_option1quants(self._damping_output) + self._input_factor.register_option1quants(self._input_damping_func) + self._output_factor.register_option1quants(self._output_damping_func) elif self._option == SeriesFBApproximation.option2: - self._input_factor.register_option2quants(self._damping_input) - self._output_factor.register_option2quants(self._damping_output) + self._input_factor.register_option2quants(self._input_damping_func) + self._output_factor.register_option2quants(self._output_damping_func) else: raise ValueError( "Unrecognized FullyConnectedSeriesFB approximation: {}".format( self._option)) - def multiply_inverse(self, vector): + def multiply_matpower(self, vector, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + # pylint: disable=invalid-name Z = utils.layer_params_to_mat2d(vector) @@ -1007,9 +1644,11 @@ class FullyConnectedSeriesFB(FisherBlock): if self._option == SeriesFBApproximation.option1: - # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. - L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) - L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) + L_A, psi_A = self._input_factor.get_option1quants( + self._input_damping_func) + L_G, psi_G = self._output_factor.get_option1quants( + self._output_damping_func) def gamma(x): # We are assuming that each case has the same number of time-steps. @@ -1019,60 +1658,61 @@ class FullyConnectedSeriesFB(FisherBlock): T = self._num_timesteps return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) # Even though Y is Z-independent we are recomputing it from the psi's # each since Y depends on both A and G quantities, and it is relatively # cheap to compute. Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - # Z = L_G^T * Z * L_A + # \\(Z = L_G^T * Z * L_A\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = U_G^T * Z * U_A + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = U_G^T * Z * U_A\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - # Z = Z .* Y + # \\(Z = Z .* Y\\) Z *= Y - # Z = L_G * Z * L_A^T + # \\(Z = L_G * Z * L_A^T\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = U_G * Z * U_A^T - # Z = G0^(-1/2) * Z * A0^(-1/2) + # \\(Z = U_G * Z * U_A^T\\) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) elif self._option == SeriesFBApproximation.option2: - # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), - # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. - P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), + # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) + P_A, K_A, mu_A = self._input_factor.get_option2quants( + self._input_damping_func) P_G, K_G, mu_G = self._output_factor.get_option2quants( - self._damping_output) + self._output_damping_func) # Our approach differs superficially from the pseudo-code in the paper # in order to reduce the total number of matrix-matrix multiplies. # In particular, the first three computations in the pseudo code are - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = Z - hPsi_G^T * Z * hPsi_A - # Z = E_G^T * Z * E_A - # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that - # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) + # \\(Z = E_G^T * Z * E_A\\) + # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that + # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) # the entire computation can be written as - # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A - # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A - # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A - # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A - # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) + # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) + # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) + # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) + # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) # This final expression is computed by the following two lines: - # Z = Z - P_G * Z * P_A^T + # \\(Z = Z - P_G * Z * P_A^T\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # Z = K_G^T * Z * K_A + # \\(Z = K_G^T * Z * K_A\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) # Be careful with the outer product. We don't want to accidentally # make it an inner-product instead. tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A @@ -1083,13 +1723,13 @@ class FullyConnectedSeriesFB(FisherBlock): # We now perform the transpose/reverse version of the operations # derived above, whose derivation from the original pseudo-code is # analgous. - # Z = K_G * Z * K_A^T + # \\(Z = K_G * Z * K_A^T\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - # Z = Z - P_G^T * Z * P_A + # \\(Z = Z - P_G^T * Z * P_A\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - # Z = normalize (1/E[T]) * Z + # \\(Z = normalize (1/E[T]) * Z\\) # Note that this normalization is done because we compute the statistics # by averaging, not summing, over time. (And the gradient is presumably # summed over time, not averaged, and thus their scales are different.) @@ -1102,11 +1742,11 @@ class FullyConnectedSeriesFB(FisherBlock): # pylint: enable=invalid-name - def multiply(self, vector): - raise NotImplementedError + def multiply_cholesky(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") - def tensors_to_compute_grads(self): - return self._outputs + def multiply_cholesky_inverse(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") - def num_inputs(self): - return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 603d8b8b210279ee6d8f1de0ce10869fde23f4d9..b43232dfafaa6d90ca3feda65e5c412d3b755651 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -24,6 +24,7 @@ import contextlib import numpy as np import six +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops @@ -32,17 +33,24 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages +from tensorflow.python.util import nest + # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). -INIT_COVARIANCES_AT_ZERO = False +INIT_COVARIANCES_AT_ZERO = True # Whether to zero-debias the moving averages. -ZERO_DEBIAS = False +ZERO_DEBIAS = True + +# Whether to initialize inverse (and other such matrices computed from the cov +# matrices) to the zero matrix (or the identity matrix). +INIT_INVERSES_AT_ZERO = True # When the number of inverses requested from a FisherFactor exceeds this value, # the inverses are computed using an eigenvalue decomposition. @@ -53,63 +61,99 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 -# Colocate the covariance ops and variables with the input tensors for each -# factor. -COLOCATE_COV_OPS_WITH_INPUTS = True +# Used to subsample the flattened extracted image patches. The number of +# outer products per row of the covariance matrix should not exceed this +# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True. +_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 +# Used to subsample the inputs passed to the extract image patches. The batch +# size of number of inputs to extract image patches is multiplied by this +# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. +_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 -@contextlib.contextmanager -def maybe_colocate_with(op): - """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" - if COLOCATE_COV_OPS_WITH_INPUTS: - if isinstance(op, (list, tuple)): - with tf_ops.colocate_with(op[0]): - yield - else: - with tf_ops.colocate_with(op): - yield - else: - yield +# If True, then subsamples the tensor passed to compute the covaraince matrix. +_SUB_SAMPLE_OUTER_PRODUCTS = False + +# If True, then subsamples the tensor passed to compute the covaraince matrix. +_SUB_SAMPLE_INPUTS = False + +# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data +# passed to the factors from the blocks will be concatenated across towers +# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over +# towers will be passed in, and the factors will iterate over this and do the +# cov computations separately for each one, averaging the results together. +TOWER_STRATEGY = "concat" def set_global_constants(init_covariances_at_zero=None, zero_debias=None, + init_inverses_at_zero=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, - colocate_cov_ops_with_inputs=None): + max_num_outer_products_per_cov_row=None, + sub_sample_outer_products=None, + inputs_to_extract_patches_factor=None, + sub_sample_inputs=None, + tower_strategy=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS + global INIT_INVERSES_AT_ZERO global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD - global COLOCATE_COV_OPS_WITH_INPUTS + global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW + global _SUB_SAMPLE_OUTER_PRODUCTS + global _INPUTS_TO_EXTRACT_PATCHES_FACTOR + global _SUB_SAMPLE_INPUTS + global TOWER_STRATEGY if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero if zero_debias is not None: ZERO_DEBIAS = zero_debias + if init_inverses_at_zero is not None: + INIT_INVERSES_AT_ZERO = init_inverses_at_zero if eigenvalue_decomposition_threshold is not None: EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold - if colocate_cov_ops_with_inputs is not None: - COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs + if max_num_outer_products_per_cov_row is not None: + _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row + if sub_sample_outer_products is not None: + _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products + if inputs_to_extract_patches_factor is not None: + _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor + if sub_sample_inputs is not None: + _SUB_SAMPLE_INPUTS = sub_sample_inputs + if tower_strategy is not None: + TOWER_STRATEGY = tower_strategy def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - return array_ops.diag(array_ops.ones(shape[0], dtype)) + if INIT_INVERSES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument if INIT_COVARIANCES_AT_ZERO: - return array_ops.diag(array_ops.zeros(shape[0], dtype)) - return array_ops.diag(array_ops.ones(shape[0], dtype)) + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) -def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: disable=unused-argument +def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype) - return array_ops.ones(shape, dtype) + return array_ops.zeros(shape, dtype=dtype) + return array_ops.ones(shape, dtype=dtype) + + +@contextlib.contextmanager +def place_on_device(device): + if device is not None and len(device): + with tf_ops.device(device): + yield + else: + yield def compute_cov(tensor, tensor_right=None, normalizer=None): @@ -181,7 +225,9 @@ def scope_string_from_params(params): name_parts = [] for param in params: - if isinstance(param, (tuple, list)): + if param is None: + name_parts.append("None") + elif isinstance(param, (tuple, list)): if all([isinstance(p, int) for p in param]): name_parts.append("-".join([str(p) for p in param])) else: @@ -190,6 +236,8 @@ def scope_string_from_params(params): name_parts.append(str(param)) elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) + elif isinstance(param, utils.PartitionedTensor): + name_parts.append(scope_string_from_name(param.tensors)) else: raise ValueError("Encountered an unsupported param type {}".format( type(param))) @@ -207,6 +255,74 @@ def scalar_or_tensor_to_string(val): return repr(val) if np.isscalar(val) else scope_string_from_name(val) +def list_to_string(lst): + return "_".join(val if isinstance(val, six.string_types) + else scalar_or_tensor_to_string(val) for val in lst) + + +def graph_func_to_id(func): + """Returns a hashable object that represents func's computation.""" + # TODO(b/74201126): replace with Topohash of func's output + return func.func_id + + +def graph_func_to_string(func): + # TODO(b/74201126): replace with Topohash of func's output + return list_to_string(func.func_id) + + +def _subsample_for_cov_computation(array, name=None): + """Subsamples the first dimension of the array. + + `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance + matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer + products per row of the covariance matrix is greater than + `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + name: `string`, Default(None) + + Returns: + A tensor of shape `[max_samples, dim_2]`. + + Raises: + ValueError: If array's is not matrix-shaped. + ValueError: If array's batch_size cannot be inferred. + + """ + with tf_ops.name_scope(name, "subsample", [array]): + array = tf_ops.convert_to_tensor(array) + if len(array.shape) != 2: + raise ValueError("Input param array must be a matrix.") + + batch_size = array.shape.as_list()[0] + if batch_size is None: + raise ValueError("Unable to get batch_size from input param array.") + + num_cov_rows = array.shape.as_list()[-1] + max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows) + if batch_size <= max_batch_size: + return array + + return _random_tensor_gather(array, max_batch_size) + + +def _random_tensor_gather(array, max_size): + """Generates a random set of indices and gathers the value at the indcices. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + max_size: int, Number of indices to sample. + + Returns: + A tensor of shape `[max_size, ...]`. + """ + batch_size = array.shape.as_list()[0] + indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size] + return array_ops.gather(array, indices) + + @six.add_metaclass(abc.ABCMeta) class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. @@ -223,13 +339,10 @@ class FisherFactor(object): Note that for blocks that aren't based on approximations, a 'factor' can be the entire block itself, as is the case for the diagonal and full representations. - - Subclasses must implement the _compute_new_cov() method, and the _var_scope - and _cov_shape properties. """ def __init__(self): - self.instantiate_covariance() + self._cov = None @abc.abstractproperty def _var_scope(self): @@ -240,6 +353,10 @@ class FisherFactor(object): """ pass + @property + def name(self): + return self._var_scope + @abc.abstractproperty def _cov_shape(self): """The shape of the variable backing this FisherFactor.""" @@ -257,6 +374,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _num_towers(self): + pass + @abc.abstractproperty def _dtype(self): """dtype for variable backing this factor.""" @@ -267,8 +388,9 @@ class FisherFactor(object): """Function for initializing covariance variable.""" return covariance_initializer - def instantiate_covariance(self): - """Instantiates the covariance Variable as the instance member _cov.""" + def instantiate_cov_variables(self): + """Makes the internal cov variable(s).""" + assert self._cov is None with variable_scope.variable_scope(self._var_scope): self._cov = variable_scope.get_variable( "cov", @@ -278,15 +400,17 @@ class FisherFactor(object): dtype=self._dtype) @abc.abstractmethod - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): """Computes minibatch-estimated covariance for a single source. Args: - idx: int in [0, self._num_sources). Which source to use when estimating - covariance. + source: int in [0, self._num_sources). Which source to use when computing + the cov update. + tower: int in [0, self._num_towers). Which tower to use when computing + the cov update. Returns: - Tensor of same shape as self.get_cov_var(). + Tensor of same shape as self.get_cov(). """ pass @@ -298,123 +422,80 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov_contribs = tuple(self._compute_new_cov(idx) - for idx in range(self._num_sources)) - # This gets the job done but we might want a better solution in the future. - # In particular, we could have a separate way of specifying where the - # the cov variables finally end up, independent of where their various - # contributions are computed. Right now these are the same thing, but in - # the future we might want to perform the cov computations on each tower, - # so that each tower will be considered a "source" (allowing us to reuse - # the existing "source" code for this). - with maybe_colocate_with(new_cov_contribs[0]): - new_cov = math_ops.add_n(new_cov_contribs) - # Synchronize value across all TPU cores. - if utils.on_tpu(): - new_cov = utils.cross_replica_mean(new_cov) - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + new_cov_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + device = (self._get_data_device(tower) + if TOWER_STRATEGY == "separate" else None) + with place_on_device(device): + new_cov_contribs.append(self._compute_new_cov(source, tower)) + + new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) + + # Compute average of 'new_cov' across all TPU cores. On a TPU, each + # instance of 'new_cov' will be based on a different minibatch. This ensures + # that by the end of assign_moving_average(), all TPU cores see the same + # value for self._cov. + # + # Other implementations of make_covariance_update_op() that accumulate + # statistics in other variables should mimic this behavior. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) @abc.abstractmethod - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" + def _get_data_device(self, tower): pass @abc.abstractmethod - def get_cov(self): - """Get full covariance matrix. - - Returns: - Tensor of shape [n, n]. Represents all parameter-parameter correlations - captured by this FisherFactor. - """ + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" pass - def get_cov_var(self): - """Get variable backing this FisherFactor. - - May or may not be the same as self.get_cov() + @abc.abstractmethod + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + pass - Returns: - Variable of shape self._cov_shape. - """ + def get_cov(self): return self._cov @abc.abstractmethod - def left_multiply(self, x, damping): - """Multiplies 'x' by the damped covariance of this factor. - - Let C be the covariance matrix this factor represents, and - D = C + damping * I be its damped variant. This method calculates - matmul(D, vec(x)). - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. - - Returns: - Tensor of same shape as 'x'. - """ + def get_cov_as_linear_operator(self): pass @abc.abstractmethod - def right_multiply(self, x, damping): - """Multiplies 'x' by the damped covariance of this factor. - - Let C be the covariance matrix this factor represents, and - D = C + damping * I be its damped variant. This method calculates - matmul(vec(x), D). - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. - - Returns: - Tensor of same shape as 'x'. - """ + def register_matpower(self, exp, damping_func): pass @abc.abstractmethod - def left_multiply_inverse(self, x, damping): - """Multiplies 'x' by damped inverse of this factor. - - Let C be the covariance matrix this factor represents and - E = inv(C + damping * I) be its damped inverse. This method calculates - matmul(E, vec(x)). - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. - - Returns: - Tensor of same shape as 'x'. - """ + def register_cholesky(self, damping_func): pass @abc.abstractmethod - def right_multiply_inverse(self, x, damping): - """Multiplies 'x' by damped inverse of this factor. + def register_cholesky_inverse(self, damping_func): + pass - Let C be the covariance matrix this factor represents and - E = inv(C + damping * I) be its damped inverse. This method calculates - matmul(vec(x), E). + @abc.abstractmethod + def get_matpower(self, exp, damping_func): + pass - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. + @abc.abstractmethod + def get_cholesky(self, damping_func): + pass - Returns: - Tensor of same shape as 'x'. - """ + @abc.abstractmethod + def get_cholesky_inverse(self, damping_func): pass -class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses explicitly. +class DenseSquareMatrixFactor(FisherFactor): + """Base class for FisherFactors that are stored as dense square matrices. - This class explicitly calculates and stores inverses of covariance matrices - provided by the underlying FisherFactor implementation. It is assumed that - vectors can be represented as 2-D matrices. + This class explicitly calculates and stores inverses of their `cov` matrices, + which must be square dense matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -428,47 +509,94 @@ class InverseProvidingFactor(FisherFactor): # the latter. def __init__(self): - self._inverses_by_damping = {} - self._matpower_by_exp_and_damping = {} + self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } + self._matpower_registrations = set() # { (float, hashable) } self._eigendecomp = None + self._damping_funcs_by_id = {} # {hashable: lambda} + + self._cholesky_registrations = set() # { hashable } + self._cholesky_inverse_registrations = set() # { hashable } + + self._cholesky_by_damping = {} # { hashable: variable } + self._cholesky_inverse_by_damping = {} # { hashable: variable } - super(InverseProvidingFactor, self).__init__() + super(DenseSquareMatrixFactor, self).__init__() - def register_damped_inverse(self, damping): - """Registers a damped inverse needed by a FisherBlock. + def get_cov_as_linear_operator(self): + assert self.get_cov().shape.ndims == 2 + return lo.LinearOperatorFullMatrix(self.get_cov(), + is_self_adjoint=True, + is_square=True) + + def _register_damping(self, damping_func): + damping_id = graph_func_to_id(damping_func) + if damping_id not in self._damping_funcs_by_id: + self._damping_funcs_by_id[damping_id] = damping_func + return damping_id + + def register_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + self.register_matpower(-1, damping_func) + + def register_matpower(self, exp, damping_func): + """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method - get_inverse. + get_matpower. Args: - damping: The damping value (float or Tensor) for this factor. + exp: float. The exponent to use in the matrix power. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). """ - if damping not in self._inverses_by_damping: - damping_string = scalar_or_tensor_to_string(damping) - with variable_scope.variable_scope(self._var_scope): - inv = variable_scope.get_variable( - "inv_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - self._inverses_by_damping[damping] = inv + if exp == 1.0: + return + + damping_id = self._register_damping(damping_func) - def register_matpower(self, exp, damping): - """Registers a matrix power needed by a FisherBlock. + if (exp, damping_id) not in self._matpower_registrations: + self._matpower_registrations.add((exp, damping_id)) + + def register_cholesky(self, damping_func): + """Registers a Cholesky factor to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method - get_matpower. + get_cholesky. Args: - exp: The exponent (float or Tensor) to raise the matrix to. - damping: The damping value (float or Tensor). + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). """ - if (exp, damping) not in self._matpower_by_exp_and_damping: + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_registrations: + self._cholesky_registrations.add(damping_id) + + def register_cholesky_inverse(self, damping_func): + """Registers an inverse Cholesky factor to be maintained/served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky_inverse. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_inverse_registrations: + self._cholesky_inverse_registrations.add(damping_id) + + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + + for (exp, damping_id) in self._matpower_registrations: exp_string = scalar_or_tensor_to_string(exp) - damping_string = scalar_or_tensor_to_string(damping) + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) with variable_scope.variable_scope(self._var_scope): matpower = variable_scope.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), @@ -476,34 +604,62 @@ class InverseProvidingFactor(FisherFactor): shape=self._cov_shape, trainable=False, dtype=self._dtype) - self._matpower_by_exp_and_damping[(exp, damping)] = matpower + assert (exp, damping_id) not in self._matpower_by_exp_and_damping + self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower + + for damping_id in self._cholesky_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + chol = variable_scope.get_variable( + "cholesky_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_by_damping + self._cholesky_by_damping[damping_id] = chol + + for damping_id in self._cholesky_inverse_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + cholinv = variable_scope.get_variable( + "cholesky_inverse_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_inverse_by_damping + self._cholesky_inverse_by_damping[damping_id] = cholinv def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] - # We do this to ensure that we don't reuse the eigendecomp from old calls - # to make_inverse_update_ops that may be placed on different devices. This - # can happen is the user has both a permanent and lazily constructed - # version of the inverse ops (and only uses one of them). - self.reset_eigendecomp() + num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping + if exp == -1) + + num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses + + other_matrix_power_registered = num_other_matpower >= 1 - num_inverses = len(self._inverses_by_damping) - matrix_power_registered = bool(self._matpower_by_exp_and_damping) use_eig = ( - self._eigendecomp or matrix_power_registered or + self._eigendecomp or other_matrix_power_registered or num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + # We precompute these so we don't need to evaluate them multiple times (for + # each matrix power that uses them) + damping_value_by_id = {damping_id: math_ops.cast( + self._damping_funcs_by_id[damping_id](), self._dtype) + for damping_id in self._damping_funcs_by_id} + if use_eig: eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence - for damping, inv in self._inverses_by_damping.items(): - ops.append( - inv.assign( - math_ops.matmul(eigenvectors / (eigenvalues + damping), - array_ops.transpose(eigenvectors)))) - - for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + damping = damping_value_by_id[damping_id] ops.append( matpower.assign( math_ops.matmul(eigenvectors * @@ -512,30 +668,95 @@ class InverseProvidingFactor(FisherFactor): # These ops share computation and should be run on a single device. ops = [control_flow_ops.group(*ops)] else: - for damping, inv in self._inverses_by_damping.items(): - ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) - + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + assert exp == -1 + damping = damping_value_by_id[damping_id] + ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) + + # TODO(b/77902055): If inverses are being computed with Cholesky's + # we can share the work. Instead this code currently just computes the + # Cholesky a second time. It does at least share work between requests for + # Cholesky's and Cholesky inverses with the same damping id. + for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): + cholesky_ops = [] + + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + + if damping_id in self._cholesky_by_damping: + cholesky = self._cholesky_by_damping[damping_id] + cholesky_ops.append(cholesky.assign(cholesky_value)) + + identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], + dtype=cholesky_value.dtype) + cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, + identity) + cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) + + ops.append(control_flow_ops.group(*cholesky_ops)) + + for damping_id, cholesky in self._cholesky_by_damping.items(): + if damping_id not in self._cholesky_inverse_by_damping: + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + ops.append(cholesky.assign(cholesky_value)) + + self._eigendecomp = False return ops - def get_damped_inverse(self, damping): + def get_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + return self.get_matpower(-1, damping_func) + + def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). - return self._inverses_by_damping[damping] - - def get_matpower(self, exp, damping): + if exp != 1: + damping_id = graph_func_to_id(damping_func) + matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] + else: + matpower = self.get_cov() + identity = linalg_ops.eye(matpower.shape.as_list()[0], + dtype=matpower.dtype) + matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity + + assert matpower.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(matpower, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). - return self._matpower_by_exp_and_damping[(exp, damping)] + damping_id = graph_func_to_id(damping_func) + cholesky = self._cholesky_by_damping[damping_id] + assert cholesky.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky, + is_non_singular=True, + is_square=True) + + def get_cholesky_inverse(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky_inv = self._cholesky_inverse_by_damping[damping_id] + assert cholesky_inv.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky_inv, + is_non_singular=True, + is_square=True) def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" - # Unlike get_inverse and get_matpower this doesn't retrieve a stored - # variable, but instead always computes a fresh version from the current - # value of get_cov(). + # Unlike get_matpower this doesn't retrieve a stored variable, but instead + # always computes a fresh version from the current value of get_cov(). if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here @@ -546,66 +767,8 @@ class InverseProvidingFactor(FisherFactor): return self._eigendecomp - def reset_eigendecomp(self): - self._eigendecomp = None - - def get_cov(self): - # Variable contains full covariance matrix. - return self.get_cov_var() - - def left_multiply(self, x, damping): - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping * array_ops.eye(n) - - if isinstance(x, tf_ops.IndexedSlices): - raise NotImplementedError( - "Left-multiply not yet supported for IndexedSlices.") - - if len(x.shape) != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - return math_ops.matmul(damped_cov, x) - - def right_multiply(self, x, damping): - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping * array_ops.eye(n) - - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_sparse_dense(x, damped_cov) - - if len(x.shape) != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - return math_ops.matmul(x, damped_cov) - - def left_multiply_inverse(self, x, damping): - if isinstance(x, tf_ops.IndexedSlices): - raise ValueError("Left-multiply not yet supported for IndexedSlices.") - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - return math_ops.matmul(self.get_damped_inverse(damping), x) - - def right_multiply_inverse(self, x, damping): - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping)) - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - return math_ops.matmul(x, self.get_damped_inverse(damping)) - - -class FullFactor(InverseProvidingFactor): +class FullFactor(DenseSquareMatrixFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. Note that this uses the naive "square the sum estimator", and so is applicable @@ -622,7 +785,7 @@ class FullFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_full/" + scope_string_from_params( + return "ff_full_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property @@ -635,17 +798,25 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): + assert tower == 0 + # This will be a very basic rank 1 estimate - with maybe_colocate_with(self._params_grads[idx]): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) - return ((params_grads_flat * array_ops.transpose( - params_grads_flat)) / math_ops.cast(self._batch_size, - params_grads_flat.dtype)) + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None class DiagonalFactor(FisherFactor): @@ -658,51 +829,49 @@ class DiagonalFactor(FisherFactor): def __init__(self): super(DiagonalFactor, self).__init__() + def get_cov_as_linear_operator(self): + assert self._matrix_diagonal.shape.ndims == 1 + return lo.LinearOperatorDiag(self._matrix_diagonal, + is_self_adjoint=True, + is_square=True) + @property def _cov_initializer(self): return diagonal_covariance_initializer + @property + def _matrix_diagonal(self): + return array_ops.reshape(self.get_cov(), [-1]) + def make_inverse_update_ops(self): return [] - def get_cov(self): - # self.get_cov() could be any shape, but it must have one entry per - # parameter. Flatten it into a vector. - cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) - return array_ops.diag(cov_diag_vec) - - def left_multiply(self, x, damping): - damped_cov = self.get_cov_var() + damping - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x) - - if x.shape != damped_cov.shape: - raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, damped_cov)) - - return damped_cov * x - - def right_multiply(self, x, damping): - raise NotImplementedError("Only left-multiply is currently supported.") + def instantiate_inv_variables(self): + pass - def left_multiply_inverse(self, x, damping): - inverse = 1. / (self.get_cov_var() + damping) + def register_matpower(self, exp, damping_func): + pass - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x) + def register_cholesky(self, damping_func): + pass - if x.shape != inverse.shape: - raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, inverse)) + def register_cholesky_inverse(self, damping_func): + pass - return inverse * x + def get_matpower(self, exp, damping_func): + matpower_diagonal = (self._matrix_diagonal + + math_ops.cast(damping_func(), self._dtype))**exp + return lo.LinearOperatorDiag(matpower_diagonal, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) - def right_multiply_inverse(self, x, damping): - raise NotImplementedError("Only left-multiply is currently supported.") + def get_cholesky(self, damping_func): + return self.get_matpower(0.5, damping_func) - def register_damped_inverse(self, damping): - # DiagonalFactors don't keep explicit inverses. - pass + def get_cholesky_inverse(self, damping_func): + return self.get_matpower(-0.5, damping_func) class NaiveDiagonalFactor(DiagonalFactor): @@ -730,7 +899,7 @@ class NaiveDiagonalFactor(DiagonalFactor): @property def _var_scope(self): - return "ff_naivediag/" + scope_string_from_params( + return "ff_naivediag_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property @@ -743,15 +912,23 @@ class NaiveDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._params_grads[idx]): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) - return (math_ops.square(params_grads_flat) / math_ops.cast( - self._batch_size, params_grads_flat.dtype)) + def _compute_new_cov(self, source, tower): + assert tower == 0 + + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None class EmbeddingInputKroneckerFactor(DiagonalFactor): @@ -772,8 +949,8 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): """Instantiate EmbeddingInputKroneckerFactor. Args: - input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype - int32. Indices into embedding matrix. + input_ids: List of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. List index is tower. vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. dtype: dtype for covariance statistics. Must be a floating point type. Defaults to float32. @@ -786,7 +963,7 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): @property def _var_scope(self): - return "ff_diag_embedding/" + scope_string_from_params(self._input_ids) + return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) @property def _cov_shape(self): @@ -794,42 +971,51 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): @property def _num_sources(self): + return 1 + + @property + def _num_towers(self): return len(self._input_ids) @property def _dtype(self): return self._cov_dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._input_ids): - input_ids = self._input_ids[idx] - if len(input_ids.shape) > 2: - raise ValueError( - "Input to embeddings must have rank <= 2. Found rank %d." % len( - input_ids.shape)) + def _compute_new_cov(self, source, tower): + assert source == 0 + + input_ids = self._input_ids[tower] + + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) - batch_size = array_ops.shape(input_ids)[0] + batch_size = array_ops.shape(input_ids)[0] - # Transform indices into one-hot vectors. - # - # TODO(b/72714822): There must be a faster way to construct the diagonal - # covariance matrix! This operation is O(batch_size * vocab_size), where - # it should be O(batch_size * input_size). - flat_input_ids = array_ops.reshape(input_ids, [-1]) - one_hots = array_ops.one_hot(flat_input_ids, - self._vocab_size) # [?, vocab_size] + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] - # Take average across examples. Note that, because all entries have - # magnitude zero or one, there's no need to square the entries. - # - # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation - # within an example such as average. - # - # TODO(b/72714822): Support for partitioned embeddings. - new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov + return new_cov + + def _get_data_device(self, tower): + return self._input_ids[tower].device class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -850,58 +1036,75 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): """Instantiate FullyConnectedDiagonalFactor. Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to fully - connected layer. - outputs_grads: List of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to layer's preactivations. + inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this + layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, output_size], + which are the gradients of the loss with respect to the layer's + outputs. First index is source, second is tower. + has_bias: bool. If True, append '1' to each input. """ self._inputs = inputs self._has_bias = has_bias self._outputs_grads = outputs_grads - self._batch_size = array_ops.shape(inputs)[0] self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_diagfc/" + scope_string_from_params( - (self._inputs,) + tuple(self._outputs_grads)) + return "ff_diagfc_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): - input_size = self._inputs.shape[1] + self._has_bias - output_size = self._outputs_grads[0].shape[1] + input_size = self._inputs[0].shape[1] + self._has_bias + output_size = self._outputs_grads[0][0].shape[1] return [input_size, output_size] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._outputs_grads[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + self._squared_inputs = [] + for tower in range(self._num_towers): + inputs = self._inputs[tower] + + with place_on_device(self._get_data_device(tower)): + if self._has_bias: + inputs = append_homog(inputs) + self._squared_inputs.append(math_ops.square(inputs)) + + return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( + ema_decay) + + def _compute_new_cov(self, source, tower): + batch_size = array_ops.shape(self._squared_inputs[tower])[0] + outputs_grad = self._outputs_grads[source][tower] - def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with maybe_colocate_with(self._outputs_grads[idx]): - # We only need to compute squared_inputs once - if self._squared_inputs is None: - inputs = self._inputs - if self._has_bias: - inputs = append_homog(self._inputs) - self._squared_inputs = math_ops.square(inputs) + new_cov = math_ops.matmul( + self._squared_inputs[tower], + math_ops.square(outputs_grad), + transpose_a=True) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + return new_cov - new_cov = math_ops.matmul( - self._squared_inputs, - math_ops.square(self._outputs_grads[idx]), - transpose_a=True) - new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) - return new_cov + def _get_data_device(self, tower): + return self._inputs[tower].device class ConvDiagonalFactor(DiagonalFactor): @@ -913,36 +1116,67 @@ class ConvDiagonalFactor(DiagonalFactor): filter_shape, strides, padding, + data_format=None, + dilations=None, has_bias=False): """Creates a ConvDiagonalFactor object. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. - outputs_grads: Tensor of shape [batch_size, height, width, out_channels]. - Per-example gradients to the loss with respect to the layer's output - preactivations. + inputs: List of Tensors of shape [batch_size, height, width, in_channels]. + Input activations to this layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, + height, width, out_channels], which are the gradients of the loss + with respect to the layer's outputs. First index is source, second + index is tower. filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). + data_format: None or str. Format of conv2d inputs. + dilations: None or tuple of 4 ints. has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + + Raises: + ValueError: If inputs, output_grads, and filter_shape do not agree on + in_channels or out_channels. + ValueError: If strides, dilations are not length-4 lists of ints. + ValueError: If data_format does not put channel last. """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + if any(input_.shape.ndims != 4 for input_ in inputs): + raise ValueError("inputs must be a list of 4-D Tensors.") + if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): + raise ValueError("inputs and filter_shape must agree on in_channels.") + for i, outputs_grad in enumerate(outputs_grads): + if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): + raise ValueError("outputs[%d] must be 4-D Tensor." % i) + if any(output_grad.shape.as_list()[-1] != filter_shape[-1] + for output_grad in outputs_grad): + raise ValueError( + "outputs[%d] and filter_shape must agree on out_channels." % i) + if len(strides) != 4: + raise ValueError("strides must be length-4 list of ints.") + if dilations is not None and len(dilations) != 4: + raise ValueError("dilations must be length-4 list of ints.") + self._inputs = inputs + self._outputs_grads = outputs_grads self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._data_format = data_format + self._dilations = dilations self._has_bias = has_bias - self._outputs_grads = outputs_grads self._patches = None super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_convdiag/" + scope_string_from_name( - (self._inputs,) + tuple(self._outputs_grads)) + return "ff_convdiag_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): @@ -956,43 +1190,50 @@ class ConvDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._inputs[0].dtype def make_covariance_update_op(self, ema_decay): - with maybe_colocate_with(self._inputs): - filter_height, filter_width, _, _ = self._filter_shape + filter_height, filter_width, _, _ = self._filter_shape - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) - - if self._has_bias: - patches = append_homog(patches) - - self._patches = patches + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + if self._dilations is None: + rates = (1, 1, 1, 1) + else: + rates = tuple(self._dilations) + + self._patches = [] + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + patches = array_ops.extract_image_patches( + self._inputs[tower], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=rates, + padding=self._padding) - op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + if self._has_bias: + patches = append_homog(patches) - self._patches = None + self._patches.append(patches) - return op + return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - outputs_grad = self._outputs_grads[idx] - batch_size = array_ops.shape(self._patches)[0] + def _compute_new_cov(self, source, tower): + patches = self._patches[tower] + batch_size = array_ops.shape(patches)[0] + outputs_grad = self._outputs_grads[source][tower] - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov + return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". @@ -1002,8 +1243,11 @@ class ConvDiagonalFactor(DiagonalFactor): outputs_grad) return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + def _get_data_device(self, tower): + return self._inputs[tower].device + -class FullyConnectedKroneckerFactor(InverseProvidingFactor): +class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ @@ -1013,8 +1257,9 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Instantiate FullyConnectedKroneckerFactor. Args: - tensors: List of Tensors of shape [batch_size, n]. Represents either a - layer's inputs or its output's gradients. + tensors: List of list of Tensors, each of shape [batch_size, n]. The + Tensors are typically either a layer's inputs or its output's gradients. + The first list index is source, the second is tower. has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of @@ -1025,31 +1270,37 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_fckron/" + scope_string_from_params( - [self._tensors, self._has_bias]) + return "ff_fckron_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + (self._has_bias,)) @property def _cov_shape(self): - size = self._tensors[0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size, size] @property def _num_sources(self): return len(self._tensors) + @property + def _num_towers(self): + return len(self._tensors[0]) + @property def _dtype(self): - return self._tensors[0].dtype + return self._tensors[0][0].dtype + + def _compute_new_cov(self, source, tower): + tensor = self._tensors[source][tower] + if self._has_bias: + tensor = append_homog(tensor) + return compute_cov(tensor) - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._tensors[idx]): - tensor = self._tensors[idx] - if self._has_bias: - tensor = append_homog(tensor) - return compute_cov(tensor) + def _get_data_device(self, tower): + return self._tensors[0][tower].device -class ConvInputKroneckerFactor(InverseProvidingFactor): +class ConvInputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the input side of a convolutional layer. Estimates E[ a a^T ] where a is the inputs to a convolutional layer given @@ -1062,39 +1313,69 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def __init__(self, inputs, filter_shape, - strides, padding, - has_bias=False): + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + has_bias=False, + sub_sample_inputs=None, + sub_sample_patches=None): """Initializes ConvInputKroneckerFactor. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs - to layer. - filter_shape: 1-D Tensor of length 4. Contains [kernel_height, - kernel_width, in_channels, out_channels]. - strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride, - width_stride, in_channel_stride]. + inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., + in_channels]. Inputs to layer. List index is tower. + filter_shape: List of ints. Contains [..spatial_filter_size.., + in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. + sub_sample_inputs: `bool`. If True, then subsample the inputs from which + the image patches are extracted. (Default: None) + sub_sample_patches: `bool`, If `True` then subsample the extracted + patches.(Default: None) """ + self._inputs = inputs self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._dilation_rate = dilation_rate + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias - self._inputs = inputs + if sub_sample_inputs is None: + self._sub_sample_inputs = _SUB_SAMPLE_INPUTS + else: + self._sub_sample_inputs = sub_sample_inputs + + if sub_sample_patches is None: + self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS + else: + self._sub_sample_patches = sub_sample_patches super(ConvInputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convinkron/" + scope_string_from_params([ - self._inputs, self._filter_shape, self._strides, self._padding, - self._has_bias - ]) + return "ff_convinkron_" + scope_string_from_params( + tuple(self._inputs) + + tuple((self._filter_shape, self._strides, self._padding, + self._dilation_rate, self._data_format, self._has_bias))) @property def _cov_shape(self): - filter_height, filter_width, in_channels, _ = self._filter_shape - size = filter_height * filter_width * in_channels + self._has_bias + spatial_filter_shape = self._filter_shape[0:-2] + in_channels = self._filter_shape[-2] + size = np.prod(spatial_filter_shape) * in_channels + self._has_bias return [size, size] @property @@ -1102,47 +1383,88 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): return 1 @property - def _dtype(self): - return self._inputs.dtype - - def _compute_new_cov(self, idx=0): - if idx != 0: - raise ValueError("ConvInputKroneckerFactor only supports idx = 0") - - with maybe_colocate_with(self._inputs): - filter_height, filter_width, in_channels, _ = self._filter_shape + def _num_towers(self): + return len(self._inputs) - # TODO(b/64144716): there is potential here for a big savings in terms of - # memory use. + @property + def _dtype(self): + return self._inputs[0].dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + inputs = self._inputs[tower] + if self._sub_sample_inputs: + batch_size = inputs.shape.as_list()[0] + max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR) + inputs = _random_tensor_gather(inputs, max_size) + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + if self._extract_patches_fn in [None, "extract_convolution_patches"]: + patches = utils.extract_convolution_patches( + inputs, + self._filter_shape, + padding=self._padding, + strides=self._strides, + dilation_rate=self._dilation_rate, + data_format=self._data_format) + + elif self._extract_patches_fn == "extract_image_patches": + assert inputs.shape.ndims == 4 + assert len(self._filter_shape) == 4 + assert len(self._strides) == 4, self._strides + if self._dilation_rate is None: + rates = [1, 1, 1, 1] + else: + rates = self._dilation_rate + assert len(rates) == 4 + assert rates[0] == rates[-1] == 1 patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], + inputs, + ksizes=[1] + list(self._filter_shape[0:-2]) + [1], strides=self._strides, - rates=[1, 1, 1, 1], + rates=rates, padding=self._padding) - flatten_size = (filter_height * filter_width * in_channels) - # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde - # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), - # where M = minibatch size, |T| = number of spatial locations, - # |Delta| = number of spatial offsets, and J = number of input maps - # for convolutional layer l. - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - # We append a homogenous coordinate to patches_flat if the layer has - # bias parameters. This gives us [[A_l]]_H from the paper. - if self._has_bias: - patches_flat = append_homog(patches_flat) - # We call compute_cov without passing in a normalizer. compute_cov uses - # the first dimension of patches_flat i.e. M|T| as the normalizer by - # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with - # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from - # the paper but has a different scale here for consistency with - # ConvOutputKroneckerFactor. - # (Tilde omitted over A for clarity.) - return compute_cov(patches_flat) - - -class ConvOutputKroneckerFactor(InverseProvidingFactor): + elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": + assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] + assert self._filter_shape[0] == self._filter_shape[1] == 1 + patches = utils.extract_pointwise_conv2d_patches( + inputs, self._filter_shape, data_format=None) + + else: + raise NotImplementedError(self._extract_patches_fn) + + flatten_size = np.prod(self._filter_shape[0:-1]) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. + if self._sub_sample_patches: + patches_flat = _subsample_for_cov_computation(patches_flat) + + if self._has_bias: + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) + + def _get_data_device(self, tower): + return self._inputs[tower].device + + +class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the output side of a convolutional layer. Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer @@ -1153,20 +1475,28 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, data_format=None): """Initializes ConvOutputKroneckerFactor. Args: - outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, height, width, out_channels]. + outputs_grads: List of list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. First list index + is source, the second is tower. + data_format: None or str. Format of outputs_grads. + + Raises: + ValueError: If channels are not final dimension. """ - self._out_channels = outputs_grads[0].shape.as_list()[3] + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + self._out_channels = outputs_grads[0][0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convoutkron/" + scope_string_from_params(self._outputs_grads) + return "ff_convoutkron_" + scope_string_from_params( + nest.flatten(self._outputs_grads)) @property def _cov_shape(self): @@ -1177,134 +1507,150 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._outputs_grads[0]) + @property def _dtype(self): - return self._outputs_grads[0].dtype - - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - # reshaped_tensor below is the matrix DS_l defined in the KFC paper - # (tilde omitted over S for clarity). It has shape M|T| x I, where - # M = minibatch size, |T| = number of spatial locations, and - # I = number of output maps for convolutional layer l. - reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], - [-1, self._out_channels]) - # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, - # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l - # as defined in the paper, with shape I x I. - # (Tilde omitted over S for clarity.) - return compute_cov(reshaped_tensor) - - -class FullyConnectedMultiKF(InverseProvidingFactor): - """Kronecker factor for a fully connected recurrent layer.""" + return self._outputs_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + outputs_grad = self._outputs_grads[source][tower] + + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. + reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + def _get_data_device(self, tower): + return self._outputs_grads[0][tower].device + + +class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): + """Kronecker factor for a fully connected layer used multiple times.""" def __init__(self, - tensor_lists, + tensors, + num_uses=None, has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: - tensor_lists: List of lists of Tensors of shape [batch_size, n]. + tensors: List of list of Tensors of shape, each of shape + [num_uses * batch_size, n], and is a reshape version of a Tensor of + shape [num_uses, batch_size, n]. Each of these tensors is usually a + layer's inputs or its output's gradients. The first list index is + sources, the second is towers. + num_uses: int. The number of time-steps / uses. has_bias: bool. If True, '1' is appended to each row. """ - self._tensor_lists = tensor_lists - self._has_bias = has_bias - self._batch_size = array_ops.shape(tensor_lists[0][0])[0] - self._num_timesteps = len(tensor_lists[0]) - self._tensors = [None] * len(tensor_lists) + self._num_uses = num_uses self._cov_dt1 = None + self._make_cov_dt1 = False self._option1quants_by_damping = {} self._option2quants_by_damping = {} + self._option1quants_registrations = set() + self._option2quants_registrations = set() - super(FullyConnectedMultiKF, self).__init__() - - @property - def _var_scope(self): - return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists) + super(FullyConnectedMultiKF, self).__init__(tensors=tensors, + has_bias=has_bias) @property - def _num_sources(self): - return len(self._tensor_lists) + def _num_timesteps(self): + return self._num_uses @property - def _dtype(self): - return self._tensor_lists[0][0].dtype + def _var_scope(self): + return "ff_fc_multi_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + + (self._num_timesteps, self._has_bias,)) def make_covariance_update_op(self, ema_decay): op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) if self._cov_dt1 is not None: - new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) - for idx in range(self._num_sources)) + new_cov_dt1_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, + tower)) + + new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) + / float(self._num_towers)) - with maybe_colocate_with(new_cov_dt1_contribs[0]): - new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + # See comments in FisherFactor.make_covariance_update_op() for details. + if utils.on_tpu(): + new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) - op2 = moving_averages.assign_moving_average( - self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) - # TODO(b/69112164): - # It's important that _cov and _cov_dt1 remain consistent with each - # other while the inverse ops are happening. How can we ensure this? - # We will need to add explicit synchronization for this to - # work with asynchronous training. - op = control_flow_ops.group(op, op2) + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) return op - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._tensor_lists[idx]): - tensor = array_ops.concat(self._tensor_lists[idx], 0) - if self._has_bias: - tensor = append_homog(tensor) - # We save these so they can be used by _compute_new_cov_dt1 - self._tensors[idx] = tensor - return compute_cov(tensor) - - def _compute_new_cov_dt1(self, idx=0): - tensor = self._tensors[idx] - with maybe_colocate_with(tensor): - # Is there a more elegant way to do this computation? - tensor_present = tensor[:-self._batch_size, :] - tensor_future = tensor[self._batch_size:, :] - # We specify a normalizer for this computation to ensure a PSD Fisher - # block estimate. This is equivalent to padding with zeros, as was done - # in Section B.2 of the appendix. - normalizer = self._num_timesteps * self._batch_size - return compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=normalizer) + def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring + tensor = self._tensors[source][tower] + if self._has_bias: + # This appending is technically done twice (the other time is for + # _compute_new_cov()) + tensor = append_homog(tensor) - @property - def _cov_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias - return [size, size] + total_len = array_ops.shape(tensor)[0] + batch_size = total_len // self._num_timesteps + + tensor_present = tensor[:-batch_size, :] + tensor_future = tensor[batch_size:, :] + + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=total_len) + + def _get_data_device(self, tower): + return self._tensors[0][tower].device @property def _vec_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size] - def get_option1quants(self, damping): - return self._option1quants_by_damping[damping] + def get_option1quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option1quants_by_damping[damping_id] - def get_option2quants(self, damping): - return self._option2quants_by_damping[damping] + def get_option2quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option2quants_by_damping[damping_id] def get_cov_dt1(self): assert self._cov_dt1 is not None return self._cov_dt1 def register_cov_dt1(self): - """Create a variable representing temporal cross-covariance. + self._make_cov_dt1 = True - (This is technically the second moment, not covariance, since it's - not mean subtracted.) - """ - if self._cov_dt1 is None: + def instantiate_cov_variables(self): + super(FullyConnectedMultiKF, self).instantiate_cov_variables() + assert self._cov_dt1 is None + if self._make_cov_dt1: with variable_scope.variable_scope(self._var_scope): self._cov_dt1 = variable_scope.get_variable( "cov_dt1", @@ -1313,15 +1659,25 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - def register_option1quants(self, damping): + def register_option1quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option1quants_registrations: + self._option1quants_registrations.add(damping_id) - self.register_cov_dt1() + def register_option2quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option2quants_registrations: + self._option2quants_registrations.add(damping_id) - if damping not in self._option1quants_by_damping: + def instantiate_inv_variables(self): + super(FullyConnectedMultiKF, self).instantiate_inv_variables() + + for damping_id in self._option1quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. - damping_string = scalar_or_tensor_to_string(damping) with variable_scope.variable_scope(self._var_scope): Lmat = variable_scope.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), @@ -1336,17 +1692,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - self._option1quants_by_damping[damping] = (Lmat, psi) - - def register_option2quants(self, damping): + assert damping_id not in self._option1quants_by_damping + self._option1quants_by_damping[damping_id] = (Lmat, psi) - self.register_cov_dt1() - - if damping not in self._option2quants_by_damping: + for damping_id in self._option2quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. - damping_string = scalar_or_tensor_to_string(damping) with variable_scope.variable_scope(self._var_scope): Pmat = variable_scope.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), @@ -1367,14 +1721,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + assert damping_id not in self._option2quants_by_damping + self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" # TODO(b/69918258): Add correctness tests for this method. # pylint: disable=invalid-name - ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + ops = [] if (len(self._option1quants_by_damping) + len(self._option2quants_by_damping)): @@ -1395,8 +1750,11 @@ class FullyConnectedMultiKF(InverseProvidingFactor): # consistently, or are somehow read between or during the cov updates. # Can this possibly happen? Is there a way to prevent it? - for damping, (Lmat_var, - psi_var) in self._option1quants_by_damping.items(): + for damping_id, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) invsqrtC0 = math_ops.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) @@ -1421,8 +1779,11 @@ class FullyConnectedMultiKF(InverseProvidingFactor): ops.append(Lmat_var.assign(Lmat)) ops.append(psi_var.assign(psi)) - for damping, (Pmat_var, Kmat_var, - mu_var) in self._option2quants_by_damping.items(): + for damping_id, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) # compute C0^(-1/2) invsqrtC0 = math_ops.matmul( @@ -1463,6 +1824,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor): ops.append(Kmat_var.assign(Kmat)) ops.append(mu_var.assign(mu)) + ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() return [control_flow_ops.group(*ops)] # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index ce9005b9ce99a4efa5f2821c56e199dd2086482e..cbbfe7212c9d946d4b5bf3690796cb248f72e8d3 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from contextlib import contextmanager from functools import partial +import warnings import math import six @@ -59,6 +61,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + APPROX_KRONECKER_INDEP_NAME = "kron_indep" APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" @@ -71,10 +77,39 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { option=2) } -# Possible value for 'reuse' keyword argument. Sets 'reuse' to +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + +# Possible value for `reuse` keyword argument. Sets `reuse` to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" +_DEFAULT_LAYER_COLLECTION = None + + +def get_default_layer_collection(): + """Get default LayerCollection.""" + if _DEFAULT_LAYER_COLLECTION is None: + raise ValueError( + "Attempted to retrieve default LayerCollection when none is set. Use " + "LayerCollection.as_default().") + + return _DEFAULT_LAYER_COLLECTION + + +def set_default_layer_collection(layer_collection): + global _DEFAULT_LAYER_COLLECTION + + if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: + raise ValueError("Default LayerCollection is already set.") + + _DEFAULT_LAYER_COLLECTION = layer_collection + class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -130,11 +165,16 @@ class LayerCollection(object): fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. losses: a list of LossFunction objects. The loss to be optimized is their sum. + loss_colocation_ops: ops to colocate loss function evaluations with. These + will typically be the inputs to the losses. """ def __init__(self, graph=None, name="LayerCollection"): + warnings.warn( + "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. " + "Use https://pypi.python.org/pypi/kfac instead.") self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() self._linked_parameters = dict( @@ -142,20 +182,30 @@ class LayerCollection(object): self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None - self._default_generic_approximation = APPROX_FULL_NAME + self._default_generic_approximation = APPROX_DIAGONAL_NAME self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_SERIES_2_NAME) + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME + self.loss_colocation_ops = {} + self._vars_to_uses = defaultdict(lambda: 0) with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @property def losses(self): - """LossFunctions registered with this LayerCollection.""" - return list(self._loss_dict.values()) + """Tuple of LossFunction objects registered with this LayerCollection.""" + return nest.flatten(self.towers_by_loss) + + @property + def towers_by_loss(self): + """Tuple across losses of LossFunction objects registered to each tower.""" + return tuple(tuple(lst) for lst in self._loss_dict.values()) @property def registered_variables(self): @@ -214,14 +264,14 @@ class LayerCollection(object): @property def default_conv2d_approximation(self): - return self._default_convolution_2d_approximation + return self._default_conv2d_approximation def set_default_conv2d_approximation(self, value): if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) - self._default_convolution_2d_approximation = value + self._default_conv2d_approximation = value @property def default_fully_connected_multi_approximation(self): @@ -233,6 +283,14 @@ class LayerCollection(object): "multi layer.".format(value)) self._default_fully_connected_multi_approximation = value + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -240,8 +298,8 @@ class LayerCollection(object): layer_key: A variable or tuple of variables. The key to check for in existing registrations and to register if valid. fisher_block: The associated `FisherBlock`. - reuse: Method to use for inserting new `FisherBlock`s. One of True, False, - or 'VARIABLE_SCOPE'. + reuse: Method to use for inserting new `FisherBlock's. One of True, False, + or `VARIABLE_SCOPE`. Raises: ValueError: If `layer_key` was already registered and reuse is `False`, @@ -290,23 +348,73 @@ class LayerCollection(object): self.fisher_blocks[layer_key] = fisher_block return fisher_block - def get_use_count_map(self): - """Returns a dict of variables to their number of registrations.""" - # TODO(b/70283403): Reimplement this in the old way, where each - # registration function would be responsible for incrementing the count. - # Also, this version has a bug: it won't do the right thing for generic - # registration for parameters that are shared. i.e. it won't set the use - # count to infinity. - vars_to_uses = defaultdict(int) - for key, block in six.iteritems(self.fisher_blocks): - n = ( - block.num_inputs()*block.num_registered_minibatches if isinstance( - block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) - else block.num_registered_minibatches) - key = utils.ensure_sequence(key) - for k in key: - vars_to_uses[k] += n - return vars_to_uses + def register_loss_function(self, + loss, + colocation_op, + base_name, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a LossFunction object. + + Args: + loss: The LossFunction object. + colocation_op: The op to colocate the loss function's computations with. + base_name: The name to derive a new unique name from is the name argument + is None. + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional + tower for the existing loss function. + + Raises: + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with `name` found. + KeyError: If reuse == False and existing LossFunction with `name` found. + """ + + name = name or self._graph.unique_name(base_name) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + + loss_list = self._loss_dict.get(name, None) + + if loss_list is None: + raise KeyError( + "Unable to find loss function named {}. Register a new loss " + "function with reuse=False.".format(name)) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another tower.".format(name)) + + loss_list = [] + self._loss_dict[name] = loss_list + + loss_list.append(loss) + self.loss_colocation_ops[loss] = colocation_op + + def _get_use_count_map(self): + """Returns a dict mapping variables to their number of registrations.""" + return self._vars_to_uses + + def _add_uses(self, params, uses): + """Register additional uses by params in the graph. + + Args: + params: Variable or tuple of Variables. Parameters for a layer. + uses: int or float. Number of additional uses for these parameters. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + for var in params: + self._vars_to_uses[var] += uses def check_registration(self, variables): """Checks that all variable uses have been registered properly. @@ -324,7 +432,7 @@ class LayerCollection(object): # Note that overlapping parameters (i.e. those that share variables) will # be caught by layer_collection.LayerParametersDict during registration. - reg_use_map = self.get_use_count_map() + reg_use_map = self._get_use_count_map() error_messages = [] @@ -386,24 +494,24 @@ class LayerCollection(object): """ params = frozenset(utils.ensure_sequence(params)) - # Check if any of the variables in 'params' is already in - # 'self.fisher_blocks.keys()'. + # Check if any of the variables in `params` is already in + # 'self.fisher_blocks.keys()`. for registered_params, fisher_block in self.fisher_blocks.items(): registered_params_set = set(utils.ensure_sequence(registered_params)) for variable in params: if (variable in registered_params_set and params != registered_params_set): raise ValueError( - "Can't link parameters {}, variable {} was already registered in " + "Can`t link parameters {}, variable {} was already registered in " "group {} with layer {}".format(params, variable, registered_params, fisher_block)) - # Check if any of the variables in 'params' is already in - # 'self.linked_parameters'. + # Check if any of the variables in `params` is already in + # 'self.linked_parameters`. for variable in params: for other_linked_params in self.linked_parameters: if variable in other_linked_params: - raise ValueError("Can't link parameters {}, variable {} was already " + raise ValueError("Can`t link parameters {}, variable {} was already " "linked in group {}.".format(params, variable, other_linked_params)) self._linked_parameters[params] = approximation @@ -414,12 +522,27 @@ class LayerCollection(object): inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) self._subgraph = utils.SubGraph(inputs_to_losses) + def eval_losses(self): + """Return evaluated losses (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate()) + return evals + + def eval_losses_on_samples(self): + """Return losses evaluated on samples (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate_on_sample()) + return evals + def total_loss(self): - return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses)) + return math_ops.add_n(self.eval_losses()) def total_sampled_loss(self): - return math_ops.add_n( - tuple(loss.evaluate_on_sample() for loss in self.losses)) + return math_ops.add_n(self.eval_losses_on_samples()) def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" @@ -429,45 +552,56 @@ class LayerCollection(object): else: return None + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + def register_embedding(self, params, inputs, outputs, approx=None, reuse=VARIABLE_SCOPE): - """Registers a fully connnected layer. + """Registers an embedding layer. Args: params: Embedding matrix of shape [vocab_size, embedding_size]. inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. - outputs: Tensor of shape [batch_size, output_size]. Outputs + outputs: Tensor of shape [batch_size, embedding_size]. Outputs produced by layer. - approx: str. Must be "kron". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_embedding_approximation - - if approx != APPROX_KRONECKER_NAME: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) if isinstance(params, (tuple, list)): raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) block = self.register_block( - params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) def register_fully_connected(self, params, @@ -484,29 +618,31 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_approximation - if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias), + reuse=reuse) + block.register_additional_tower(inputs, outputs) - block = self.register_block(params, block_type(self, has_bias), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + self._add_uses(params, 1) def register_conv2d(self, params, @@ -514,44 +650,262 @@ class LayerCollection(object): padding, inputs, outputs, + data_format=None, + dilations=None, approx=None, reuse=VARIABLE_SCOPE): - """Registers a convolutional layer. + """Registers a call to tf.nn.conv2d(). Args: params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [kernel_height, kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. - strides: 1-D Tensor of length 4. Strides for convolution kernel. + strides: List of 4 ints. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. Output produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_conv2d_approximation + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) + + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can`t use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? + if approx == APPROX_KRONECKER_NAME: + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + elif approx == APPROX_DIAGONAL_NAME: + assert strides[0] == strides[-1] == 1 + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilations=dilations, + data_format=data_format), + reuse=reuse) + else: + raise NotImplementedError(approx) - if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_convolution(self, + params, + inputs, + outputs, + padding, + strides=None, + dilation_rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.convolution(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [..filter_spatial_size.., + in_channels, out_channels]. Bias should have shape [out_channels]. + inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. + Inputs to layer. + outputs: Tensor of shape [batch_size, ..output_spatial_size.., + out_channels]. Output produced by layer. + padding: string. see tf.nn.conv2d for valid values. + strides: List of ints of length len(..input_spatial_size..). Strides for + convolution kernel in spatial dimensions. + dilation_rate: List of ints of length len(..input_spatial_size..). + Dilations along spatial dimension. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_KRONECKER_NAME - block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block( - params, block_type(self, params, strides, padding), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + params, + fb.ConvKFCBasicFB( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilation_rate=dilation_rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_depthwise_conv2d(self, + params, + inputs, + outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.depthwise_conv2d(). + + Args: + params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Convolutional filter. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_height, output_width, + in_channels * channel_multiplier]. Output produced by depthwise conv2d. + strides: List of ints of length 4. Strides along all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rates in spatial + dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_DIAGONAL_NAME + assert data_format in [None, "NHWC"] + + block = self.register_block( + params, + fb.DepthwiseConvDiagonalFB( + layer_collection=self, + params=params, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_separable_conv2d(self, + depthwise_params, + pointwise_params, + inputs, + depthwise_outputs, + pointwise_outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.separable_conv2d(). + + Note: This requires access to intermediate outputs between depthwise and + pointwise convolutions. + + Args: + depthwise_params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Filter for depthwise conv2d. + pointwise_params: 4-D Tensor of shape [1, 1, in_channels * + channel_multiplier, out_channels]. Filter for pointwise conv2d. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + depthwise_outputs: Tensor of shape [batch_size, output_height, + output_width, in_channels * channel_multiplier]. Output produced by + depthwise conv2d. + pointwise_outputs: Tensor of shape [batch_size, output_height, + output_width, out_channels]. Output produced by pointwise conv2d. + strides: List of ints of length 4. Strides for depthwise conv2d kernel in + all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rate of depthwise conv2d + kernel in spatial dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + self.register_depthwise_conv2d( + params=depthwise_params, + inputs=inputs, + outputs=depthwise_outputs, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format, + approx=APPROX_DIAGONAL_NAME, + reuse=reuse) + + self.register_conv2d( + params=pointwise_params, + inputs=depthwise_outputs, + outputs=pointwise_outputs, + strides=[1, 1, 1, 1], + padding="VALID", + data_format=data_format, + approx=approx, + reuse=reuse) def register_generic(self, params, @@ -562,32 +916,32 @@ class LayerCollection(object): Args: params: Tensor or tuple of Tensors corresponding to the parameters. - batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of "full" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + batch_size: 0-D Tensor. Size of the minibatch (for this tower). + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `batch_size` to the total + mini-batch size use when estimating the Fisher block for this layer + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_generic_approximation - - if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - - block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) - block.register_additional_minibatch(batch_size) + block.register_additional_tower(batch_size) + + self._add_uses(params, float("inf")) def register_fully_connected_multi(self, params, inputs, outputs, - approx=None): + num_uses=None, approx=None, + reuse=VARIABLE_SCOPE): """Register fully connected layers with shared parameters. This can handle general fully-connected layers with shared parameters, but @@ -598,34 +952,194 @@ class LayerCollection(object): params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. - inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs - to layer. In the case of RNNs, one Tensor per time step. - outputs: A list of tensors, the same length as 'inputs', each of shape - [batch_size, output_size]. Outputs produced by layer. In the case of - RNNs, one Tensor per time step. - approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). OR, can be single Tensor, of + shape [num_uses * batch_size , input_size], which is a reshaped version + of a Tensor of shape [num_uses, batch_size, input_size]. + outputs: A list of Tensors, the same length as `inputs`, each of shape + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in `inputs`. OR, can be + a single Tensor of shape [num_uses * batch_size, output_size], which is + a reshaped version of a Tensor of shape [num_uses, batch_size, + output_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. + ValueError: For improper value to `approx`. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_multi_approximation - has_bias = isinstance(params, (tuple, list)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) # TODO(b/70283649): something along the lines of find_canonical_output # should be added back in here (and for the other block types, arguably). - if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias, + num_uses=num_uses), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + num_uses=None, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). OR, can be single + Tensor, of shape [num_uses * batch_size, height, width, in_channels], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + height, width, in_channels]. + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, height, width, + out_channels], which is a reshaped version of a Tensor of shape + [num_uses, batch_size, height, width, out_channels]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches", + num_uses=num_uses), + reuse=reuse) + + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + # TODO(b/74108452): change the loss registration functions names to refer + # to "loss functions" instead of distributions. Following naming convention + # of the loss function classes themselves. + + def register_embedding_multi(self, + params, + inputs, + outputs, + num_uses=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + input_size]. + outputs: A list of Tensors, each of shape [batch_size, embedding_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, embedding_size], which + is a reshaped version of a Tensor of shape [num_uses, batch_size, + embedding_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) + block.register_additional_tower(inputs, outputs) - # For now we don't support multiple minibatches for this type of layer, so - # we set reuse=False - self.register_block(params, - block_type(self, inputs, outputs, has_bias=has_bias), - reuse=False) + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_categorical_predictive_distribution(self, logits, @@ -645,53 +1159,24 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. - If False, create a new FisherBlock. If VARIABLE_SCOPE, use - tf.get_variable_scope().reuse. - - Raises: - ValueError: If reuse == True and name == None. - ValueError: If reuse == True and seed != None. - KeyError: If reuse == True and no existing LossFunction with 'name' found. - KeyError: If reuse == False and existing LossFunction with 'name' found. + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_categorical_predictive_distribution") - - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - if name is None: - raise ValueError( - "If reuse is enabled, loss function's name must be set.") - if seed is not None: - raise ValueError( - "Seed can only be specified at LossFunction instantiation.") - - loss = self._loss_dict.get(name, None) - - if loss is None: - raise KeyError( - "Unable to find loss function named {}. Create a new LossFunction " - "with reuse=False.".format(name)) - - loss.register_additional_minibatch(logits, targets=targets) - else: - if name in self._loss_dict: - raise KeyError( - "Loss function named {} already exists. Set reuse=True to append " - "another minibatch.".format(name)) - loss = lf.CategoricalLogitsNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "categorical_predictive_distribution", + name=name, reuse=reuse) def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None, - name=None): + name=None, + reuse=VARIABLE_SCOPE): """Registers a normal predictive distribution. Args: @@ -708,21 +1193,23 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `mean` and `var` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_normal_predictive_distribution") - if name in self._loss_dict: - raise NotImplementedError( - "Adding logits to an existing LossFunction not yet supported.") - loss = lf.NormalMeanNegativeLogProbLoss( - mean, var, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, + seed=seed) + self.register_loss_function(loss, mean, + "normal_predictive_distribution", + name=name, reuse=reuse) def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, targets=None, - name=None): + name=None, + reuse=VARIABLE_SCOPE): """Registers a multi-Bernoulli predictive distribution. Args: @@ -735,29 +1222,30 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_multi_bernoulli_predictive_distribution") - if name in self._loss_dict: - raise NotImplementedError( - "Adding logits to an existing LossFunction not yet supported.") - loss = lf.MultiBernoulliNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "multi_bernoulli_predictive_distribution", + name=name, reuse=reuse) def make_or_get_factor(self, cls, args): - """Insert 'cls(args)' into 'self.fisher_factors' if not already present. + """Insert `cls(args)` into 'self.fisher_factors` if not already present. - Wraps constructor in 'tf.variable_scope()' to ensure variables constructed - in 'cls.__init__' are placed under this LayerCollection's scope. + Wraps constructor in `tf.variable_scope()` to ensure variables constructed + in `cls.__init__` are placed under this LayerCollection's scope. Args: cls: Class that implements FisherFactor. - args: Tuple of arguments to pass into 'cls's constructor. Must be + args: Tuple of arguments to pass into `cls's constructor. Must be hashable. Returns: - Instance of 'cls' found in self.fisher_factors. + Instance of `cls` found in self.fisher_factors. """ try: hash(args) @@ -772,3 +1260,10 @@ class LayerCollection(object): with variable_scope.variable_scope(self._var_scope): self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] + + @contextmanager + def as_default(self): + """Sets this LayerCollection as the default.""" + set_default_layer_collection(self) + yield + set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index f8aa230d9ca1f542950f56b1e6cf1ab7ccd3d05f..9f4685380705bd409dbcd7e85d0e3bb4189a6adc 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "get_default_layer_collection", + "set_default_layer_collection", "LayerParametersDict", "LayerCollection", "APPROX_KRONECKER_NAME", diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..61cb955ae85df9e56cbe165acba98ece750cba90 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SmartMatrices definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.ops.linalg import linear_operator_util as lou + + +class LinearOperatorExtras(object): # pylint: disable=missing-docstring + + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + if isinstance(x, ops.IndexedSlices): + return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -2 if adjoint else -1 + arg_dim = -1 if adjoint_arg else -2 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + + if isinstance(x, ops.IndexedSlices): + return self._matmul_right_sparse( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -1 if adjoint else -2 + arg_dim = -2 if adjoint_arg else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + +class LinearOperatorFullMatrix(LinearOperatorExtras, + linalg.LinearOperatorFullMatrix): + + # TODO(b/78117889) Remove this definition once core LinearOperator + # has _matmul_right. + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + return lou.matmul_with_broadcast( + x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + assert not adjoint and not adjoint_arg + return utils.matmul_sparse_dense(x, self._matrix) + + +class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring + linalg.LinearOperatorDiag): + + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + x = linalg_impl.adjoint(x) if adjoint_arg else x + return diag_mat * x + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + assert not adjoint_arg + return utils.matmul_diag_sparse(diag_mat, x) + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index cb3e698b9ceab920785adf735f88bd8e535a628f..42d525c2c21f5ba3457cba041261dc3b225dc11e 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -57,30 +57,6 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass - @property - def input_minibatches(self): - """A `list` of inputs to the loss function, separated by minibatch. - - Typically there will be one minibatch per tower in a multi-tower setup. - Returns a list consisting of `self.inputs` by default; `LossFunction`s - supporting registering multiple minibatches should override this method. - - Returns: - A `list` of `Tensor`s representing - """ - return [self.inputs] - - @property - def num_registered_minibatches(self): - """Number of minibatches registered for this LossFunction. - - Typically equal to the number of towers in a multi-tower setup. - - Returns: - An `int` representing the number of registered minibatches. - """ - return len(self.input_minibatches) - def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -474,7 +450,6 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): assert len(variance.shape) == 2, "Expect 2D variance tensor." self._mean = mean self._variance = variance - self._scale = math_ops.sqrt(variance) self._targets = targets super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) @@ -484,7 +459,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def dist(self): - return normal.Normal(loc=self._mean, scale=self._scale) + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) @property def params(self): @@ -502,7 +477,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean_factor(self): - return 1. / self._scale + return 1. / math_ops.sqrt(self._variance) @property def _fisher_var(self): @@ -611,36 +586,13 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, index in [0, output_size). seed: int or None. Default random seed when sampling. """ - self._logits_components = [] - self._targets_components = [] - self.register_additional_minibatch(logits, targets=targets) + self._logits = logits + self._targets = targets super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) - def register_additional_minibatch(self, logits, targets=None): - """Register an additiona minibatch's worth of parameters. - - Args: - logits: Tensor of shape [batch_size, output_size]. Parameters for - underlying distribution. - targets: None or Tensor of shape [batch_size, output_size]. Each row must - be a one-hot vector. - """ - self._logits_components.append(logits) - self._targets_components.append(targets) - - @property - def _logits(self): - return array_ops.concat(self._logits_components, axis=0) - - @property - def input_minibatches(self): - return self._logits_components - @property def targets(self): - if all(target is None for target in self._targets_components): - return None - return array_ops.concat(self._targets_components, axis=0) + return self._targets @property def dist(self): @@ -661,19 +613,19 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def multiply_fisher(self, vector): probs = self._probs return vector * probs - probs * math_ops.reduce_sum( - vector * probs, axis=-1, keep_dims=True) + vector * probs, axis=-1, keepdims=True) def multiply_fisher_factor(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=-1, keep_dims=True) + sqrt_probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_transpose(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=-1, keep_dims=True) + probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py index 705a871d482565897e7ac850327729a6186f1746..4279cb2792854249e3e076d200e2656bc615779d 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -33,7 +33,6 @@ _allowed_symbols = [ "CategoricalLogitsNegativeLogProbLoss", "OnehotCategoricalLogitsNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", - "MultiBernoulliNegativeLogProbLoss", "insert_slice_in_zeros", ] diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 1974b07acfc879dc4bc844db9af88fd1043d6698..03b9da793307b966632789fd11162306e6cd19f9 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,16 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est # pylint enable=long-line +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import gradient_descent @@ -47,8 +52,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): name="KFAC", estimation_mode="gradients", colocate_gradients_with_ops=True, - cov_devices=None, - inv_devices=None): + batch_size=None, + placement_strategy=None, + **kwargs): """Initializes the KFAC optimizer with the given settings. Args: @@ -61,6 +67,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. + If damping is adapted during training then this value is used for + initializing damping variable. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher @@ -86,12 +94,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. + batch_size: The size of the mini-batch. Only needed when momentum_type + == 'qmodel' or when automatic adjustment is used. (Default: None) + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passesd to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: ValueError: If the momentum type is unsupported. @@ -100,20 +109,36 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ - - variables = var_list - if variables is None: - variables = tf_variables.trainable_variables() - - self._fisher_est = est.FisherEstimator( - variables, - cov_ema_decay, - damping, - layer_collection, - estimation_mode=estimation_mode, - colocate_gradients_with_ops=colocate_gradients_with_ops, - cov_devices=cov_devices, - inv_devices=inv_devices) + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) + # Parameters to be passed to the Fisher estimator: + self._variables = var_list or tf_variables.trainable_variables + self._cov_ema_decay = cov_ema_decay + self._layers = layer_collection + self._estimation_mode = estimation_mode + self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # The below parameters are required only if damping needs to be adapated. + # These parameters can be set by calling + # set_damping_adaptation_params() explicitly. + self._damping_adaptation_decay = 0.95 + self._damping_adaptation_interval = 5 + # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._adapt_damping = False + self._min_damping = 1e-5 + self._prev_train_batch = None + self._is_chief = False + self._loss_fn = None + self._damping_constant = damping + self._damping = None + self._rho = None + self._prev_loss = None + self._q_model_change = None + self._update_damping_op = None momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] @@ -122,46 +147,91 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): raise ValueError("Unsupported momentum type {}. Must be one of {}." .format(momentum_type, legal_momentum_types)) if momentum_type != "regular" and norm_constraint is not None: - raise ValueError("Update clipping is only supported with momentum" + raise ValueError("Update clipping is only supported with momentum " "type 'regular'.") if momentum_type not in ["regular", "adam"] and momentum != 0: raise ValueError("Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") + # Extra parameters of the optimizer self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint - - # this is a bit of a hack - # TODO(duckworthd): Handle this in a better way (e.g. pass it in?) - self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] - self._losses = layer_collection.losses + self._batch_size = batch_size + self._placement_strategy = placement_strategy + + with variable_scope.variable_scope(name): + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, + exps=(-1,), + estimation_mode=self._estimation_mode, + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) super(KfacOptimizer, self).__init__(learning_rate, name=name) - @property - def cov_update_thunks(self): - return self._fisher_est.cov_update_thunks - - @property - def cov_update_ops(self): - return self._fisher_est.cov_update_ops + def set_damping_adaptation_params(self, + is_chief, + prev_train_batch, + loss_fn, + min_damping=1e-5, + damping_adaptation_decay=0.99, + damping_adaptation_interval=5): + """Sets parameters required to adapt damping during training. - @property - def cov_update_op(self): - return self._fisher_est.cov_update_op + When called, enables damping adaptation according to the Levenberg-Marquardt + style rule described in Section 6.5 of "Optimizing Neural Networks with + Kronecker-factored Approximate Curvature". - @property - def inv_update_thunks(self): - return self._fisher_est.inv_update_thunks + Note that this function creates Tensorflow variables which store a few + scalars and are accessed by the ops which update the damping (as part + of the training op returned by the minimize() method). - @property - def inv_update_ops(self): - return self._fisher_est.inv_update_ops + Args: + is_chief: `Boolean`, `True` if the worker is chief. + prev_train_batch: Training data used to minimize loss in the previous + step. This will be used to evaluate loss by calling + `loss_fn(prev_train_batch)`. + loss_fn: `function` that takes as input training data tensor and returns + a scalar loss. + min_damping: `float`(Optional), Minimum value the damping parameter + can take. Default value 1e-5. + damping_adaptation_decay: `float`(Optional), The `damping` parameter is + multiplied by the `damping_adaptation_decay` every + `damping_adaptation_interval` number of iterations. Default value 0.99. + damping_adaptation_interval: `int`(Optional), Number of steps in between + updating the `damping` parameter. Default value 5. - @property - def inv_update_op(self): - return self._fisher_est.inv_update_op + Raises: + ValueError: If `set_damping_adaptation_params` is already called and the + the `adapt_damping` is `True`. + """ + if self._adapt_damping: + raise ValueError("Damping adaptation parameters already set.") + + with variable_scope.variable_scope(self.get_name()): + self._adapt_damping = True + self._is_chief = is_chief + self._prev_train_batch = prev_train_batch + self._loss_fn = loss_fn + self._damping_adaptation_decay = damping_adaptation_decay + self._damping_adaptation_interval = damping_adaptation_interval + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._min_damping = min_damping + + self._rho = variable_scope.get_variable( + "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. + self._prev_loss = variable_scope.get_variable( + "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) + self._q_model_change = variable_scope.get_variable( + "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) + self._damping = variable_scope.get_variable( + "damping", initializer=self._damping_constant, trainable=False) @property def variables(self): @@ -169,14 +239,76 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): @property def damping(self): - return self._fisher_est.damping + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) + + def create_ops_and_vars_thunks(self): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) def minimize(self, *args, **kwargs): - kwargs["var_list"] = kwargs.get("var_list") or self.variables - if set(kwargs["var_list"]) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + # Should this variable scope encompass everything below? Or will the super- + # class make another copy of the same name scope? + with variable_scope.variable_scope(self.get_name()): + kwargs["var_list"] = kwargs.get("var_list") or self.variables + if set(kwargs["var_list"]) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + if self._adapt_damping and self._is_chief: + global_step = kwargs.get("global_step", None) + if not global_step: + raise KeyError("global_step needs to be passed to optimizer.minimize " + "if damping parameter is adapted.") + update_damping_op = self._update_damping(self._prev_train_batch, + global_step) + with ops.control_dependencies([update_damping_op]): + loss = args[0] + loss_assign_op = state_ops.assign(self._prev_loss, loss) + train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) + return control_flow_ops.group(loss_assign_op, train_op) + else: + return super(KfacOptimizer, self).minimize(*args, **kwargs) def compute_gradients(self, *args, **kwargs): # args[1] could be our var_list @@ -185,6 +317,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): else: kwargs["var_list"] = kwargs.get("var_list") or self.variables var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") @@ -296,6 +429,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] + def _compute_prev_updates(self, variables): + """Computes previous updates as negative velocities scaled by learning rate. + + Args: + variables: List of variables in the graph that the update will be + applied to. + + Returns: + List of previous updates applied to the `variables`. + """ + return list( + -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) + for var in variables) + def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, variables): """Compute optimal update hyperparameters from the quadratic model. @@ -336,12 +483,12 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). """ - cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables) + cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, + variables) # compute the matrix-vector products with the transposed Fisher factor fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( self._batch_size, dtype=fft_precon_grads[0].dtype) @@ -374,9 +521,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], [_inner_product_list(grads, prev_updates)]]) - sol = _two_by_two_solve(m, c) - alpha = -sol[0] - mu = -sol[1] + sol = -1. * _two_by_two_solve(m, c) + alpha = sol[0] + mu = sol[1] qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) return alpha, mu, qmodel_change @@ -404,6 +551,52 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return control_flow_ops.cond( math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) + def _assign_q_model_change(self, q_model_change): + """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. + + Note only the chief worker does the assignment. + + Args: + q_model_change: Scalar tensor of type `float32`. + + Returns: + If `adapt_damping` is `True` then returns an assign op, Otherwise returns + a no_op(). + """ + if self._adapt_damping and self._is_chief: + q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) + else: + q_model_assign_op = control_flow_ops.no_op() + return q_model_assign_op + + def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, + precon_grads_and_vars): + """Wrapper function for `self._compute_qmodel_hyperparams`. + + Constructs a list of preconditioned gradients and variables. Also creates a + op to asssign the computed q model change to `self._q_model_change`. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradients, variable) + pairs. + + Returns: + (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize + the quadratic model, `q_model_assign_op` assigns the computed q model + change to `self._q_model_change`. + """ + precon_grads = list( + precon_grad for (precon_grad, _) in precon_grads_and_vars) + grads = list(grad for (grad, _) in grads_and_vars) + variables = list(var for (_, var) in grads_and_vars) + prev_updates = self._compute_prev_updates(variables) + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_change = self._compute_qmodel_hyperparams( + precon_grads, prev_updates, grads, variables) + + return alpha, mu, self._assign_q_model_change(q_model_change) + def _compute_update_steps(self, grads_and_vars): """Computes the update steps for the variables given the gradients. @@ -411,8 +604,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): grads_and_vars: List of (gradient, variable) pairs. Returns: - An 'Operation that computes the update steps for the given variables. + A list of tuple (assign_op ,var) where `assign_op` assigns the update + steps to `var`. """ + if self._momentum_type == "regular": # Compute "preconditioned" gradient. precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) @@ -423,8 +618,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): precon_grads_and_vars) # Update the velocity with this and return it as the step. - return self._update_velocities(precon_grads_and_vars, self._momentum) - + if self._adapt_damping and self._is_chief: + _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities(precon_grads_and_vars, self._momentum) + else: + return self._update_velocities(precon_grads_and_vars, self._momentum) elif self._momentum_type == "adam": # Update velocity. velocities_and_vars = self._update_velocities(grads_and_vars, @@ -436,23 +636,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Compute "preconditioned" gradient. precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - # Extract out singleton lists from the tuple-lists - precon_grads = list( - precon_grad for (precon_grad, _) in precon_grads_and_vars) - grads = list(grad for (grad, _) in grads_and_vars) - variables = list(var for (_, var) in grads_and_vars) - # previous updates are the negative velocities (up to scaling by LR) - prev_updates = list( - -self._zeros_slot(var, "velocity", self._name) for var in variables) - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, _ = self._compute_qmodel_hyperparams( - precon_grads, prev_updates, grads, variables) + alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) - # Update the velocity with precon_grads according to these params - # and return it as the step. - return self._update_velocities( - precon_grads_and_vars, mu, vec_coeff=-alpha) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities( + precon_grads_and_vars, mu, vec_coeff=-alpha) def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): """Updates the velocities of the variables with the given vectors. @@ -482,6 +672,50 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] + def _update_damping(self, prev_batch, global_step): + """Adapts damping parameter. Check KFAC (Section 6.5) for the details. + + The damping parameter is updated according to the Levenberg-Marquardt rule + every `self._damping_adaptation_interval` iterations. + + Args: + prev_batch: Tensor or tuple of tensors which can be passed to + `self._loss_fn` to evaluate loss. + global_step: `Variable` which keeps track of number of times the training + variables have been updated. + Returns: + A `tf.cond` op which updates the damping parameter. + """ + def compute_damping(): + """"Adapts damping parameter based on "reduction ratio". + + Reduction ratio captures how closely the quadratic approximation to the + loss function approximates the actual loss within a trust region. The + damping update tries to make the damping as small as possible while + maintaining the property that the quadratic model remains a good local + approximation to the loss function. + + Returns: + An Op to assign newly computed damping value to `self._damping`. + """ + prev_batch_loss = self._loss_fn(prev_batch) + with ops.control_dependencies([prev_batch_loss]): + rho_assign = self._rho.assign( + (prev_batch_loss - self._prev_loss) / self._q_model_change) + with ops.control_dependencies([rho_assign]): + new_damping = control_flow_ops.case( + [(self._rho < 0.25, lambda: self.damping / self._omega), + (self._rho > 0.75, lambda: self.damping * self._omega)], + lambda: self.damping) + with ops.control_dependencies([new_damping]): + new_damping_min = math_ops.maximum(new_damping, self._min_damping) + return control_flow_ops.group(self._damping.assign(new_damping_min)) + + return control_flow_ops.cond( + math_ops.equal( + math_ops.mod(global_step + 1, self._damping_adaptation_interval), + 0), compute_damping, control_flow_ops.no_op) + def _inner_product_list(list1, list2): return math_ops.add_n( diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000000000000000000000000000000000..c4454325aebe131058282ff15c2734bf10d1cc49 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,114 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + **kwargs: Need something here? + + """ + super(RoundRobinPlacementMixin, self).__init__(**kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement start. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index f5bd97cb4e7d547394050e944f75b43a40887f34..144295f4c7e36f61b4bae4178a6f57f6657204c5 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -24,11 +24,13 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -233,6 +235,13 @@ posdef_eig_functions = { } +def cholesky(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return linalg_ops.cholesky(tensor + damping * identity) + + class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ @@ -241,19 +250,22 @@ class SubGraph(object): # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() - self._recurse_add(outputs) + self._iter_add(outputs) - def _recurse_add(self, nodes): - """Recursively adds all of nodes' ancestors.""" - for node in nodes: - if node in self._members: - continue - self._members.add(node) + def _iter_add(self, root): + """Iteratively adds all of nodes' ancestors using depth first search.""" + stack = [root] + while stack: + nodes = stack.pop() + for node in nodes: + if node in self._members: + continue + self._members.add(node) - if isinstance(node, ops.Tensor): - self._recurse_add((node.op,)) - elif isinstance(node, ops.Operation): - self._recurse_add(node.inputs) + if isinstance(node, ops.Tensor): + stack.append((node.op,)) + elif isinstance(node, ops.Operation): + stack.append(node.inputs) def is_member(self, node): """Check if 'node' is in this subgraph.""" @@ -427,13 +439,138 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result -def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name +def extract_convolution_patches(inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): + """Extracts inputs to each output coordinate in tf.nn.convolution. + + This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), + where the number of spatial dimensions may be something other than 2. + + Assumes, + - First dimension of inputs is batch_size + - Convolution filter is applied to all input channels. + + Args: + inputs: Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). + filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). + padding: string. Padding method. One of "VALID", "SAME". + strides: None or list of ints. Strides along spatial dimensions. + dilation_rate: None or list of ints. Dilation along spatial dimensions. + name: None or str. Name of Op. + data_format: None or str. Format of data. + + Returns: + Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: If data_format does not put channel last. + ValueError: If inputs and filter disagree on in_channels. + """ + if not is_data_format_channel_last(data_format): + raise ValueError("Channel must be last dimension.") + with ops.name_scope(name, "extract_convolution_patches", + [inputs, filter_shape, padding, strides, dilation_rate]): + batch_size = inputs.shape.as_list()[0] + in_channels = inputs.shape.as_list()[-1] + + # filter_shape = spatial_filter_shape + [in_channels, out_channels] + spatial_filter_shape = filter_shape[:-2] + if in_channels != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + + # Map each input feature to a location in the output. + out_channels = np.prod(spatial_filter_shape) * in_channels + filters = linalg_ops.eye(out_channels) + filters = array_ops.reshape( + filters, + list(spatial_filter_shape) + [in_channels, out_channels]) + + result = nn_ops.convolution( + inputs, + filters, + padding=padding, + strides=strides, + dilation_rate=dilation_rate) + spatial_output_shape = result.shape.as_list()[1:-1] + result = array_ops.reshape(result, + [batch_size or -1] + spatial_output_shape + + list(spatial_filter_shape) + [in_channels]) + + return result + + +def extract_pointwise_conv2d_patches(inputs, + filter_shape, + name=None, + data_format=None): + """Extract patches for a 1x1 conv2d. + + Args: + inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. + filter_shape: List of 4 ints. Shape of filter to apply with conv2d() + name: None or str. Name for Op. + data_format: None or str. Format for data. See 'data_format' in + tf.nn.conv2d() for details. + + Returns: + Tensor of shape [batch_size, ..spatial_input_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: if inputs is not 4-D. + ValueError: if filter_shape is not [1, 1, ?, ?] + ValueError: if data_format is not channels-last. + """ + if inputs.shape.ndims != 4: + raise ValueError("inputs must have 4 dims.") + if len(filter_shape) != 4: + raise ValueError("filter_shape must have 4 dims.") + if filter_shape[0] != 1 or filter_shape[1] != 1: + raise ValueError("filter_shape must have shape 1 along spatial dimensions.") + if not is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels last.") + with ops.name_scope(name, "extract_pointwise_conv2d_patches", + [inputs, filter_shape]): + ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. + strides = [1, 1, 1, 1] # Operate on all pixels. + rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. + padding = "VALID" # Doesn't matter. + result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, + padding) + + batch_size, input_height, input_width, in_channels = inputs.shape.as_list() + filter_height, filter_width, in_channels, _ = filter_shape + return array_ops.reshape(result, [ + batch_size, input_height, input_width, filter_height, filter_width, + in_channels + ]) + + +def is_data_format_channel_last(data_format): + """True if data_format puts channel last.""" + if data_format is None: + return True + return data_format.endswith("C") + + +def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. Args: A: tf.IndexedSlices with dense shape [m, n]. B: tf.Tensor with shape [n, k]. name: str. Name of op. + transpose_a: Bool. If true we transpose A before multiplying it by B. + (Default: False) + transpose_b: Bool. If true we transpose B before multiplying it by A. + (Default: False) Returns: tf.IndexedSlices resulting from matmul(A, B). @@ -447,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name raise ValueError("A must represent a matrix. Found: %s." % A) if B.shape.ndims != 2: raise ValueError("B must be a matrix.") - new_values = math_ops.matmul(A.values, B) + new_values = math_ops.matmul( + A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) return ops.IndexedSlices( new_values, A.indices, @@ -479,5 +617,93 @@ def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + +class PartitionedTensor(object): + """A Tensor partitioned across its 0-th dimension.""" + + def __init__(self, tensors): + """Initializes PartitionedTensor. + + Args: + tensors: List of Tensors. All Tensors must agree on shape (excepting + batch dimension) and dtype. + + Raises: + ValueError: If 'tensors' has length zero. + ValueError: if contents of 'tensors' don't agree on shape or dtype. + """ + if not tensors: + raise ValueError("tensors must be a list of 1+ Tensors.") + + dtype = tensors[0].dtype + if not all(tensor.dtype == dtype for tensor in tensors): + raise ValueError("all tensors must have dtype = %s." % dtype) + + shape = tensors[0].shape[1:] + if not all(tensor.shape[1:] == shape for tensor in tensors): + raise ValueError("All tensors must have shape = %s (excluding batch " + "dimension)." % shape) + + self.tensors = tensors + self._concats = {} # {device: Tensor} + + @property + def shape(self): + feature_shape = self.tensors[0].shape[1:] + batch_size = sum([tensor.shape[0] for tensor in self.tensors], + tensor_shape.Dimension(0)) + return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) + + def get_shape(self): + return self.shape + + @property + def dtype(self): + return self.tensors[0].dtype + + def __str__(self): + return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( + self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) + + def __hash__(self): + return hash(tuple(self.tensors)) + + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + + def __getitem__(self, key): + return self.as_tensor()[key] + + def as_tensor(self, dtype=None, name=None, as_ref=False): + with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): + assert not as_ref + assert dtype in [None, self.dtype] + result = array_ops.concat(self.tensors, axis=0) + + # Cache 'result' if we haven't already cached a value for this device. + if result.device not in self._concats: + self._concats[result.device] = result + return self._concats[result.device] + + @property + def device(self): + # PartitionedTensors in general do not live on a single device. If the + # device cannot be determined unambiguously this property will return None. + device = self.tensors[0].device + if all(tensor.device == device for tensor in self.tensors): + return device + return None + + +ops.register_tensor_conversion_function( + PartitionedTensor, + lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 8e424a794691484fdea7d8481677aa641c433d4c..330d222dbf70fcfa02ffd47261c0513d9dd6e0e9 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -40,6 +40,9 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "extract_convolution_patches", + "extract_pointwise_conv2d_patches", + "is_data_format_channel_last", "matmul_sparse_dense", "matmul_diag_sparse", ] diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 894e6f6946bb59810a9da2d304cc0dd43d25201d..c8812d4b23f94102d093db878a709b090a3318d6 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -70,6 +70,7 @@ py_test( "python/ops/core_test.py", ], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":_typecheck", ":core", @@ -213,14 +214,3 @@ py_test( "//tensorflow/python:math_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py index 0727f4cf88728dc3d919e662d65c93a658ac730b..39e9d65407f3b1e79804317023ea03dd81484ff5 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py @@ -660,7 +660,7 @@ class ReduceSumTest(Base): sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')}) golden_lt = core.LabeledTensor( math_ops.reduce_sum( - self.original_lt.tensor, 1, keep_dims=True), + self.original_lt.tensor, 1, keepdims=True), [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) self.assertLabeledTensorsEqual(sum_lt, golden_lt) @@ -668,7 +668,7 @@ class ReduceSumTest(Base): sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou')) golden_lt = core.LabeledTensor( math_ops.reduce_sum( - self.original_lt.tensor, 1, keep_dims=True), + self.original_lt.tensor, 1, keepdims=True), [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) self.assertLabeledTensorsEqual(sum_lt, golden_lt) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 852d06e1e3cc8f8deecd15b7436cd4e4a393ad66..7355a403aeef78cc7e76d58adfe114e4729f6595 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -188,6 +188,7 @@ py_test( size = "small", srcs = ["python/layers/normalization_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":layers_py", "//tensorflow/contrib/framework:framework_py", @@ -353,6 +354,7 @@ py_test( size = "small", srcs = ["python/ops/sparse_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":layers_py", "//tensorflow/python:array_ops", @@ -379,7 +381,7 @@ py_test( py_test( name = "rev_block_lib_test", - size = "small", + size = "medium", srcs = ["python/layers/rev_block_lib_test.py"], srcs_version = "PY2AND3", deps = [ @@ -390,15 +392,3 @@ py_test( "//tensorflow/python:variables", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 337c9e06b870b2cca53fcdbf3d94225660e193c4..bc3359693562deb1229a78a2db5c256c76f7fd8d 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide. @@avg_pool2d @@avg_pool3d @@batch_norm +@@convolution +@@convolution1d @@convolution2d @@convolution3d @@conv2d_in_plane @@ -104,6 +106,7 @@ See the @{$python/contrib.layers} guide. @@infer_real_valued_columns @@sequence_input_from_feature_columns +@@group_norm @@instance_norm """ @@ -122,6 +125,7 @@ _allowed_symbols = ['bias_add', 'conv3d', 'elu', 'feature_column', + 'group_norm', 'instance_norm', 'legacy_fully_connected', 'legacy_linear', diff --git a/tensorflow/contrib/layers/kernels/BUILD b/tensorflow/contrib/layers/kernels/BUILD index e407a9ce015603094c7bbab72856403e2f0eb1a1..7aae09ff3e9995b2d92b05211b3bf8a94a26ff43 100644 --- a/tensorflow/contrib/layers/kernels/BUILD +++ b/tensorflow/contrib/layers/kernels/BUILD @@ -18,14 +18,3 @@ cc_library( ], alwayslink = 1, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py index f701647c2b297015f025eb53bd191a1a8c54ec62..28ddaa69a14776e0c157c2e68105ee9e17bc3cbb 100644 --- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py +++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py @@ -200,7 +200,7 @@ class SparseCrossOpTest(test.TestCase): self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_large_batch(self): - """Tests with large batch size to force multithreding. + """Tests with large batch size to force multithreading. """ batch_size = 5000 col1 = [] diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index b62e3050cd7003f1ba72061b133ff9b5d6b616da..60e1d85ea9c08a51763fdaf08853f8d9b67347e5 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -140,6 +140,9 @@ def safe_embedding_lookup_sparse(embedding_weights, # Prune invalid ids and weights. sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) + if combiner != "sum": + sparse_ids, sparse_weights = _prune_invalid_weights( + sparse_ids, sparse_weights) # Fill in dummy values for empty features, if necessary. sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, @@ -188,13 +191,23 @@ def _prune_invalid_ids(sparse_ids, sparse_weights): is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) if sparse_weights is not None: is_id_valid = math_ops.logical_and( - is_id_valid, math_ops.greater(sparse_weights.values, 0)) + is_id_valid, + array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) if sparse_weights is not None: sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) return sparse_ids, sparse_weights +def _prune_invalid_weights(sparse_ids, sparse_weights): + """Prune invalid weights (< 0) from the input ids and weights.""" + if sparse_weights is not None: + is_weights_valid = math_ops.greater(sparse_weights.values, 0) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) + return sparse_ids, sparse_weights + + def scattered_embedding_lookup(params, values, dimension, @@ -445,7 +458,7 @@ def scattered_embedding_lookup_sparse(params, return embeddings -def embedding_lookup_unique(params, ids, name=None): +def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): """Version of embedding_lookup that avoids duplicate lookups. This can save communication in the case of repeated ids. @@ -457,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None): `PartitionedVariable`. Shape `[index, d1, d2, ...]`. ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for this operation (optional). Returns: @@ -470,9 +486,10 @@ def embedding_lookup_unique(params, ids, name=None): ids = ops.convert_to_tensor(ids) shape = array_ops.shape(ids) ids_flat = array_ops.reshape( - ids, math_ops.reduce_prod(shape, keep_dims=True)) + ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) - unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) + unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, + partition_strategy) embeds_flat = array_ops.gather(unique_embeddings, idx) embed_shape = array_ops.concat( [shape, array_ops.shape(unique_embeddings)[1:]], 0) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index bf2514498202e9227c2d74c036c7eecba5ccdf2c..dd2395f8c9748dadbecfe47df5511874d5f848ea 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops @@ -691,11 +692,12 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): index += num_val return grouped_vals + @test_util.enable_c_shapes def testEmbeddingLookupSparse(self): vocab_size = 13 batch_size = 10 param_shape = [2, 5] - expected_lookup_result_shape = [None] + param_shape + expected_lookup_result_shape = param_shape sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( self._RandomIdsAndWeights(batch_size, vocab_size)) @@ -719,7 +721,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): None if ignore_weights else sp_weights, combiner=combiner) - self.assertEqual(embedding_sum.get_shape().as_list(), + self.assertEqual(embedding_sum.get_shape().as_list()[1:], expected_lookup_result_shape) tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py index 89c9d37bd09cb6c43eebb91f3a16600eae9cb490..f42112206d0db9d2e42bd4cff19f6a6533951d46 100644 --- a/tensorflow/contrib/layers/python/layers/encoders.py +++ b/tensorflow/contrib/layers/python/layers/encoders.py @@ -125,7 +125,7 @@ def embed_sequence(ids, `reuse` is `None` or `False`. """ if not (reuse or (vocab_size and embed_dim)): - raise ValueError('Must specify vocab size and embedding dimension when not' + raise ValueError('Must specify vocab size and embedding dimension when not ' 'reusing. Got vocab_size=%s and embed_dim=%s' % ( vocab_size, embed_dim)) with variable_scope.variable_scope( diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 9ccb589d698ad83c9654f5523ccdcb35b031b3da..3ae07cedab0be2da8ec633cfd84e07cfdfb11457 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -48,7 +48,7 @@ you should choose depends on (1) the feature type and (2) the model type. recommended. embedded_dept_column = embedding_column( - sparse_column_with_keys("department", ["math", "philosphy", ...]), + sparse_column_with_keys("department", ["math", "philosophy", ...]), dimension=10) * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 78affea44cbfb92523063968dbc1be98841854db..06060b99e7e58787994f20f037ffa451abbc7459 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -815,7 +815,7 @@ class _Transformer(object): """ def __init__(self, columns_to_tensors): - """Initializes transfomer. + """Initializes transformer. Args: columns_to_tensors: A mapping from feature columns to tensors. 'string' @@ -908,7 +908,7 @@ def _gather_feature_columns(feature_columns): def _check_forbidden_sequence_columns(feature_columns): - """Recursively cecks `feature_columns` for `_FORBIDDEN_SEQUENCE_COLUMNS`.""" + """Recursively checks `feature_columns` for `_FORBIDDEN_SEQUENCE_COLUMNS`.""" all_feature_columns = _gather_feature_columns(feature_columns) for feature_column in all_feature_columns: if isinstance(feature_column, _FORBIDDEN_SEQUENCE_COLUMNS): diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index e27b36908eba7cc1b992079b1abb5c5367340de1..b6d63c9640611abdda65f1205f544ee505dae1f0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -51,17 +51,16 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import moving_averages -from tensorflow.python.layers.maxout import maxout # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. __all__ = [ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', - 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', + 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', + 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', @@ -933,7 +932,8 @@ def convolution(inputs, variables_collections=None, outputs_collections=None, trainable=True, - scope=None): + scope=None, + conv_dims=None): """Adds an N-D convolution followed by an optional batch_norm layer. It is required that 1 <= N <= 3. @@ -994,6 +994,10 @@ def convolution(inputs, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for `variable_scope`. + conv_dims: Optional convolution dimensionality, when set it would use the + corresponding convolution (e.g. 2 for Conv 2D, 3 for Conv 3D, ..). When + leaved to None it would select the convolution dimensionality based on + the input rank (i.e. Conv ND, with N = input_rank - 2). Returns: A tensor representing the output of the operation. @@ -1016,6 +1020,9 @@ def convolution(inputs, inputs = ops.convert_to_tensor(inputs) input_rank = inputs.get_shape().ndims + if conv_dims is not None and conv_dims + 2 != input_rank: + raise ValueError('Convolution expects input with rank %d, got %d' % + (conv_dims + 2, input_rank)) if input_rank == 3: layer_class = convolutional_layers.Convolution1D elif input_rank == 4: @@ -1062,10 +1069,134 @@ def convolution(inputs, outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def convolution1d(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + return convolution(inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=1) + +convolution1d.__doc__ = convolution.__doc__ -convolution2d = convolution -convolution3d = convolution +@add_arg_scope +def convolution2d(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + return convolution(inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=2) + +convolution2d.__doc__ = convolution.__doc__ +@add_arg_scope +def convolution3d(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + return convolution(inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=3) + +convolution3d.__doc__ = convolution.__doc__ @add_arg_scope def convolution2d_in_plane( @@ -1405,13 +1536,14 @@ def convolution3d_transpose( @add_arg_scope def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): """Converts a dense tensor into a sparse tensor. + An example use would be to convert dense labels to sparse ones so that they can be fed to the ctc_loss. Args: tensor: An `int` `Tensor` to be converted to a `Sparse`. eos_token: An integer. - It is part of the target label that signfies the end of a sentence. + It is part of the target label that signifies the end of a sentence. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. """ @@ -1555,7 +1687,7 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): output_collections: Collection to which the outputs will be added. scope: Optional scope for `name_scope`. Returns: - A `Tensor` or `SparseTensor` conataining the same values as `inputs`, but + A `Tensor` or `SparseTensor` containing the same values as `inputs`, but with innermost dimensions flattened to obtain rank `new_rank`. Raises: @@ -1890,6 +2022,7 @@ class GDN(base.Layer): def beta_initializer(shape, dtype=None, partition_info=None): del partition_info # unused + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal) def gamma_initializer(shape, dtype=None, partition_info=None): @@ -1897,6 +2030,7 @@ class GDN(base.Layer): assert len(shape) == 2 assert shape[0] == shape[1] eye = linalg_ops.eye(shape[0], dtype=dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) beta = self.add_variable( @@ -2192,11 +2326,16 @@ def images_to_sequence(inputs, outputs_collections=None, scope=None): """Convert a batch of images into a batch of sequences. + Args: inputs: a (num_images, height, width, depth) tensor data_format: A string. `NHWC` (default) and `NCHW` are supported. outputs_collections: The collections to which the outputs are added. scope: Optional scope for name_scope. + + Raises: + ValueError: If `data_format` is not either NCHW or NHWC. + Returns: (width, num_images*height, depth) sequence tensor """ @@ -2702,6 +2841,7 @@ def sequence_to_images(inputs, outputs_collections=None, scope=None): """Convert a batch of sequences into a batch of images. + Args: inputs: (num_steps, num_batches, depth) sequence tensor height: the height of the images @@ -2709,6 +2849,7 @@ def sequence_to_images(inputs, Currently supports `'channels_first'` and `'channels_last'`. outputs_collections: The collections to which the outputs are added. scope: Optional scope for name_scope. + Returns: A tensor representing the output of the operation. """ @@ -2718,7 +2859,7 @@ def sequence_to_images(inputs, if num_batches is None: num_batches = -1 else: - num_batches = num_batches // height + num_batches //= height reshaped = array_ops.reshape(inputs, [width, num_batches, height, depth]) if output_data_format == 'channels_first': @@ -2748,7 +2889,7 @@ def softmax(logits, scope=None): logits_2d = array_ops.reshape(logits, [-1, num_logits]) predictions = nn.softmax(logits_2d) predictions = array_ops.reshape(predictions, array_ops.shape(logits)) - if context.in_graph_mode(): + if not context.executing_eagerly(): predictions.set_shape(logits.get_shape()) return predictions @@ -2941,6 +3082,53 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None): return math_ops.div(inputs, array_ops.tile(lengths, multiples)) +@add_arg_scope +def maxout(inputs, num_units, axis=-1, scope=None): + """Adds a maxout op from https://arxiv.org/abs/1302.4389 + + "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron + Courville, + Yoshua Bengio + + Usually the operation is performed in the filter/channel dimension. This can + also be + used after fully-connected layers to reduce number of features. + + Arguments: + inputs: Tensor input + num_units: Specifies how many features will remain after maxout + in the `axis` dimension (usually channel). + This must be multiple of number of `axis`. + axis: The dimension where max pooling will be performed. Default is the + last dimension. + scope: Optional scope for variable_scope. + + Returns: + A `Tensor` representing the results of the pooling operation. + + Raises: + ValueError: if num_units is not multiple of number of features. + """ + with variable_scope.variable_scope(scope, 'MaxOut', [inputs]): + inputs = ops.convert_to_tensor(inputs) + shape = inputs.get_shape().as_list() + num_channels = shape[axis] + if num_channels % num_units: + raise ValueError('number of features({}) is not ' + 'a multiple of num_units({})'.format( + num_channels, num_units)) + shape[axis] = -1 + shape += [num_channels // num_units] + + # Dealing with batches with arbitrary sizes + for i in range(len(shape)): + if shape[i] is None: + shape[i] = array_ops.shape(inputs)[i] + outputs = math_ops.reduce_max( + array_ops.reshape(inputs, shape), -1, keepdims=False) + return outputs + + def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): """Project into the Poincare ball with norm <= 1.0 - epsilon. @@ -2999,16 +3187,16 @@ def legacy_fully_connected(x, `activation_fn` is `None`, the result of `y = w * x + b` is returned. - If `x` has shape [\\\(\\text{dim}_0, \\text{dim}_1, ..., \\text{dim}_n\\\)] - with more than 2 dimensions (\\\(n > 1\\\)), then we repeat the matrix + If `x` has shape [\\(\text{dim}_0, \text{dim}_1, ..., \text{dim}_n\\)] + with more than 2 dimensions (\\(n > 1\\)), then we repeat the matrix multiply along the first dimensions. The result r is a tensor of shape - [\\\(\\text{dim}_0, ..., \\text{dim}_{n-1},\\\) `num_output_units`], - where \\\( r_{i_0, ..., i_{n-1}, k} = - \\sum_{0 \\leq j < \\text{dim}_n} x_{i_0, ... i_{n-1}, j} \cdot w_{j, k}\\\). + [\\(\text{dim}_0, ..., \text{dim}_{n-1},\\) `num_output_units`], + where \\( r_{i_0, ..., i_{n-1}, k} = + \sum_{0 \leq j < \text{dim}_n} x_{i_0, ... i_{n-1}, j} \cdot w_{j, k}\\). This is accomplished by reshaping `x` to 2-D - [\\\(\\text{dim}_0 \\cdot ... \\cdot \\text{dim}_{n-1}, \\text{dim}_n\\\)] + [\\(\text{dim}_0 \cdot ... \cdot \text{dim}_{n-1}, \text{dim}_n\\)] before the matrix multiply and afterwards reshaping it to - [\\\(\\text{dim}_0, ..., \\text{dim}_{n-1},\\\) `num_output_units`]. + [\\(\text{dim}_0, ..., \text{dim}_{n-1},\\) `num_output_units`]. This op creates `w` and optionally `b`. Bias (`b`) can be disabled by setting `bias_init` to `None`. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0f062adbab3ca9acfb89543b69c7c957bbdf5dd8..c5c7269b1f15849956e90654e3bcf8ab0eebc393 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -310,6 +310,17 @@ class BiasAddTest(test.TestCase): class ConvolutionTest(test.TestCase): + def testInvalidShape(self): + with self.test_session(): + images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1) + with self.assertRaisesRegexp( + ValueError, 'Convolution expects input with rank 5, got 4'): + layers_lib.convolution3d(images_2d, 32, 3) + images_3d = random_ops.random_uniform((5, 6, 7, 9, 3), seed=1) + with self.assertRaisesRegexp( + ValueError, 'Convolution expects input with rank 4, got 5'): + layers_lib.convolution2d(images_3d, 32, 3) + def testInvalidDataFormat(self): height, width = 7, 9 with self.test_session(): @@ -1301,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) + def testConv1dShape(self): + width = 7 + with self.test_session(): + images = random_ops.random_uniform((5, width, 3), seed=1) + output = layers_lib.convolution1d(images, 32, 3) + self.assertEqual(output.op.name, 'Conv/Relu') + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + + def testConvInferSpatialDims(self): + depth, height, width = 7, 9, 11 + with self.test_session(): + images = np.random.uniform(size=(5, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3]) + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3]) + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) + images = np.random.uniform(size=(5, depth, height, width, + 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 32]) + class DenseToSparseTest(test.TestCase): @@ -1322,7 +1356,7 @@ class DropoutTest(test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) @@ -3155,7 +3189,7 @@ class RepeatTests(test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32) output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3]) - self.assertEqual(output.op.name, 'Repeat/convolution_3/Relu') + self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu') self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 32]) def testRepeatWithScope(self): @@ -3749,7 +3783,7 @@ class StackTests(test.TestCase): layers_lib.convolution2d, [10, 20, 30], kernel_size=[3, 3], padding='SAME') - self.assertEqual(output.op.name, 'Stack/convolution_3/Relu') + self.assertEqual(output.op.name, 'Stack/convolution2d_3/Relu') self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 30]) def testStackWithScope(self): @@ -4135,5 +4169,31 @@ class LegacyFullyConnectedTest(test.TestCase): _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax) +class MaxOutTest(test.TestCase): + + def test_simple(self): + inputs = random_ops.random_uniform((64, 10, 36), seed=1) + graph = _layers.maxout(inputs, num_units=3) + self.assertEqual(graph.get_shape().as_list(), [64, 10, 3]) + + def test_fully_connected(self): + inputs = random_ops.random_uniform((64, 50), seed=1) + graph = _layers.fully_connected(inputs, 50) + graph = _layers.maxout(graph, num_units=10) + self.assertEqual(graph.get_shape().as_list(), [64, 10]) + + def test_nchw(self): + inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) + graph = _layers.conv2d(inputs, 10, 3, padding='SAME') + graph = _layers.maxout(graph, num_units=1) + self.assertEqual(graph.get_shape().as_list(), [10, 100, 100, 1]) + + def test_invalid_shape(self): + inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) + graph = _layers.conv2d(inputs, 3, 10) + with self.assertRaisesRegexp(ValueError, 'number of features'): + graph = _layers.maxout(graph, num_units=2) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index e7d4080ff769327cc74b6629a7705ddfa552169b..c807ab0f2e5c8ac3ec2ae1d84a5b36b5f4ba76a4 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -24,11 +24,13 @@ from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope __all__ = [ + 'group_norm', 'instance_norm', ] @@ -158,3 +160,196 @@ def instance_norm(inputs, if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + + +@add_arg_scope +def group_norm(inputs, + groups=32, + channels_axis=-1, + reduction_axes=(-3, -2), + center=True, + scale=True, + epsilon=1e-6, + activation_fn=None, + param_initializers=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Functional interface for the group normalization layer. + + Reference: https://arxiv.org/abs/1803.08494. + + "Group Normalization", Yuxin Wu, Kaiming He + + Args: + inputs: A Tensor with at least 2 dimensions one which is channels. All + shape dimensions must be fully defined. + groups: Integer. Divide the channels into this number of groups over which + normalization statistics are computed. This number must be commensurate + with the number of channels in `inputs`. + channels_axis: An integer. Specifies index of channels axis which will be + broken into `groups`, each of which whose statistics will be computed + across. Must be mutually exclusive with `reduction_axes`. Preferred usage + is to specify negative integers to be agnostic as to whether a batch + dimension is included. + reduction_axes: Tuple of integers. Specifies dimensions over which + statistics will be accumulated. Must be mutually exclusive with + `channels_axis`. Statistics will not be accumulated across axes not + specified in `reduction_axes` nor `channel_axis`. Preferred usage is to + specify negative integers to be agnostic to whether a batch dimension is + included. + + Some sample usage cases: + NHWC format: channels_axis=-1, reduction_axes=[-3, -2] + NCHW format: channels_axis=-3, reduction_axes=[-2, -1] + + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + epsilon: Small float added to variance to avoid dividing by zero. + activation_fn: Activation function, default set to None to skip it and + maintain a linear activation. + param_initializers: Optional initializers for beta, gamma, moving mean and + moving variance. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional collections for the variables. + outputs_collections: Collections to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + scope: Optional scope for `variable_scope`. + + Returns: + A `Tensor` representing the output of the operation. + + Raises: + ValueError: If the rank of `inputs` is undefined. + ValueError: If rank or channels dimension of `inputs` is undefined. + ValueError: If number of groups is not commensurate with number of channels. + ValueError: If reduction_axes or channels_axis are out of bounds. + ValueError: If reduction_axes are not mutually exclusive with channels_axis. + """ + # TODO(shlens): Support partially defined shapes for the inputs. + inputs = ops.convert_to_tensor(inputs) + original_shape = inputs.shape + + if inputs.shape.ndims is None: + raise ValueError('Inputs %s has undefined rank.' % inputs.name) + if channels_axis > (inputs.shape.ndims - 1): + raise ValueError('Axis is out of bounds.') + + # Standardize the channels_axis to be positive and identify # of channels. + if channels_axis < 0: + channels_axis = inputs.shape.ndims + channels_axis + channels = inputs.shape[channels_axis].value + + if channels is None: + raise ValueError('Inputs %s has undefined channel dimension: %d.' % ( + inputs.name, channels_axis)) + + # Standardize the reduction_axes to be positive. + reduction_axes = list(reduction_axes) + for i in range(len(reduction_axes)): + if reduction_axes[i] < 0: + reduction_axes[i] += inputs.shape.ndims + + for a in reduction_axes: + if a > inputs.shape.ndims: + raise ValueError('Axis is out of bounds.') + if inputs.shape[a].value is None: + raise ValueError('Inputs %s has undefined dimensions %d.' % ( + inputs.name, a)) + if channels_axis == a: + raise ValueError('reduction_axis must be mutually exclusive ' + 'with channels_axis') + if groups > channels: + raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) + if channels % groups != 0: + raise ValueError('%d channels is not commensurate with %d groups.' % + (channels, groups)) + + # Determine axes before channels. Some examples of common image formats: + # 'NCHW': before = [N], after = [HW] + # 'NHWC': before = [NHW], after = [] + axes_before_channels = inputs.shape.as_list()[:channels_axis] + axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + + # Manually broadcast the parameters to conform to the number of groups. + params_shape_broadcast = ([1] * len(axes_before_channels) + + [groups, channels // groups] + + [1] * len(axes_after_channels)) + + # Reshape the input by the group within the channel dimension. + inputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + inputs = array_ops.reshape(inputs, inputs_shape) + + # Determine the dimensions across which moments are calculated. + moments_axes = [channels_axis + 1] + for a in reduction_axes: + if a > channels_axis: + moments_axes.append(a + 1) + else: + moments_axes.append(a) + + with variable_scope.variable_scope( + scope, 'GroupNorm', [inputs], reuse=reuse) as sc: + # Note that the params_shape is the number of channels always. + params_shape = [channels] + + # Allocate parameters for the beta and gamma of the normalization. + beta, gamma = None, None + dtype = inputs.dtype.base_dtype + if param_initializers is None: + param_initializers = {} + if center: + beta_collections = utils.get_variable_collections( + variables_collections, 'beta') + beta_initializer = param_initializers.get( + 'beta', init_ops.zeros_initializer()) + beta = variables.model_variable('beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) + beta = array_ops.reshape(beta, params_shape_broadcast) + + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get( + 'gamma', init_ops.ones_initializer()) + gamma = variables.model_variable('gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + gamma = array_ops.reshape(gamma, params_shape_broadcast) + + # Calculate the moments. + mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) + + # Compute normalization. + # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor + # appropriately so that this operation may be faster. + gain = math_ops.rsqrt(variance + epsilon) + offset = -mean * gain + if gamma is not None: + gain *= gamma + offset *= gamma + if beta is not None: + offset += beta + outputs = inputs * gain + offset + + # Collapse the groups into the channel dimension. + outputs = array_ops.reshape(outputs, original_shape) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index 5cff1bf0ebb2fe8bc6933de882ecd47a9edf0f94..b6e96350db92baf4770683273be7e5dde73dbcec 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -166,5 +166,231 @@ class InstanceNormTest(test.TestCase): def testOutputBigInput5DNCHW(self): self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3) + +class GroupNormTest(test.TestCase): + + def testInvalidGroupSize(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10)) + with self.assertRaisesRegexp(ValueError, + 'Invalid groups 10 for 2 channels.'): + normalization.group_norm(inputs, groups=10, + reduction_axes=[-2, -1], channels_axis=-3) + + def testBadCommensurateGroup(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10)) + with self.assertRaisesRegexp(ValueError, + '4 channels is not commensurate with ' + '3 groups.'): + normalization.group_norm(inputs, groups=3, + reduction_axes=[-2, -1], channels_axis=-3) + + def testAxisIsBad(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5)) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, channels_axis=5) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, reduction_axes=[1, 5]) + + def testNotMutuallyExclusiveAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32)) + # Specify axis with negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2]) + # Specify axis with positive values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3]) + # Specify axis with mixed positive and negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2]) + + def testUnknownShape(self): + inputs = array_ops.placeholder(dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'undefined rank'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedReductionAxes(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4)) + with self.assertRaisesRegexp(ValueError, 'undefined dimensions'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedChannelsAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None)) + with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'): + normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2]) + + def testCreateOp(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) + output = normalization.group_norm(images, groups=groups, channels_axis=-1, + reduction_axes=[-3, -2]) + print('name: ', output.op.name) + self.assertListEqual([5, height, width, 2*groups], output.shape.as_list()) + + def testCreateOpFloat64(self): + height, width, groups = 3, 3, 5 + images = random_ops.random_uniform( + (5, height, width, 4*groups), dtype=dtypes.float64, seed=1) + output = normalization.group_norm(images, groups=groups) + self.assertEqual(dtypes.float64, output.dtype) + self.assertListEqual([5, height, width, 4*groups], output.shape.as_list()) + + def testCreateOpNoScaleCenter(self): + height, width, groups = 3, 3, 7 + images = random_ops.random_uniform( + (5, height, width, 3*groups), dtype=dtypes.float32, seed=1) + output = normalization.group_norm(images, groups=groups, center=False, + scale=False) + self.assertListEqual([5, height, width, 3*groups], output.shape.as_list()) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta'))) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma'))) + + def testCreateVariables_NHWC(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 8), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-1, reduction_axes=(-3, -2), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testCreateVariables_NCHW(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, 2*groups, height, width), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-3, reduction_axes=(-2, -1), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testReuseVariables(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 4), seed=1) + normalization.group_norm(images, groups=2, scale=True, scope='IN') + normalization.group_norm(images, groups=2, scale=True, scope='IN', + reuse=True) + beta = contrib_variables.get_variables_by_name('beta') + gamma = contrib_variables.get_variables_by_name('gamma') + self.assertEqual(1, len(beta)) + self.assertEqual(1, len(gamma)) + + def testValueCorrectWithReuseVars(self): + height, width = 3, 3 + image_shape = (10, height, width, 4) + images = random_ops.random_uniform(image_shape, seed=1) + output_train = normalization.group_norm(images, groups=2, scope='IN') + output_eval = normalization.group_norm(images, groups=2, scope='IN', + reuse=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + # output_train and output_eval should be the same. + train_np, eval_np = sess.run([output_train, output_eval]) + self.assertAllClose(train_np, eval_np) + + def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None, + groups=2, tol=1e-2): + # Select the axis for the channel and the dimensions along which statistics + # are accumulated. + if channels_axis < 0: + channels_axis += len(input_shape) + reduced_axes = [channels_axis + 1] + for a in reduction_axes: + if a < 0: + a += len(input_shape) + if a < channels_axis: + reduced_axes.append(a) + else: + reduced_axes.append(a+1) + reduced_axes = tuple(reduced_axes) + + # Calculate the final shape for the output Tensor. + axes_before_channels = input_shape[:channels_axis] + axes_after_channels = input_shape[channels_axis+1:] + channels = input_shape[channels_axis] + outputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + + # Calculate the final shape for the output statistics. + reduced_shape = [] + for i, a in enumerate(outputs_shape): + if i not in reduced_axes: + reduced_shape.append(a) + + for mu in (0.0, 1e2): + for sigma in (1.0, 0.1): + # Determine shape of Tensor after normalization. + expected_mean = np.zeros(reduced_shape) + expected_var = np.ones(reduced_shape) + + inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu + output_op = normalization.group_norm( + inputs, groups=groups, center=False, scale=False, + channels_axis=channels_axis, + reduction_axes=reduction_axes) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outputs = sess.run(output_op) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(outputs).any()) + + outputs = np.reshape(outputs, outputs_shape) + mean = np.mean(outputs, axis=reduced_axes) + var = np.var(outputs, axis=reduced_axes) + # The mean and variance of each example should be close to 0 and 1 + # respectively. + self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol) + self.assertAllClose(expected_var, var, rtol=tol, atol=tol) + + def testOutputSmallInput4D_NHWC(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput3D_NHWC(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput4D_NCHW(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputSmallInput3D_NCHW(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputBigInput4D_NHWC(self): + self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], + groups=1) + + def testOutputBigInput4D_NCHW(self): + self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], + groups=4) + + def testOutputSmallInput2D_NC(self): + self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7) + + def testOutputSmallInput5D_NCXXX(self): + self.doOutputTest([10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index cdceea6fee5bdb5aeb6537ea55d25ccf107def4c..69d927e1b3001d14dd1af2f890b07c1a57ab2cfc 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -41,7 +41,7 @@ OPTIMIZER_CLS_NAMES = { "Adagrad": train.AdagradOptimizer, "Adam": train.AdamOptimizer, "Ftrl": train.FtrlOptimizer, - "Momentum": lambda lr: train.MomentumOptimizer(lr, momentum=0.9), + "Momentum": lambda learning_rate: train.MomentumOptimizer(learning_rate, momentum=0.9), # pylint: disable=line-too-long "RMSProp": train.RMSPropOptimizer, "SGD": train.GradientDescentOptimizer, } diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 1ea25bd1a5685eb6f840e621b5739029a660aa0f..a4461a20e54c289886f1a1beb255de12fc054afe 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -61,7 +61,8 @@ class OptimizersTest(test.TestCase): optimizers = [ "SGD", gradient_descent.GradientDescentOptimizer, gradient_descent.GradientDescentOptimizer(learning_rate=0.1), - lambda lr: gradient_descent.GradientDescentOptimizer(learning_rate=lr) + lambda lr: gradient_descent.GradientDescentOptimizer(learning_rate=lr), + "Momentum" ] for optimizer in optimizers: with ops.Graph().as_default() as g: diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 123275e1fde047cd3772528641b2e3b09742fbdc..0e35b1aa8bf682c1b4f7e8d974d3e8fad69e33cb 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -29,23 +29,36 @@ from __future__ import print_function import functools import re +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops -from tensorflow.python.framework import function +from tensorflow.python.eager import backprop +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") +_USE_DEFAULT = "__rev_block_lib_default" +_WRONG_VARS_ERR = """\ +The variables used on recompute were different than the variables originally +used. The function wrapped with @recompute_grad likley creates its own variable +scope with a default name and has been called twice in the same enclosing scope. +To fix, ensure each call to the function happens in its own unique variable +scope. +""" def _acc_grads(*lists_of_grads): @@ -142,7 +155,7 @@ def _scope_wrap(fn, scope): @functools.wraps(fn) def wrap(*args, **kwargs): - with variable_scope.variable_scope(scope): + with variable_scope.variable_scope(scope, use_resource=True): return fn(*args, **kwargs) return wrap @@ -217,91 +230,95 @@ class RevBlock(base.Layer): "build.") self.built = True - def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): - """Custom gradient fn for a block of reversible residual layers.""" - side_inputs = inputs[2:] - f_side_idxs = [None] * len(self.f_side_input) - g_side_idxs = [None] * len(self.g_side_input) - assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) - - for i, t in enumerate(side_inputs): - if t in self.f_side_input: - f_side_idxs[self.f_side_input.index(t)] = i - elif t in self.g_side_input: - g_side_idxs[self.g_side_input.index(t)] = i - else: - assert False - - f_vars = [[] for _ in range(self.num_layers)] - g_vars = [[] for _ in range(self.num_layers)] - f_vars_idxs = [[] for _ in range(self.num_layers)] - g_vars_idxs = [[] for _ in range(self.num_layers)] - - for i, t in enumerate(variables): - ref = _underlying_variable_ref(t) - - # Use the name to identify the layer number and function (f or g) - regex = LAYER_RE.match(ref.name) - layer_no = int(regex.group(1)) - fn_name = regex.group(2) - if fn_name == "f": - f_vars[layer_no].append(ref) - f_vars_idxs[layer_no].append(i) - else: - assert fn_name == "g" - g_vars[layer_no].append(ref) - g_vars_idxs[layer_no].append(i) - - f_var_grads = [] - g_var_grads = [] - f_side_grads = [] - g_side_grads = [] - - # Reverse variable containers to go backward - f_vars.reverse() - g_vars.reverse() - f = list(self.f) - g = list(self.g) - f.reverse() - g.reverse() - - with variable_scope.variable_scope(self.scope_name, reuse=True): - for i in xrange(self.num_layers): - ys, grad_ys, f_ret, g_ret = _rev_layer_backward( - ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], - self.g_side_input) - - grad_f_vars, grad_f_side = f_ret - grad_g_vars, grad_g_side = g_ret - f_var_grads.append(grad_f_vars) - g_var_grads.append(grad_g_vars) - f_side_grads.append(grad_f_side) - g_side_grads.append(grad_g_side) - - # Accumulate layer gradients for f_side_input and g_side_input - acc_f_side_grads = _acc_grads(*f_side_grads) - acc_g_side_grads = _acc_grads(*g_side_grads) - - # Use the stored idxs to put gradients in the passed-in order. - side_input_grads = [None] * len(side_inputs) - variable_grads = [None] * len(variables) - - # Variable gradients were collected in reverse layer order. Reverse to match - # idxs. - f_var_grads.reverse() - g_var_grads.reverse() - for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( - zip(g_vars_idxs, g_var_grads)): - for i, grad in zip(idxs, grads): - variable_grads[i] = grad - - for i, grad in zip(f_side_idxs, acc_f_side_grads): - side_input_grads[i] = grad - for i, grad in zip(g_side_idxs, acc_g_side_grads): - side_input_grads[i] = grad - - grad_x1, grad_x2 = grad_ys - return [grad_x1, grad_x2] + side_input_grads, variable_grads + def _make_efficient_grad_fn(self, inputs_, ys_): + def _efficient_grad_fn(*grad_ys, **kwargs): + """Custom gradient fn for a block of reversible residual layers.""" + inputs = inputs_ + ys = ys_ + variables = kwargs["variables"] + side_inputs = inputs[2:] + + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False + + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] + + for i, ref in enumerate(variables): + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) + + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] + + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) + + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) + + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) + + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) + + # Variable gradients were collected in reverse layer order. Reverse to + # match idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad + + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad + + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + return _efficient_grad_fn def _forward(self, x1, x2): """Run forward through the reversible layers.""" @@ -309,10 +326,6 @@ class RevBlock(base.Layer): side_inputs = [self.f_side_input, self.g_side_input] flat_side_inputs = nest.flatten(side_inputs) - custom_grad_fn = ( - self._efficient_grad_fn if self._use_efficient_backprop else None) - - @_fn_with_custom_grad(custom_grad_fn) def _forward_wrap(x1_, x2_, *flat_side_inputs): f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) return _rev_block_forward( @@ -325,7 +338,16 @@ class RevBlock(base.Layer): g_side_input=g_side, gate_outputs=self._use_efficient_backprop) - return _forward_wrap(x1, x2, *flat_side_inputs) + @custom_gradient.custom_gradient + def _forward_with_custom_grad(*args): + out = _forward_wrap(*args) # pylint: disable=no-value-for-parameter + grad_fn = self._make_efficient_grad_fn(args, out) + return out, grad_fn + + if self._use_efficient_backprop: + return _forward_with_custom_grad(x1, x2, *flat_side_inputs) + else: + return _forward_wrap(x1, x2, *flat_side_inputs) def _backward(self, y1, y2): """Run backward through the reversible layers.""" @@ -405,12 +427,49 @@ def rev_block(x1, return block.forward(x1, x2) -def recompute_grad(fn): +def enable_with_args(dec): + """A decorator for decorators to enable their usage with or without args.""" + + @functools.wraps(dec) + def new_dec(*args, **kwargs): + if len(args) == 1 and not kwargs and callable(args[0]): + # Used as decorator without args + fn = args[0] + return dec(fn) + else: + return lambda fn: dec(fn, *args, **kwargs) + + return new_dec + + +@enable_with_args +def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. + To use this function, you must use `ResourceVariable`s (i.e. + `variable_scope(name, use_resource=True), which are the default in Eager mode + and when running on TPU. + + Warning: Because the function will be called again on the backwards pass, the + user should be careful to not use ops in their function that mutate state or + have randomness (for example, batch normalization or dropout). If the function + does have such operations, it is recommended that the function take the + `is_recomputing` keyword argument which will be `False` on the forward pass + and `True` on the backwards pass so that it can disable state changes when + `is_recomputing=True` (for example, not updating the moving averages in batch + normalization). + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. + use_data_dep: `bool`, if `True` will use a dummy data dependency to force + the recompute to happen. If `False` will use a control dependency. By + default will be `True` if in an XLA context and `False` otherwise. XLA + ignores control dependencies and so this data dependency is necessary. + tupleize_grads: `bool`, if `True` will use control dependencies to ensure + that all gradients are produced before any are consumed by downstream ops. + If `use_data_dep` is also `True`, will use a data dependency instead of + a control dependency. Returns: A wrapped fn that is identical to fn when called, but its activations will @@ -420,43 +479,82 @@ def recompute_grad(fn): @functools.wraps(fn) def wrapped(*args): - return _recompute_grad(fn, args) + return _recompute_grad( + fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads) return wrapped -def _recompute_grad(fn, args): - """See recompute_grad.""" +def _is_on_tpu(): + ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingXLAContext(ctxt) is not None + - cached_vs = [] - cached_arg_scope = [] - - def grad_fn(inputs, variables, outputs, output_grads): - """Recompute outputs for gradient computation.""" - del outputs - # Recompute outputs - with framework_ops.control_dependencies(output_grads): - with contrib_framework_ops.arg_scope(cached_arg_scope[0]): - with variable_scope.variable_scope(cached_vs[0], reuse=True): - outputs = fn(*inputs) - - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): - outputs = [outputs] - outputs = list(outputs) - grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) - grad_inputs = grads[:len(inputs)] - grad_vars = grads[len(inputs):] - return grad_inputs, grad_vars - - @_fn_with_custom_grad(grad_fn) +def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): + """See recompute_grad.""" + has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args + for arg in args: + if not isinstance(arg, framework_ops.Tensor): + raise ValueError("All inputs to function must be Tensors") + use_data_dep_ = use_data_dep + if use_data_dep_ == _USE_DEFAULT: + use_data_dep_ = _is_on_tpu() + + @custom_gradient.custom_gradient def fn_with_recompute(*args): - cached_vs.append(variable_scope.get_variable_scope()) - # TODO(rsepassi): Rm conditional in TF 1.4 - if hasattr(contrib_framework_ops, "current_arg_scope"): - cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) - else: - cached_arg_scope.append({}) - return fn(*args) + """Wrapper for fn.""" + # Forward pass + vs = variable_scope.get_variable_scope() + arg_scope = contrib_framework_ops.current_arg_scope() + with backprop.GradientTape() as tape: + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = False + outputs = fn(*args, **fn_kwargs) + original_vars = set(tape.watched_variables()) + + # Backward pass + def grad_fn(*output_grads, **kwargs): + """Recompute outputs for gradient computation.""" + variables = [] + if original_vars: + variables = kwargs["variables"] + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + del kwargs + inputs = list(args) + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) + with contrib_framework_ops.arg_scope(arg_scope): + with variable_scope.variable_scope(vs, reuse=True): + with backprop.GradientTape() as tape: + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = fn(*inputs, **fn_kwargs) + recompute_vars = set(tape.watched_variables()) + if original_vars != recompute_vars: + raise ValueError(_WRONG_VARS_ERR) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, + output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + return outputs, grad_fn return fn_with_recompute(*args) @@ -483,101 +581,46 @@ def _underlying_variable_ref(t): return None -def _fn_with_custom_grad(grad_fn, use_global_vars=False): - """Decorator to create a subgraph with a custom gradient function. +def _force_data_dependency(first_compute, then_compute): + """Force all of `then_compute` to depend on all of `first_compute`. - The subgraph created by the decorated function is NOT put in a Defun and so - does not suffer from the limitations of the Defun (all subgraph ops on the - same device, no summaries). + Uses a dummy data dependency, which is useful when running on TPUs because + XLA ignores control dependencies. Only supports float arguments. Args: - grad_fn: function with signature - (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. + first_compute: `list`. These will be made to run before the + `Tensor`s `then_compute`. + then_compute: `list`. These will run after all the `Tensor`s in + `first_compute`. Returns: - Decorator for function such that the gradient is defined by grad_fn. - """ - - def dec(fn): - - @functools.wraps(fn) - def wrapped(*args): - return _fn_with_custom_grad_internal( - fn, args, grad_fn, use_global_vars=use_global_vars) + `list`, same length as `then_compute`. - return wrapped - - return dec - - -def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): - """Create a subgraph with a custom gradient. - - Args: - fn: function that takes inputs as arguments and produces 1 or more Tensors. - inputs: list, will be passed as fn(*inputs). - grad_fn: function with signature - (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - fn(*inputs) + Raises: + ValueError: if ranks are unknown or types are not floating. """ - vs = variable_scope.get_variable_scope() - get_vars_fn = ( - vs.global_variables if use_global_vars else vs.trainable_variables) - len_before_vars = len(get_vars_fn()) - inputs = list(inputs) - outputs = fn(*inputs) - train_vars = get_vars_fn()[len_before_vars:] - - if grad_fn is None: - return outputs - - if not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] - outputs = list(outputs) - - defun_inputs = [inputs, train_vars, outputs] - - def custom_grad_fn(op, *dys): - """Custom grad fn applying grad_fn for identity Defun.""" - fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( - defun_inputs, list(op.inputs)) - dys = list(dys) - assert len(fn_outputs) == len(outputs) - assert len(fn_outputs) == len(dys) - - grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) - grad_outputs = [None] * len(fn_outputs) - return tuple(grad_inputs + grad_vars + grad_outputs) - - # The Defun takes as input the original inputs, the trainable variables - # created in fn, and the outputs. In the forward it passes through the - # outputs. In the backwards, it produces gradients for the original inputs - # and the trainable variables. - in_types = [t.dtype for t in inputs] - out_types = [t.dtype for t in outputs] - var_types = [t.dtype for t in train_vars] - - # Get a unique name for the Defun - with framework_ops.name_scope("identity_custom_grad") as ns: - defun_name = ns - - @function.Defun( - *(in_types + var_types + out_types), - func_name=defun_name, - python_grad_func=custom_grad_fn, - shape_func=lambda _: [t.get_shape() for t in outputs]) - def identity(*args): - _, _, outs = nest.pack_sequence_as(defun_inputs, args) - return tuple([array_ops.identity(t) for t in outs]) - - flat_inputs = nest.flatten(defun_inputs) - id_out = identity(*flat_inputs) - return id_out + + def _first_element(x): + if x.get_shape().ndims is None: + raise ValueError("Rank of Tensor %s must be known" % x) + ndims = x.get_shape().ndims + begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32) + size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32) + return array_ops.reshape(array_ops.slice(x, begin, size), []) + + first_compute_sum = math_ops.add_n( + [_first_element(x) for x in first_compute if x is not None]) + dtype = first_compute_sum.dtype + if not dtype.is_floating: + raise ValueError("_force_data_dependency only supports floating dtypes.") + epsilon = np.finfo(dtype.as_numpy_dtype).tiny + zero = array_ops.stop_gradient(epsilon * first_compute_sum) + + return [ + array_ops.identity(x) + zero if x is not None else None + for x in then_compute + ] + + +def _tuple_with_data_dep(tensors): + return _force_data_dependency(tensors, tensors) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index cbcbcd75114a522b95631e4e7e95c1641b0a9987..bc09ba8d439808c1582f207a99504012afcf33a6 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -21,9 +21,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import rev_block_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import convolutional from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -60,8 +62,8 @@ class RevBlockTest(test.TestCase): sess.run(variables.global_variables_initializer()) x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) - self.assertAllClose(x1, x1_inv) - self.assertAllClose(x2, x2_inv) + self.assertAllClose(x1, x1_inv, atol=1e-5) + self.assertAllClose(x2, x2_inv, atol=1e-5) def testBackwardForward(self): @@ -83,8 +85,8 @@ class RevBlockTest(test.TestCase): sess.run(variables.global_variables_initializer()) y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) - self.assertAllClose(y1, y1_inv) - self.assertAllClose(y2, y2_inv) + self.assertAllClose(y1, y1_inv, rtol=1e-5) + self.assertAllClose(y2, y2_inv, rtol=1e-5) def _testRevBlock(self, x=None, @@ -154,7 +156,7 @@ class RevBlockTest(test.TestCase): y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): - self.assertAllClose(g1, g2) + self.assertAllClose(g1, g2, rtol=1e-5) def testRevBlock(self): self._testRevBlock() @@ -179,18 +181,16 @@ class RevBlockTest(test.TestCase): self._testRevBlock(f=[f1, f2, f1, f2]) - # TODO(rsepassi): Recent change to conv seems to have broken this test. Find - # out why. - def _testConvAndBatchNorm(self): + def testConvAndBatchNorm(self): x = random_ops.random_uniform( [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) def f(x): x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) return x self._testRevBlock(x=x, f=f) @@ -255,109 +255,122 @@ class RecomputeTest(test.TestCase): def fn_recompute(x): return fn(x) + @rev_block_lib.recompute_grad(use_data_dep=True) + def fn_use_data_dep(x): + return fn(x) + + @rev_block_lib.recompute_grad(tupleize_grads=True) + def fn_tupleize(x): + return fn(x) + + @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True) + def fn_both(x): + return fn(x) + x = random_ops.random_uniform((3, 1, 3)) - recompute_vars = None - with variable_scope.variable_scope("recompute") as vs: - out1 = math_ops.reduce_sum(fn_recompute(x)) - recompute_vars = vs.trainable_variables() - reg_vars = None - with variable_scope.variable_scope("regular") as vs: - out2 = math_ops.reduce_sum(fn(x)) - reg_vars = vs.trainable_variables() - - grad1 = gradients_impl.gradients(out1, recompute_vars) - grad2 = gradients_impl.gradients(out2, reg_vars) + + names_and_fns = [ + ("recompute", fn_recompute), + ("regular", fn), + ("use_data_dep", fn_use_data_dep), + ("tupleize", fn_tupleize), + ("tuple_and_data_dep", fn_both), + ] + outputs_and_vars = [] + for name, wrapped_fn in names_and_fns: + with variable_scope.variable_scope(name, use_resource=True) as vs: + out = math_ops.reduce_sum(wrapped_fn(x)) + outputs_and_vars.append((out, vs.trainable_variables())) + + all_grads = [] + for out, scope_vars in outputs_and_vars: + all_grads.append(gradients_impl.gradients(out, scope_vars)) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - outs = sess.run([out1, out2, grad1, grad2]) - self.assertAllClose(outs[0], outs[1]) - for g1, g2 in zip(outs[2], outs[3]): - self.assertAllClose(g1, g2) + outputs = list(zip(*outputs_and_vars))[0] + outs, all_grads_val = sess.run([outputs, all_grads]) + # All outputs are the same + current = outs[0] + for out in outs[1:]: + self.assertAllClose(current, out) + current = out -class FnWithCustomGradTest(test.TestCase): + # All gradients are the same + for grads in zip(all_grads_val): + current = grads[0] + for g in grads[1:]: + self.assertAllClose(current, g) + current = g - def testCorrectness(self): + def testDoubleCallInSameScopeFails(self): - w = random_ops.random_uniform([6, 10]) + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs): + return core_layers.dense(inputs, 2) - def fn(a, b, c): - return core_layers.dense( - a, - 10, - use_bias=False, - kernel_initializer=lambda shape, dtype, partition_info: w - ) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, outputs, grad_outputs): - outputs = outputs[0] - grad_outputs = grad_outputs[0] - grad_inputs = gradients_impl.gradients( - outputs, inputs, grad_ys=grad_outputs) - grad_vars = gradients_impl.gradients( - outputs, trainable_variables, grad_ys=grad_outputs) - return grad_inputs, grad_vars - - custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - - out = fn(a, b, c) - custom_out = custom_fn(a, b, c) - self.assertEqual(out.get_shape().as_list(), - custom_out.get_shape().as_list()) - - loss = math_ops.reduce_mean(out) - custom_loss = math_ops.reduce_mean(custom_out) - - grads = gradients_impl.gradients( - loss, [a, b, c] + [variables.trainable_variables()[0]]) - custom_grads = gradients_impl.gradients( - custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) + with variable_scope.variable_scope("layer", use_resource=True): + inputs = array_ops.ones((2, 4), dtypes.float32) + out1 = layer_with_recompute(inputs) + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - out_val, custom_out_val, grads_val, custom_grads_val = sess.run( - [out, custom_out, grads, custom_grads]) - self.assertAllClose(out_val, custom_out_val) - for g1, g2 in zip(grads_val, custom_grads_val): - self.assertAllClose(g1, g2) - - def testCustomGrad(self): - - def fn(a, b, c): - return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, unused_outputs, - unused_grad_outputs): - grad_inputs = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) - ] - grad_vars = [ - array_ops.ones_like(t) * (i + len(inputs) + 1.) - for i, t in enumerate(trainable_variables) - ] - return grad_inputs, grad_vars - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - w = random_ops.random_uniform([6, 10]) - out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) - loss = math_ops.reduce_mean(out) - grads = gradients_impl.gradients( - loss, [a, b, c, variables.trainable_variables()[0]]) - expected_grads = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) - ] - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - g_val, eg_val = sess.run([grads, expected_grads]) - for g1, g2 in zip(g_val, eg_val): - self.assertAllClose(g1, g2) + tvars = variables.trainable_variables() + assert len(tvars) == 4 + with self.assertRaisesWithPredicateMatch( + ValueError, "called twice in the same enclosing scope"): + gradients_impl.gradients(out, [inputs] + tvars) + + def testDoubleCallInUniqueScope(self): + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs): + with variable_scope.variable_scope("inner", use_resource=True): + return core_layers.dense(inputs, 2) + + with variable_scope.variable_scope("layer", use_resource=True): + inputs = array_ops.ones((2, 4), dtypes.float32) + + with variable_scope.variable_scope("layer1", use_resource=True): + out1 = layer_with_recompute(inputs) + with variable_scope.variable_scope("layer2", use_resource=True): + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) + + tvars = variables.trainable_variables() + assert len(tvars) == 4 + grads = gradients_impl.gradients(out, [inputs] + tvars) + for grad in grads: + self.assertTrue(grad is not None) + + def testWithIsRecomputeKwarg(self): + + kwarg_values = [] + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs, is_recomputing=False): + kwarg_values.append(is_recomputing) + out = core_layers.dense(inputs, 2) + out = normalization_layers.batch_normalization(out, training=True) + if is_recomputing: + # Ensure that the updates are not duplicated by popping off the latest + # 2 additions. + update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) + update_ops.pop() + update_ops.pop() + return out + + x = array_ops.ones((2, 4), dtypes.float32) + with variable_scope.variable_scope("layer1", use_resource=True): + y = layer_with_recompute(x) + loss = math_ops.reduce_sum(y) + tvars = variables.trainable_variables() + gradients_impl.gradients(loss, [x] + tvars) + + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + self.assertEqual(2, len(update_ops)) + self.assertEqual([False, True], kwarg_values) if __name__ == "__main__": diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 3e639a180ef11af5f7f498c647eb25417f918eb9..69bb6be81453f5f5487f25547f017dc5f87c2f2c 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -270,7 +270,7 @@ class _RegressionTargetColumn(_TargetColumn): def logits_to_predictions(self, logits, proba=False): if self.num_label_columns == 1: - return array_ops.squeeze(logits, squeeze_dims=[1]) + return array_ops.squeeze(logits, axis=[1]) return logits def get_eval_ops(self, features, logits, labels, metrics=None): @@ -418,7 +418,7 @@ def _softmax_cross_entropy_loss(logits, target): "Instead got %s." % target.dtype) # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. if len(target.get_shape()) == 2: - target = array_ops.squeeze(target, squeeze_dims=[1]) + target = array_ops.squeeze(target, axis=[1]) loss_vec = nn.sparse_softmax_cross_entropy_with_logits( labels=target, logits=logits) return loss_vec diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py index 3409860add8f8c393ffd342633e7023931867dd9..645dc1291eb6370a5e504306fc00a5454dde77ed 100644 --- a/tensorflow/contrib/layers/python/layers/utils_test.py +++ b/tensorflow/contrib/layers/python/layers/utils_test.py @@ -294,7 +294,6 @@ class NPositiveIntegersTest(test.TestCase): self.assertEqual(utils.n_positive_integers(2, 2), (2, 2)) self.assertEqual(utils.n_positive_integers(2, (2, 3)), (2, 3)) self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) - self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) self.assertEqual( utils.n_positive_integers(3, tensor_shape.TensorShape([2, 3, 1])), (2, 3, 1)) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index abf6e393bb0fbbce4e43f6d209e9b30517df36c3..b56a88659bbd4467600788fc8e3e9dbf38ce8244 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load("//tensorflow:tensorflow.bzl", "py_test") + package(default_visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", @@ -115,6 +117,7 @@ py_test( size = "small", srcs = ["python/learn/learn_io/data_feeder_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/python:client_testlib", @@ -170,6 +173,7 @@ tf_py_test( "//tensorflow/python:variables", "//tensorflow/python/estimator", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) py_test( @@ -188,6 +192,7 @@ py_test( size = "small", srcs = ["python/learn/graph_actions_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -224,6 +229,7 @@ py_test( size = "small", srcs = ["python/learn/monitors_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip_gpu"], # b/74437598 deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -275,7 +281,11 @@ py_test( size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], + tags = [ + "manual", + "noasan", # times out + "optonly", # test is flaky without optimization. + ], deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -425,7 +435,12 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", + tags = [ + "noasan", # b/73741358 + "nomac", + ], deps = [ ":learn", "//tensorflow/python:array_ops", @@ -472,6 +487,7 @@ py_test( name = "state_saving_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], deps = [ @@ -584,6 +600,7 @@ py_test( size = "small", srcs = ["python/learn/learn_io/io_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/learn/python/learn/datasets", @@ -730,7 +747,7 @@ py_test( tf_py_test( name = "graph_io_test", - size = "small", + size = "medium", srcs = ["python/learn/learn_io/graph_io_test.py"], additional_deps = [ ":learn", @@ -813,6 +830,7 @@ py_test( size = "small", srcs = ["python/learn/utils/saved_model_export_utils_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -867,15 +885,3 @@ py_binary( "//tensorflow/python:platform", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d516bffc5e0327a3400068b35de5503e5a925a54 --- /dev/null +++ b/tensorflow/contrib/learn/README.md @@ -0,0 +1,143 @@ +EVERYTHING IN THIS DIRECTORY IS DEPRECATED. + +Using functions or classes will result in warnings. + +Instructions for converting to current alternatives are included in the +warnings. A high-level overview is below. + +## Canned Estimators + +Many canned estimators (subclasses of `Estimator`) have equivalents in core: +`DNNClassifier`, `DNNRegressor`, `DNNEstimator`, `LinearClassifier`, +`LinearRegressor`, `DNNLinearCombinedClassifier` and +`DNNLinearCombinedRegressor`. They are exposed under `tf.estimator`. +`DNNEstimator`, `LinearEstimator` and `DNNLinearCombinedEstimator` +are exposed under `tf.contrib.estimator`. + +To migrate to the new api, users need to take the following steps: + +* Replace `tf.contrib.learn` with `tf.estimator`. +* If you subclass any of the estimators, stop doing that. You should be able to + write a factory method that returns a canned estimator instead. If this is not + possible (if you override methods from the canned estimator), consider writing + a custom estimator instead. See `tf.estimator.Estimator`. +* Set `loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE` to preserve loss + reduction as the average over batch. +* Some optimizer-related arguments are no longer passed in the estimator + constructor. Instead, we provide methods that perform the same job by wrapping + an optimizer. Specifically: + * `gradient_clip_norm`: Use `tf.contrib.estimator.clip_gradients_by_norm` + * `embedding_lr_multipliers`: Not supported. + Other arguments: + * `input_layer_min_slice_size`: Replaced by `input_layer_partitioner` + * `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. + * `feature_engineering_fn`: Not supported. You can call your + `feature_engineering_fn` inside your input_fn: + ```python + def new_input_fn(): + features, labels = old_input_fn() + return feature_engineering_fn(features, labels) + ``` +* Use `tf.reshape` to reshape labels in your `input_fn`. `tf.estimator` + classifiers and regressors expect labels as a 2D Tensor of shape + `[batch_size, 1]`, or `[batch_size, n_labels]`. In contrast, + `tf.contrib.learn` classifiers and regressors supported labels with shape + `[batch_size]`. +* If you pass custom metrics from the `evaluate()` method call, use + `tf.contrib.estimator.add_metrics`. +* Replace your `serving_input_fn` with a `serving_input_receiver_fn`. + Note this should be entirely distinct from your training `input_fn`, so if you + previously had one `input_fn` with different "modes", you should now factor + that apart. Where the former returned either a simple `(features, labels)` + tuple or `InputFnOps`, you should now return a `ServingInputReceiver`. + If you were generating your `serving_input_fn` using the + `build_parsing_serving_input_fn` helper, you can simply drop in the + replacement `build_parsing_serving_input_receiver_fn`. + +Some remaining estimators/classes: + +* `DynamicRnnEstimator`: Consider a custom `model_fn`. +* `KMeansClustering`: Use `tf.contrib.factorization.KMeansClustering`. +* `LogisticRegressor`: Not supported. Instead, use `binary_classification_head` + with a custom `model_fn`, or with `DNNEstimator`. +* `StateSavingRnnEstimator`: Consider a custom `model_fn`. +* SVM: Consider a custom `model_fn`. +* `LinearComposableModel` and `DNNComposableModel`: Not supported. + Consider `tf.contrib.estimator.DNNEstimator`, or write a custom model_fn. +* `MetricSpec`: Deprecated. For adding custom metrics to canned Estimators, use + `tf.contrib.estimator.add_metrics`. + +## Estimator +`tf.contrib.learn.Estimator` is migrated to `tf.estimator.Estimator`. + +To migrate, users need to take the following steps: + +* Replace `tf.contrib.learn.Estimator` with `tf.estimator.Estimator`. +* If you pass a `config` argument to `Estimator`, this must be + `tf.estimator.RunConfig`. You may need to edit your code accordingly. +* Edit your `model_fn` to return `tf.estimator.EstimatorSpec`. Refer to + `EstimatorSpec` for documentation of specific fields. +* If your `model_fn` uses the `mode` argument, use `tf.estimator.ModeKeys`. + +Some related classes: +* `Evaluable`, `Trainable`: Not supported, merged into `tf.estimator.Estimator`. +* ExportStrategy: Replaced by `tf.estimator.Exporter`. + +## Head/MultiHead +These classes are now supported under `tf.contrib.estimator`, e.g. +`tf.contrib.estimator.multi_class_head` and `tf.contrib.estimator.multi_head`. + +Some differences: + +* `multi_class_head`: If you use `tf.contrib.learn.multi_class_head` with + `n_classes=2`, switch to `tf.contrib.estimator.binary_classification_head`. +* `loss_only_head`: Not supported. +* `poisson_regression_head`: Not supported (yet). +* `binary_svm_head`: Not supported (yet). +* `no_op_train_fn`: Replace it with `tf.no_op`. + +Some arguments are renamed, please refer to documentation. In addition: + +* `loss_fn`: Supported for `multi_label_head`. If you need it for other heads, + please open an issue. +* `metric_class_ids`: Not supported (yet). +* `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. +* `label_name`: Not needed in `tf.estimator`. If you don’t use `multi_head`, + drop this argument. If you use `multi_head`, refer to + `tf.contrib.estimator.multi_head` documentation. + +## Experiment Class - Distributed Training Tooling + +Switch to `tf.estimator.train_and_evaluate`. Some differences: + +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, + `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` + directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement + for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to export + only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirementds (in + particular, the chief node and evaluator node). + +## Others Classes and Functions + +* `tf.contrib.learn.datasets` is deprecated. We are adding ready to use datasets + to tensorflow/models. Many smaller datasets are available from other sources, + such as scikits.learn. Some Python processing may have to be written, but this + is straightforward to implement using the standard modules. +* `tf.contrib.learn.preprocessing`: Deprecated. The python-only preprocessing + functions are not a good fit for TensorFlow. Please use `tf.data`, and + consider tensorflow/transform for more complex use cases. +* `tf.contrib.learn.models`: Not supported, use canned estimators instead. +* `tf.contrib.learn.monitors`: Implement `SessionRunHook` instead. Hook + implementations are in `tf.train`. +* `tf.contrib.learn.learn_io`: Use the methods in `tf.estimator.inputs`, such as + `tf.estimator.inputs.numpy_input_fn`. Some utility functions have no + equivalent, we encourage the use of `tf.data`. + diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 3698af027e38f1063ad829c26eb179734968f813..79bd73faaf1301a2fc4999b64f88d30542577980 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================== -# TODO(ptucker,ipolosukhin): Improve descriptions. -"""High level API for learning. +"""High level API for learning (DEPRECATED). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. See the @{$python/contrib.learn} guide. diff --git a/tensorflow/contrib/learn/python/__init__.py b/tensorflow/contrib/learn/python/__init__.py index bbebd5ab9792cb937219cf937f08c4d4e6e44a92..df23aeb2c433c2b4392f706730f715246ce01cea 100644 --- a/tensorflow/contrib/learn/python/__init__.py +++ b/tensorflow/contrib/learn/python/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index cdc67c77d5fd1df61016835dc75ba44feb458cf9..76e0e8ac8f19026086959f3b197cfd1a81e65a3e 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py index 2284ec46e971731af74f17678fc0d1d3888419e2..fed1c44d1970bf07c808ace817aa9972d7776d88 100644 --- a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py +++ b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py @@ -12,20 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Some common SessionRunHook classes.""" +"""Some common SessionRunHook classes (deprected). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.util.deprecation import deprecated_alias # pylint: disable=invalid-name -LoggingTensorHook = basic_session_run_hooks.LoggingTensorHook -StopAtStepHook = basic_session_run_hooks.StopAtStepHook -CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook -StepCounterHook = basic_session_run_hooks.StepCounterHook -NanLossDuringTrainingError = basic_session_run_hooks.NanLossDuringTrainingError -NanTensorHook = basic_session_run_hooks.NanTensorHook -SummarySaverHook = basic_session_run_hooks.SummarySaverHook +LoggingTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.LoggingTensorHook', + 'tf.train.LoggingTensorHook', + basic_session_run_hooks.LoggingTensorHook) +StopAtStepHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StopAtStepHook', + 'tf.train.StopAtStepHook', + basic_session_run_hooks.StopAtStepHook) +CheckpointSaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.CheckpointSaverHook', + 'tf.train.CheckpointSaverHook', + basic_session_run_hooks.CheckpointSaverHook) +StepCounterHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StepCounterHook', + 'tf.train.StepCounterHook', + basic_session_run_hooks.StepCounterHook) +NanLossDuringTrainingError = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanLossDuringTrainingError', + 'tf.train.NanLossDuringTrainingError', + basic_session_run_hooks.NanLossDuringTrainingError) +NanTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanTensorHook', + 'tf.train.NanTensorHook', + basic_session_run_hooks.NanTensorHook) +SummarySaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.SummarySaverHook', + 'tf.train.SummarySaverHook', + basic_session_run_hooks.SummarySaverHook) # pylint: enable=invalid-name diff --git a/tensorflow/contrib/learn/python/learn/datasets/BUILD b/tensorflow/contrib/learn/python/learn/datasets/BUILD index 8bf372841d04dc9e1339925474801d5aa3af4ccd..2c7215bba3816ff3762e5b7927f650d1c9cbf617 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/BUILD +++ b/tensorflow/contrib/learn/python/learn/datasets/BUILD @@ -44,18 +44,6 @@ py_binary( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_test( name = "base_test", size = "small", diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index 7240b0de149051afa045a8113f9e9b212840c311..3c34712ac859d32f549468345950a93d2ed2aa56 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Dataset utilities and synthetic/reference datasets.""" +"""Dataset utilities and synthetic/reference datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.datasets import mnist from tensorflow.contrib.learn.python.learn.datasets import synthetic from tensorflow.contrib.learn.python.learn.datasets import text_datasets +from tensorflow.python.util.deprecation import deprecated # Export load_iris and load_boston. load_iris = base.load_iris @@ -51,6 +57,7 @@ SYNTHETIC = { } +@deprecated(None, 'Please use tf.data.') def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -73,8 +80,9 @@ def load_dataset(name, size='small', test_with_fake_data=False): return DATASETS[name]() +@deprecated(None, 'Please use tf.data.') def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): - """Creates binary synthetic datasets + """Creates binary synthetic datasets. Args: name: str, name of the dataset to generate diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index ca720ae5ed26e74da12bd6c5a37231b41442f76f..4676eedb206147d178c6a652aa7c2cb48ef888c0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base utilities for loading datasets.""" + +"""Base utilities for loading datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +35,14 @@ import numpy as np from six.moves import urllib from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated + Dataset = collections.namedtuple('Dataset', ['data', 'target']) Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test']) +@deprecated(None, 'Use tf.data instead.') def load_csv_with_header(filename, target_dtype, features_dtype, @@ -53,6 +62,7 @@ def load_csv_with_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def load_csv_without_header(filename, target_dtype, features_dtype, @@ -70,6 +80,7 @@ def load_csv_without_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def shrink_csv(filename, ratio): """Create a smaller dataset of only 1/ratio of original data.""" filename_small = filename.replace('.', '_small.') @@ -84,6 +95,7 @@ def shrink_csv(filename, ratio): i += 1 +@deprecated(None, 'Use scikits.learn.datasets.') def load_iris(data_path=None): """Load Iris dataset. @@ -100,6 +112,7 @@ def load_iris(data_path=None): data_path, target_dtype=np.int, features_dtype=np.float) +@deprecated(None, 'Use scikits.learn.datasets.') def load_boston(data_path=None): """Load Boston housing dataset. @@ -116,20 +129,58 @@ def load_boston(data_path=None): data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): +@deprecated(None, 'Use the retry module or similar alternatives.') +def retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): """Simple decorator for wrapping retriable functions. Args: initial_delay: the initial delay. + max_delay: the maximum delay allowed (actual max is + max_delay * (1 + jitter). factor: each subsequent retry, the delay is multiplied by this value. (must be >= 1). jitter: to avoid lockstep, the returned delay is multiplied by a random number between (1-jitter) and (1+jitter). To add a 20% jitter, set jitter = 0.2. Must be < 1. + is_retriable: (optional) a function that takes an Exception as an argument + and returns true if retry should be applied. + + Returns: + A function that wraps another function to automatically retry it. + """ + return _internal_retry( + initial_delay=initial_delay, + max_delay=max_delay, + factor=factor, + jitter=jitter, + is_retriable=is_retriable) + + +def _internal_retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): + """Simple decorator for wrapping retriable functions, for internal use only. + + Args: + initial_delay: the initial delay. max_delay: the maximum delay allowed (actual max is max_delay * (1 + jitter). + factor: each subsequent retry, the delay is multiplied by this value. + (must be >= 1). + jitter: to avoid lockstep, the returned delay is multiplied by a random + number between (1-jitter) and (1+jitter). To add a 20% jitter, set + jitter = 0.2. Must be < 1. is_retriable: (optional) a function that takes an Exception as an argument and returns true if retry should be applied. + + Returns: + A function that wraps another function to automatically retry it. """ if factor < 1: raise ValueError('factor must be >= 1; was %f' % (factor,)) @@ -152,7 +203,7 @@ def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): for delay in delays(): try: return fn(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except) + except Exception as e: # pylint: disable=broad-except if is_retriable is None: continue @@ -176,11 +227,13 @@ def _is_retriable(e): return isinstance(e, IOError) and e.errno in _RETRIABLE_ERRNOS -@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) +@deprecated(None, 'Please use urllib or similar directly.') +@_internal_retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) def urlretrieve_with_retry(url, filename=None): return urllib.request.urlretrieve(url, filename) +@deprecated(None, 'Please write your own downloading logic.') def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 37f9175015a239f763c7721cf36ab8063c0a3e32..abbb44c2f5b701829ce16f64eadd8ebc04c84e2c 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for downloading and reading MNIST data.""" +"""Functions for downloading and reading MNIST data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated # CVDF mirror of http://yann.lecun.com/exdb/mnist/ DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' @@ -37,6 +43,7 @@ def _read32(bytestream): return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. @@ -65,6 +72,7 @@ def extract_images(f): return data +@deprecated(None, 'Please use tf.one_hot on tensors.') def dense_to_one_hot(labels_dense, num_classes): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] @@ -74,6 +82,7 @@ def dense_to_one_hot(labels_dense, num_classes): return labels_one_hot +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. @@ -103,7 +112,15 @@ def extract_labels(f, one_hot=False, num_classes=10): class DataSet(object): + """Container class for a dataset (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def __init__(self, images, labels, @@ -210,6 +227,8 @@ class DataSet(object): return self._images[start:end], self._labels[start:end] +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def read_data_sets(train_dir, fake_data=False, one_hot=False, @@ -275,5 +294,7 @@ def read_data_sets(train_dir, return base.Datasets(train=train, validation=validation, test=test) +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def load_mnist(train_dir='MNIST-data'): return read_data_sets(train_dir) diff --git a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py index 6e0ba38941ce4650ede9f7210e284bde2ed8e6a9..a4848fa64a72f031ef35c0c3256e97a7326acd60 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Produce DBpedia datasets of a smaller size.""" +"""Produce DBpedia datasets of a smaller size (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py index 9a843168c27d9cae3f55efe4fe4c688d86c745f3..6a0e3350b3d1052249160a2a997a76de7a5040c3 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Synthetic dataset generators.""" +"""Synthetic dataset generators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,8 +26,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.learn.python.learn.datasets.base import Dataset +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def circles(n_samples=100, noise=None, seed=None, @@ -93,6 +100,7 @@ def circles(n_samples=100, return Dataset(data=X[indices], target=y[indices]) +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def spirals(n_samples=100, noise=None, seed=None, diff --git a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py index 2596a2ecaf1572506504831e8b08fab9b5dbc119..ce9466301728082f8e9d99c90989ba8fe623bcf0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Text datasets.""" +"""Text datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,10 +31,12 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated DBPEDIA_URL = 'https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz' +@deprecated(None, 'See contrib/learn/README.md') def maybe_download_dbpedia(data_dir): """Download if DBpedia data is not present.""" train_path = os.path.join(data_dir, 'dbpedia_csv/train.csv') @@ -41,6 +48,7 @@ def maybe_download_dbpedia(data_dir): tfile.extractall(data_dir) +@deprecated(None, 'See contrib/learn/README.md') def load_dbpedia(size='small', test_with_fake_data=False): """Get DBpedia datasets from CSV files.""" if not test_with_fake_data: diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 4981750c94c7ac31e23b7a3f71ca30e3c9573a20..3e64595f312bcc2a2e8dcba589fb993a249b684b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""An estimator is a rule for calculating an estimate of a given quantity. +"""An estimator is a rule for calculating an estimate of a given quantity (deprecated). + +These classes are deprecated and replaced with `tf.estimator`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. # Estimators diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 15277415a1ce83dc1d4a334e60fe1933ba244df0..1f0e4663d060a3850e2002b27f809fde1db47e48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""sklearn cross-support.""" +"""sklearn cross-support (deprecated).""" from __future__ import absolute_import from __future__ import division @@ -132,6 +132,8 @@ class _TransformerMixin(): class NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. + USE OF THIS EXCEPTION IS DEPRECATED. + This class inherits from both ValueError and AttributeError to help with exception handling and backward compatibility. diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py index a02c726c74946d93b8e1726473db746220b00795..1fa58271e2b886cd143683a759145fd750791473 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow composable models used as building blocks for estimators.""" +"""TensorFlow composable models used as building blocks for estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,6 +39,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated class _ComposableModel(object): @@ -46,6 +52,7 @@ class _ComposableModel(object): _ComposableModel and its subclasses are not part of the public tf.learn API. """ + @deprecated(None, "Please use model_fns in tf.estimator.") def __init__(self, num_label_columns, optimizer, @@ -141,6 +148,10 @@ class _ComposableModel(object): class LinearComposableModel(_ComposableModel): """A _ComposableModel that implements linear regression. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ @@ -252,6 +263,10 @@ class LinearComposableModel(_ComposableModel): class DNNComposableModel(_ComposableModel): """A _ComposableModel that implements a DNN. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/constants.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py index fc69e810244a182b864be856e6720f8584f7aa65..d2548946bc77dea7c452d61c7e2b6e12c3d6239a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/constants.py +++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py @@ -13,9 +13,11 @@ # limitations under the License. # ============================================================================== -"""Constants regarding Estimators. +"""Constants regarding Estimators (deprecated). -This file is obsoleted in the move of Estimator to core. +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ from __future__ import absolute_import from __future__ import division @@ -25,6 +27,8 @@ from __future__ import print_function class ProblemType(object): """Enum-like values for the type of problem that the model solves. + THIS CLASS IS DEPRECATED. + These values are used when exporting the model to produce the appropriate signature function for serving. diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug.py b/tensorflow/contrib/learn/python/learn/estimators/debug.py index 9d5f6c2bf969d7c85d251bf1b06a0307a41b2297..24b067b7e38b12df3d1d0c49f626344217218571 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Debug estimators. +"""Debug estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Debug estimators are bias-only estimators that can be used for debugging and as simple baselines. @@ -118,6 +122,10 @@ def debug_model_fn(features, labels, mode, params, config=None): class DebugClassifier(estimator.Estimator): """A classifier for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -237,6 +245,10 @@ class DebugClassifier(estimator.Estimator): class DebugRegressor(estimator.Estimator): """A regressor for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index c17b41c0f767e19d9c3635a8f60347a49b297cfb..eabebb7e881558471c343c0573cc9a8f4a425312 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Deep Neural Network estimators.""" +"""Deep Neural Network estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -212,6 +217,10 @@ def _dnn_model_fn(features, labels, mode, params, config=None): class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -521,6 +530,10 @@ class DNNClassifier(estimator.Estimator): class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -796,6 +809,10 @@ class DNNRegressor(estimator.Estimator): class DNNEstimator(estimator.Estimator): """A Estimator for TensorFlow DNN models with user specified _Head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 726612235050def6e7addb503cc6646a25de0e42..3d85533d92d17095bae9a69f229171e1bf61ba10 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow estimators for Linear and DNN joined training models.""" +"""TensorFlow estimators for Linear and DNN joined training models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -372,6 +377,10 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): class DNNLinearCombinedEstimator(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -490,6 +499,10 @@ class DNNLinearCombinedEstimator(estimator.Estimator): class DNNLinearCombinedClassifier(estimator.Estimator): """A classifier for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -832,6 +845,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): class DNNLinearCombinedRegressor(estimator.Estimator): """A regressor for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 69440e823ef1ed2d739f28bc14587891f2de80bb..a703dc66e922d48ceb64edc2a979061b8e45db49 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for Dynamic RNNs.""" +"""Estimator for Dynamic RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -540,6 +545,12 @@ def _get_dynamic_rnn_model_fn( class DynamicRnnEstimator(estimator.Estimator): + """Dynamically unrolled RNN (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 4b63e08ab3372849309ee5d28d754de82e9632f4..7a026a15e4aeea0dde4ed9f7de053a757a0abb58 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base Estimator class.""" +"""Base Estimator class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -138,6 +143,7 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): return df.input_builder, df.get_feed_dict_fn() +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input_fn(input_fn): """Creates `FeatureColumn` objects for inputs defined by `input_fn`. @@ -158,6 +164,7 @@ def infer_real_valued_columns_from_input_fn(input_fn): return layers.infer_real_valued_columns(features) +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input(x): """Creates `FeatureColumn` objects for inputs defined by input `x`. @@ -389,6 +396,10 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Users should not instantiate or subclass this class. Instead, use an `Estimator`. """ @@ -399,6 +410,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): Remove this once launcher takes over config functionality _Config = run_config.RunConfig # pylint: disable=invalid-name + @deprecated(None, 'Please replace uses of any Estimator from tf.contrib.learn' + ' with an Estimator from tf.estimator.*') def __init__(self, model_dir=None, config=None): """Initializes a BaseEstimator instance. @@ -457,6 +470,20 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): make RunConfig immutable, and then return it without a copy. return copy.deepcopy(self._config) + @property + def model_fn(self): + """Returns the model_fn which is bound to self.params. + + Returns: + The model_fn with the following signature: + `def model_fn(features, labels, mode, metrics)` + """ + + def public_model_fn(features, labels, mode, config): + return self._call_model_fn(features, labels, mode, config=config) + + return public_model_fn + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), ('y', None), ('batch_size', None)) def fit(self, @@ -890,8 +917,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, if feed_fn: hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn)) if steps == 0: - logging.warning('evaluation steps are 0. If `input_fn` does not raise' - 'OutOfRangeError`, the evaluation will never stop.' + logging.warning('evaluation steps are 0. If `input_fn` does not raise ' + '`OutOfRangeError`, the evaluation will never stop. ' 'Use steps=None if intended.') if steps: hooks.append( @@ -1074,6 +1101,10 @@ def _identity_feature_engineering_fn(features, labels): class Estimator(BaseEstimator): """Estimator class is the basic TensorFlow model trainer/evaluator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ def __init__(self, @@ -1162,7 +1193,7 @@ class Estimator(BaseEstimator): self._feature_engineering_fn = ( feature_engineering_fn or _identity_feature_engineering_fn) - def _call_model_fn(self, features, labels, mode, metrics=None): + def _call_model_fn(self, features, labels, mode, metrics=None, config=None): """Calls model function with support of 2, 3 or 4 arguments. Args: @@ -1170,6 +1201,7 @@ class Estimator(BaseEstimator): labels: labels dict. mode: ModeKeys metrics: Dict of metrics. + config: RunConfig. Returns: A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a @@ -1186,7 +1218,10 @@ class Estimator(BaseEstimator): if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: - kwargs['config'] = self.config + if config: + kwargs['config'] = config + else: + kwargs['config'] = self.config if 'model_dir' in model_fn_args: kwargs['model_dir'] = self.model_dir model_fn_results = self._model_fn(features, labels, **kwargs) @@ -1458,8 +1493,14 @@ class Estimator(BaseEstimator): # For time of deprecation x,y from Estimator allow direct access. # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): - """Scikit learn wrapper for TensorFlow Learn Estimator.""" + """Scikit learn wrapper for TensorFlow Learn Estimator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please switch to the Estimator interface.') def __init__(self, estimator): self._estimator = estimator diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index d81a534b79bc90fe91ffd3cb97a7865a7cb4c2a9..9e5aaf3118dfed4ce64dd244a915860b5a2eef44 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -715,7 +715,9 @@ class EstimatorTest(test.TestCase): ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') - self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'], + # TODO(b/78461127): Please modify tests to not directly rely on names of + # checkpoints. + self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py index fd47710e3015de9ae6a453f98978b0ef8f88968c..e4c31396baf8271c49395a2b87b454dbc77177e2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utils for Estimator.""" +"""Utils for Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 9b124b2c19f16bbc9b2afeadb82a32006e1a0ae9..339c4e0e360ed9ef9906f0e51b64a0dc13826259 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Abstractions for the head(s) of a model. +"""Abstractions for the head(s) of a model (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -47,11 +52,16 @@ from tensorflow.python.summary import summary from tensorflow.python.training import training from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated class Head(object): """Interface for the head/top of a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Given logits (or output of a hidden layer), a Head knows how to compute predictions, loss, default metric and export signature. It is meant to, @@ -177,6 +187,7 @@ class Head(object): raise NotImplementedError("Calling an abstract method.") +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -216,6 +227,7 @@ def regression_head(label_name=None, link_fn=(link_fn if link_fn is not None else array_ops.identity)) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def poisson_regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -254,6 +266,7 @@ def poisson_regression_head(label_name=None, # TODO(zakaria): Consider adding a _RegressionHead for logistic_regression +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_class_head(n_classes, label_name=None, weight_column_name=None, @@ -335,6 +348,7 @@ def multi_class_head(n_classes, label_keys=label_keys) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def binary_svm_head( label_name=None, weight_column_name=None, @@ -370,6 +384,7 @@ def binary_svm_head( thresholds=thresholds) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_label_head(n_classes, label_name=None, weight_column_name=None, @@ -430,6 +445,7 @@ def multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def loss_only_head(loss_fn, head_name=None): """Creates a Head that contains only loss terms. @@ -447,6 +463,7 @@ def loss_only_head(loss_fn, head_name=None): return _LossOnlyHead(loss_fn, head_name=head_name) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. @@ -479,6 +496,7 @@ def multi_head(heads, loss_weights=None): return _MultiHead(heads, loss_merger=_weighted_loss_merger) +@deprecated(None, "Use 'lambda _: tf.no_op()'.") def no_op_train_fn(loss): del loss return control_flow_ops.no_op() @@ -759,7 +777,7 @@ class _RegressionHead(_SingleHead): key = prediction_key.PredictionKey.SCORES with ops.name_scope(None, "predictions", (logits,)): if self.logits_dimension == 1: - logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key) + logits = array_ops.squeeze(logits, axis=(1,), name=key) return {key: self._link_fn(logits)} def _metrics(self, eval_loss, predictions, labels, weights): @@ -956,7 +974,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None): is_squeezed_labels = False # TODO(ptucker): This will break for dynamic shapes. if len(labels.get_shape()) == 2: - labels = array_ops.squeeze(labels, squeeze_dims=(1,)) + labels = array_ops.squeeze(labels, axis=(1,)) is_squeezed_labels = True loss = nn.sparse_softmax_cross_entropy_with_logits( @@ -1844,12 +1862,12 @@ def _get_arguments(func): if hasattr(func, "__code__"): # Regular function. return tf_inspect.getargspec(func) - elif hasattr(func, "__call__"): - # Callable object. - return _get_arguments(func.__call__) elif hasattr(func, "func"): # Partial function. return _get_arguments(func.func) + elif hasattr(func, "__call__"): + # Callable object. + return _get_arguments(func.__call__) def _verify_loss_fn_args(loss_fn): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 6d5da81b4c2087fb9c5307902e452a6220a17cd0..7c2d9bb0767cb979dae9c84b5342d129225677ed 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -362,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase): "auc_precision_recall": 0.166667, "auc_precision_recall/class0": 0, "auc_precision_recall/class1": 0., - "auc_precision_recall/class2": 0.49999, + "auc_precision_recall/class2": 1., "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -748,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, - "auc_precision_recall": 0.25, + "auc_precision_recall": 0.749999, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 8f9d6fc318a357853bdb8e3264f6691b410006b1..66ebcfd1d81904b9afe5be6bd1a648fe325e1e0b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API. +"""Implementation of k-means clustering on top of `Estimator` API (deprecated). This module is deprecated. Please use @{tf.contrib.factorization.KMeansClustering} instead of @@ -153,7 +153,12 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE RANDOM_INIT = clustering_ops.RANDOM_INIT diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py index b28835a809736a099ad2f08d127dc68d7977a3c1..584556992a0db2345e182e92c4a7f7582d3cd8dc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import random_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import flags from tensorflow.python.platform import test diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 37aa8b339622415d082933cdf66d2472a4119b48..e100bc7a1e7be4896e9ab1c965775b5185b38897 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Linear Estimators.""" +"""Linear Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,7 +31,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -46,6 +50,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook from tensorflow.python.training import training as train +from tensorflow.python.training import training_util # The default learning rate of 0.2 is a historical artifact of the initial @@ -238,8 +243,10 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" - with variable_scope.variable_op_scope( - features.values(), parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), + name_or_scope=parent_scope, + partitioner=optimizer.partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -305,6 +312,10 @@ class _SdcaUpdateWeightsHook(session_run_hook.SessionRunHook): class LinearClassifier(estimator.Estimator): """Linear classifier model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear model to classify instances into one of multiple possible classes. When number of possible classes is 2, this is binary classification. @@ -625,6 +636,10 @@ class LinearClassifier(estimator.Estimator): class LinearRegressor(estimator.Estimator): """Linear regressor model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear regression model to predict label value given observation of feature values. @@ -860,6 +875,10 @@ class LinearRegressor(estimator.Estimator): class LinearEstimator(estimator.Estimator): """Linear model with user specified head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a generalized linear model to predict label value given observation of feature values. diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index d3bb0fda5765d88ec064047f523de853d3de6a3f..597ca4e86dbf66c86182f14a2a364b662d52fb0a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test from tensorflow.python.training import ftrl from tensorflow.python.training import input as input_lib @@ -863,6 +864,38 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self): + """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + sparse_tensor.SparseTensor( + values=[2., 3., 1.], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]), + 'country': + sparse_tensor.SparseTensor( + # 'GB' is out of the vocabulary. + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]) + }, constant_op.constant([[1], [0], [1]]) + + country = feature_column_lib.sparse_column_with_keys( + 'country', keys=['US', 'CA', 'MK', 'IT', 'CN']) + country_weighted_by_price = feature_column_lib.weighted_sparse_column( + country, 'price') + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id') + classifier = linear.LinearClassifier( + feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerCrossedFeatures(self): """Tests LinearClassifier with SDCAOptimizer and crossed features.""" @@ -934,6 +967,63 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearClassifier with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + classifier = linear.LinearClassifier( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + print('all scores = {}'.format(scores)) + self.assertGreater(scores['accuracy'], 0.9) + def testEval(self): """Tests that eval produces correct metrics. """ @@ -1508,6 +1598,60 @@ class LinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearRegressor with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([0.6, 0.8, 0.3]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', symmetric_l2_regularization=1.0, + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + regressor = linear.LinearRegressor( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """Tests LinearClassifier with SDCAOptimizer and sparse features.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py index fb339160d58e09d4ffd50090f2dbbcec08bebe47..3cbcc6e98de1c915c302617e4591c9baa33adeaf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logistic regression (aka binary classifier) class. +"""Logistic regression (aka binary classifier) class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This defines some useful basic metrics for using logistic regression to classify a binary event (0 vs 1). @@ -75,6 +79,10 @@ def LogisticRegressor( # pylint: disable=invalid-name feature_engineering_fn=None): """Builds a logistic regression Estimator for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This method provides a basic Estimator with some additional metrics for custom binary classification models, including AUC, precision/recall and accuracy. diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 99388f116b345bd038f2985606c6203011597ea2..f264248e44d9aa48f26ee32e36746bd4c3145a8d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for metric keys.""" +"""Enum for metric keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function class MetricKey(object): - """Metric key strings.""" + """Metric key strings (deprecated).""" + LOSS = "loss" AUC = "auc" AUC_PR = "auc_precision_recall" diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 44e6c7c52dac524a22e9099e33e2aef82f8fe7ba..dcb161180c99ce71195c820217e8bdaf79d70901 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Classes and methods related to model_fn.""" +"""Classes and methods related to model_fn (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -37,10 +42,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import session_run_hook +from tensorflow.python.util.deprecation import deprecated class ModeKeys(object): - """Standard names for model modes. + """Standard names for model modes (deprecated). + + THIS CLASS IS DEPRECATED. The following standard keys are defined: @@ -65,8 +73,16 @@ class ModelFnOps( 'output_alternatives', 'training_chief_hooks', 'training_hooks', 'scaffold', 'mode' ])): - """Ops returned from a model_fn.""" + """Ops returned from a model_fn. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'When switching to tf.estimator.Estimator, use ' + 'tf.estimator.EstimatorSpec. You can use the `estimator_spec`' + ' method to create an equivalent one.') def __new__(cls, mode, predictions=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py index f8d87b8914307a86eb2f46123a28ff11eb925eda..6fd2fc9d592cef4e44a640e2f27cb28b367d44d5 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for model prediction keys. +"""Enum for model prediction keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This file is obsoleted in the move of Estimator to core. """ @@ -22,6 +26,8 @@ from __future__ import print_function class PredictionKey(object): + """THIS CLASS IS DEPRECATED.""" + CLASSES = "classes" PROBABILITIES = "probabilities" LOGITS = "logits" diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index 2752bc2d90ee0f51b2c40ccc4d24a4eb21cff38f..215022e5d9e5d3cd5d6a96583b325b19a1719568 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common operations for RNN Estimators.""" +"""Common operations for RNN Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index fd90fd1cc6277e7d80287aefdbab6134dac7c0d5..14ee2ba6094760d52180d6de7763ea88b8ee98c8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Run Config.""" +"""Run Config (deprecated, use tf.estimator.RunConfig instead). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +34,12 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as core_run_config from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util.deprecation import deprecated # A list of the property names in RunConfig user allows to change. They will # not affect the execution framework, so when execution framework checks the -# `uid` of the RunConfig, it should be ingored. +# `uid` of the RunConfig, it should be ignored. _DEFAULT_UID_WHITE_LIST = [ 'tf_random_seed', 'save_summary_steps', @@ -47,6 +53,7 @@ _DEFAULT_UID_WHITE_LIST = [ class Environment(object): + """DEPRECATED CLASS.""" # For running general distributed training. CLOUD = 'cloud' # For running Google-internal distributed training. @@ -56,6 +63,7 @@ class Environment(object): class TaskType(object): + """DEPRECATED CLASS.""" MASTER = 'master' PS = 'ps' WORKER = 'worker' @@ -64,6 +72,8 @@ class TaskType(object): class ClusterConfig(object): """This class specifies the configurations for a distributed run. + THIS CLASS IS DEPRECATED. Use tf.estimator.RunConfig instead. + If you're using an `Estimator`, you should probably use the subclass RunConfig instead. """ @@ -211,10 +221,13 @@ class ClusterConfig(object): class RunConfig(ClusterConfig, core_run_config.RunConfig): """This class specifies the configurations for an `Estimator` run. - This class is the implementation of @{tf.estimator.RunConfig} interface. + This class is a deprecated implementation of @{tf.estimator.RunConfig} + interface. """ _USE_DEFAULT = 0 + @deprecated(None, 'When switching to tf.estimator.Estimator, use' + ' tf.estimator.RunConfig instead.') def __init__(self, master=None, num_cores=0, @@ -277,8 +290,16 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. """ - super(RunConfig, self).__init__( - master=master, evaluation_master=evaluation_master) + # Neither parent class calls super().__init__(), so here we have to + # manually call their __init__() methods. + ClusterConfig.__init__( + self, master=master, evaluation_master=evaluation_master) + # For too long this code didn't call: + # core_run_config.RunConfig.__init__(self) + # so instead of breaking compatibility with that assumption, we + # just manually initialize this field: + self._train_distribute = None + self._device_fn = None gpu_options = config_pb2.GPUOptions( per_process_gpu_memory_fraction=gpu_memory_fraction) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index 0cea35e219a4457417a161a3ac4ac4292fd24f53..de78c72c3ae3ef14f5f7c46b1d47f82e8266c7c6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for State Saving RNNs.""" +"""Estimator for State Saving RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -528,6 +533,12 @@ def _get_rnn_model_fn(cell_type, class StateSavingRnnEstimator(estimator.Estimator): + """RNN with static unrolling and state saving (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index 72920d73c0c92886e54f533ad7fe170fe27d9870..3459997baba16fc0d4045e50819ecdd0e7121657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support Vector Machine (SVM) Estimator.""" +"""Support Vector Machine (SVM) Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -36,6 +41,10 @@ def _as_iterable(preds, output): class SVM(estimator.Estimator): """Support Vector Machine (SVM) model for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Currently, only linear SVMs are supported. For the underlying optimization problem, the `SDCAOptimizer` is used. For performance and convergence tuning, the num_loss_partitions parameter passed to `SDCAOptimizer` (see `__init__()` diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py index a120bc6cc3975a3d4559d018c8aa74ff42a16d2d..71b5658dd174d2b47e33860844359f68e6768026 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py +++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorSignature class and utilities.""" +"""TensorSignature class and utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -33,6 +38,10 @@ class TensorSignature(collections.namedtuple( "TensorSignature", ["dtype", "shape", "is_sparse"])): """Signature of the `Tensor` object. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Useful to check compatibility of tensors. Example: diff --git a/tensorflow/contrib/learn/python/learn/estimators/test_data.py b/tensorflow/contrib/learn/python/learn/estimators/test_data.py index ed201bfc58f273e6587850032386c2686aea4148..e4b057b4f5a9e081c2d891bd9828ffc315e51e91 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/test_data.py +++ b/tensorflow/contrib/learn/python/learn/estimators/test_data.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test data utilities.""" +"""Test data utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 8f6cd39864b437f163dd7c1140dc88755ce98529..10881ca885599bc81386e15f814a2687d907f63b 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Evaluable` interface.""" +"""`Evaluable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,10 @@ import abc class Evaluable(object): """Interface for objects that are evaluatable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index bec976afd2719138117976381669ca3292360480..f8a3709ee57a32734afa7ac8133271c75d152b2c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experiment class collecting information needed for a single training run.""" +"""Experiment class collecting information for a single training run (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -25,7 +30,6 @@ import os import time from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import export_strategy @@ -34,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import function_utils __all__ = ["Experiment"] def _get_standardized_predicate_fn(predicate_fn): - pred_fn_args = estimator_util.fn_args(predicate_fn) + pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): @@ -118,6 +122,10 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): class Experiment(object): """Experiment is a class containing all information needed to train a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + After an experiment is created (by passing an Estimator and inputs for training and evaluation), an Experiment instance knows how to invoke training and eval loops in a sensible fashion for distributed training. @@ -125,16 +133,8 @@ class Experiment(object): # TODO(ispir): remove delay_workers_by_global_step and make global step based # waiting as only behavior. - @deprecated_args( - "2016-10-23", - "local_eval_frequency is deprecated as local_run will be renamed to " - "train_and_evaluate. Use min_eval_frequency and call train_and_evaluate " - "instead. Note, however, that the default for min_eval_frequency is 1, " - "meaning models will be evaluated every time a new checkpoint is " - "available. In contrast, the default for local_eval_frequency is None, " - "resulting in evaluation occurring only after training has completed. " - "min_eval_frequency is ignored when calling the deprecated local_run.", - "local_eval_frequency") + @deprecated(None, "Please switch to tf.estimator.train_and_evaluate. You will" + " also have to convert to a tf.estimator.Estimator.") def __init__(self, estimator, train_input_fn, @@ -152,7 +152,8 @@ class Experiment(object): export_strategies=None, train_steps_per_iteration=None, checkpoint_and_export=False, - saving_listeners=None): + saving_listeners=None, + check_interval_secs=5): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -190,8 +191,9 @@ class Experiment(object): number of steps between evaluations. Of course, evaluation does not occur if no new snapshot is available, hence, this is the minimum. If 0, the evaluation will only happen after training. - If None, defaults to 1, unless model_dir is on GCS, in which case the - default is 1000. + If None, defaults to 1. To avoid checking for new checkpoints too + frequent, the interval is further limited to be at least + check_interval_secs between checks. delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. export_strategies: Iterable of `ExportStrategy`s, or a single one, or @@ -215,7 +217,10 @@ class Experiment(object): saving_listeners: list of `CheckpointSaverListener` objects. Used by tf.estimator.Estimator for callbacks that run immediately before or after checkpoint savings. - + check_interval_secs: + Minimum time between subsequent checks for a new checkpoint. This + mostly applies if both min_eval_frequency and the time spent per + training step is low. Raises: ValueError: if `estimator` does not implement Estimator interface, or if export_strategies has the wrong type. @@ -261,13 +266,9 @@ class Experiment(object): self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._checkpoint_and_export = checkpoint_and_export self._saving_listeners = saving_listeners - # Using 1 on a non-cached file system requires a lot of overhead to - # read the checkpoint state file. This is particular bad on GCS, so - # we use a different default. This is a temporary band-aid, to be - # fixed holistically later (b/36498507). - default_min_eval_frequency = 1000 if _is_gcs(estimator.model_dir) else 1 self._min_eval_frequency = min_eval_frequency if ( - min_eval_frequency is not None) else default_min_eval_frequency + min_eval_frequency is not None) else 1 + self._check_interval_secs = check_interval_secs self._delay_workers_by_global_step = delay_workers_by_global_step self._train_monitors = train_monitors[:] if train_monitors else [] self._eval_hooks = eval_hooks[:] if eval_hooks else [] @@ -357,7 +358,7 @@ class Experiment(object): self._start_server() elif config.cluster_spec and config.master: raise ValueError( - "For distributed runtime, Experiment class only works with" + "For distributed runtime, Experiment class only works with " "tf.contrib.learn.RunConfig for now, but provided {}".format( type(config))) @@ -467,10 +468,15 @@ class Experiment(object): on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` will be None so it's expected that the predicate function handles that gracefully. - When `predicate_fn` is not specified, continuous eval will run in an - infinite loop (if `train_steps` is None). or exit once global step - reaches `train_steps`. - + Continuous eval behavior under different conditions: + * When `predicate_fn` is specified: + + if `train_steps` is None, run until `predicate_fn` returns False. + + if `train_steps` is specified, run until either global step + reaches `train_steps` or `predicate_fn` returns False. + * When `predicate_fn` is not specified: + + if `train_steps` is None, run in an infinite loop. + + if `train_steps` is specified, run until global step reaches + `train_steps`. export: Whether to export from this step. Default is 'True'. Raises: @@ -499,7 +505,7 @@ class Experiment(object): eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( - eval_result, checkpoint_path=previous_path if eval_result else None)): + eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " @@ -646,12 +652,19 @@ class Experiment(object): self._train_monitors += [saver_hook] else: if self._min_eval_frequency: + # Using low min_eval_frequency (default is 1) on a non-cached file + # system requires a lot of overhead to read the checkpoint state file. + # This is particular bad on GCS and CNS. See also b/36498507 for + # context. `check_interval_secs = 5` avoids polling a remote + # fileystem too often. + self._train_monitors += [ monitors.ValidationMonitor( input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._min_eval_frequency, + check_interval_secs=self._check_interval_secs, name=eval_dir_suffix, hooks=self._eval_hooks) ] @@ -928,7 +941,3 @@ def _new_attr_context(obj, attr): yield finally: setattr(obj, attr, saved) - - -def _is_gcs(model_dir): - return model_dir and model_dir.startswith("gs://") diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 545d7d8924c0c10544e6113e2968b7ae3d2090fc..fb16c94c29660e2777942ea9cf30da51dbf90571 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase): noop_hook = _NoopHook() def _predicate_fn(eval_result, checkpoint_path): - self.assertEqual(not eval_result, + self.assertEqual(eval_result is None, checkpoint_path is None) return est.eval_count < 3 # pylint: disable=cell-var-from-loop @@ -674,37 +674,11 @@ class ExperimentTest(test.TestCase): def test_min_eval_frequency_defaults(self): def dummy_model_fn(features, labels): # pylint: disable=unused-argument pass - - # The default value when model_dir is on GCS is 1000 - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, train_input_fn=None, eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 1000) - - # The default value when model_dir is not on GCS is 1 estimator = core_estimator.Estimator(dummy_model_fn, '/tmp/dummy') ex = experiment.Experiment( estimator, train_input_fn=None, eval_input_fn=None) self.assertEquals(ex._min_eval_frequency, 1) - # Make sure default not used when explicitly set - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, - min_eval_frequency=123, - train_input_fn=None, - eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 123) - - # Make sure default not used when explicitly set as 0 - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, - min_eval_frequency=0, - train_input_fn=None, - eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 0) - def test_continuous_train_and_eval(self): for est in self._estimators_for_tests(eval_dict={'global_step': 100}): if isinstance(est, core_estimator.Estimator): diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index 55a8b824312b89e0ac66513242191f4201ac212a..075cab536ecb5279e7e6f23abb0b70c75043a7ec 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""ExportStrategy class represents different flavors of model export.""" +"""ExportStrategy class represents different flavors of model export (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated __all__ = ['ExportStrategy'] @@ -30,6 +36,10 @@ class ExportStrategy( ['name', 'export_fn', 'strip_default_attrs'])): """A class representing a type of model export. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Typically constructed by a utility function specific to the exporter, such as `saved_model_export_utils.make_export_strategy()`. @@ -56,6 +66,8 @@ class ExportStrategy( forward compatibility of the resulting `SavedModel`. """ + @deprecated(None, 'Please switch to tf.estimator.train_and_evaluate, and use ' + 'tf.estimator.Exporter.') def __new__(cls, name, export_fn, strip_default_attrs=None): return super(ExportStrategy, cls).__new__( cls, name, export_fn, strip_default_attrs) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 98365c05f663e5d2a06703457fc5663d7135f7d9..a997fab723a16dddf150aa9397863605e4e77933 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level operations on graphs.""" +"""High level operations on graphs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -68,6 +73,7 @@ def clear_summary_writers(): return summary_io.SummaryWriterCache.clear() +@deprecated(None, 'Use `SummaryWriterCache.get` directly.') def get_summary_writer(logdir): """Returns single SummaryWriter per logdir in current run. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py index 06c3782a471537cf3879450e6bd20899a35d96ac..8b133a4440d8cbc19abca64f972791fc16ade6f8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tools to allow different io formats.""" +"""Tools to allow different io formats (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py index 7d666391cea3c0a52a2cb7e324c00d5f480710d5..e0a1948d95a727675dac8ff3ce9f55c35d5f8d8d 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Methods to allow dask.DataFrame.""" +"""Methods to allow dask.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.util.deprecation import deprecated + try: # pylint: disable=g-import-not-at-top import dask.dataframe as dd @@ -60,6 +67,7 @@ def _construct_dask_df_with_divisions(df): return dd.Series(merge(dsk, df.dask), name, df.name, divisions) +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_data(data): """Extract data from dask.Series or dask.DataFrame for predictors. @@ -81,6 +89,7 @@ def extract_dask_data(data): return data +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_labels(labels): """Extract data from dask.Series or dask.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 96be8b1bc402479d5611965f27abb197363cb939..c45b1d186471125776d6536112aebb66bb5ad558 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementations of different data feeders to provide data for TF trainer.""" +"""Implementations of different data feeders to provide data for TF trainer (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues. @@ -31,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels @@ -101,6 +107,7 @@ def _is_iterable(x): return hasattr(x, 'next') or hasattr(x, '__next__') +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_train_data_feeder(x, y, n_classes, @@ -188,6 +195,7 @@ def _batch_data(x, batch_size=None): yield np.matrix(chunk) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_predict_data_feeder(x, batch_size=None): """Returns an iterable for feeding into predict step. @@ -219,6 +227,7 @@ def setup_predict_data_feeder(x, batch_size=None): return [x] +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_processor_data_feeder(x): """Sets up processor iterable. @@ -233,6 +242,7 @@ def setup_processor_data_feeder(x): return x +@deprecated(None, 'Please convert numpy dtypes explicitly.') def check_array(array, dtype): """Checks array on dtype and converts it if different. @@ -275,8 +285,14 @@ def _check_dtype(dtype): class DataFeeder(object): - """Data feeder is an example class to sample data for TF trainer.""" + """Data feeder is an example class to sample data for TF trainer. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, x, y, @@ -563,6 +579,10 @@ class DataFeeder(object): class StreamingDataFeeder(DataFeeder): """Data feeder for TF trainer that reads data from iterator. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Streaming data feeder allows to read data as it comes it from disk or somewhere else. It's custom to have this iterators rotate infinetly over the dataset, to allow control of how much to learn on the trainer side. @@ -771,11 +791,16 @@ class StreamingDataFeeder(DataFeeder): class DaskDataFeeder(object): """Data feeder for that reads data from dask.Series and dask.DataFrame. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Numpy arrays can be serialized to disk and it's possible to do random seeks into them. DaskDataFeeder will remove requirement to have full dataset in the memory and still do random seeks for sampling of batches. """ + @deprecated(None, 'Please feed input to tf.data to support dask.') def __init__(self, x, y, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py index 82848be7df653dd60219317d28f233767746f544..1f439965daf956665bbedc919281df0ee07b5d62 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os.path import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin @@ -26,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.learn.python.learn.learn_io import * from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io from tensorflow.python.platform import test # pylint: enable=wildcard-import @@ -35,6 +37,13 @@ class DataFeederTest(test.TestCase): # pylint: disable=undefined-variable """Tests for `DataFeeder`.""" + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), 'base_dir') + file_io.create_dir(self._base_dir) + + def tearDown(self): + file_io.delete_recursively(self._base_dir) + def _wrap_dict(self, data, prepend=''): return {prepend + '1': data, prepend + '2': data} @@ -45,14 +54,14 @@ class DataFeederTest(test.TestCase): def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data): feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) if isinstance(input_data, dict): - for k, v in list(feeder.input_dtype.items()): + for v in list(feeder.input_dtype.values()): self.assertEqual(expected_np_dtype, v) else: self.assertEqual(expected_np_dtype, feeder.input_dtype) with ops.Graph().as_default() as g, self.test_session(g): inp, _ = feeder.input_builder() if isinstance(inp, dict): - for k, v in list(inp.items()): + for v in list(inp.values()): self.assertEqual(expected_tf_dtype, v.dtype) else: self.assertEqual(expected_tf_dtype, inp.dtype) @@ -301,7 +310,10 @@ class DataFeederTest(test.TestCase): [0.60000002, 0.2]]) self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]]) - def test_hdf5_data_feeder(self): + # TODO(rohanj): Fix this test by fixing data_feeder. Currently, h5py doesn't + # support permutation based indexing lookups (More documentation at + # http://docs.h5py.org/en/latest/high/dataset.html#fancy-indexing) + def DISABLED_test_hdf5_data_feeder(self): def func(df): inp, out = df.input_builder() @@ -314,11 +326,12 @@ class DataFeederTest(test.TestCase): import h5py # pylint: disable=g-import-not-at-top x = np.matrix([[1, 2], [3, 4]]) y = np.array([1, 2]) - h5f = h5py.File('test_hdf5.h5', 'w') + file_path = os.path.join(self._base_dir, 'test_hdf5.h5') + h5f = h5py.File(file_path, 'w') h5f.create_dataset('x', data=x) h5f.create_dataset('y', data=y) h5f.close() - h5f = h5py.File('test_hdf5.h5', 'r') + h5f = h5py.File(file_path, 'r') x = h5f['x'] y = h5f['y'] func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3)) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py index 884faf8335e2a3ca1d27d2d93b4c817131648774..f8aaa0c9e3e5b589a6ad47678dba3dc38de7c471 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow generator of dict with numpy arrays.""" +"""Methods to allow generator of dict with numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,8 +28,10 @@ from types import FunctionType from types import GeneratorType from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.data.') def generator_input_fn(x, target_key=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 3a46c239688017f9204d2c6182a6f81cd325a417..9e816f54b6cf8dee84c6d62406ab3db700054d06 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to read data in the graph.""" +"""Methods to read data in the graph (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,11 +39,13 @@ from tensorflow.python.platform import gfile from tensorflow.python.summary import summary from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner +from tensorflow.python.util.deprecation import deprecated # Default name for key in the feature dict. KEY_FEATURE_NAME = '__key__' +@deprecated(None, 'Use tf.data.') def read_batch_examples(file_pattern, batch_size, reader, @@ -106,6 +113,7 @@ def read_batch_examples(file_pattern, return examples +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples(file_pattern, batch_size, reader, @@ -175,6 +183,7 @@ def read_keyed_batch_examples(file_pattern, seed=seed) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples_shared_queue(file_pattern, batch_size, reader, @@ -452,6 +461,7 @@ def _read_keyed_batch_examples_helper(file_pattern, return queued_examples_with_keys +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features(file_pattern, batch_size, features, @@ -540,6 +550,7 @@ def read_keyed_batch_features(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features_shared_queue(file_pattern, batch_size, features, @@ -620,6 +631,7 @@ def read_keyed_batch_features_shared_queue(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def queue_parsed_features(parsed_features, keys=None, feature_queue_capacity=100, @@ -742,6 +754,7 @@ def queue_parsed_features(parsed_features, return dequeued_keys, dequeued_parsed_features +@deprecated(None, 'Use tf.data.') def read_batch_features(file_pattern, batch_size, features, @@ -821,6 +834,7 @@ def read_batch_features(file_pattern, return features +@deprecated(None, 'Use tf.data.') def read_batch_record_features(file_pattern, batch_size, features, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 692438807fbd7febb156d4db73b5d3deba6c987d..29552d24f1eaa0d85a99c8b09f69d007e7e4fe9f 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow dict of numpy arrays.""" +"""Methods to allow dict of numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_numpy_input_fn +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Use tf.estimator.inputs.numpy_input_fn.') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index ede7558eafa9237dc63aa95a62e599c5e9755822..b4ef055f5ae484ec704ad42efcf2c00c4a7a4f56 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -13,13 +13,19 @@ # limitations under the License. # ============================================================================== -"""Methods to allow pandas.DataFrame.""" +"""Methods to allow pandas.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn as core_pandas_input_fn +from tensorflow.python.util.deprecation import deprecated try: # pylint: disable=g-import-not-at-top @@ -47,6 +53,7 @@ PANDAS_DTYPES = { } +@deprecated(None, 'Please use tf.estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, @@ -66,6 +73,7 @@ def pandas_input_fn(x, target_column=target_column) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_data(data): """Extract data from pandas.DataFrame for predictors. @@ -96,6 +104,7 @@ def extract_pandas_data(data): 'float, or bool. Found: ' + ', '.join(error_report)) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_matrix(data): """Extracts numpy matrix from pandas DataFrame. @@ -111,6 +120,7 @@ def extract_pandas_matrix(data): return data.as_matrix() +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_labels(labels): """Extract data from pandas.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 2af723a0d64822e81fa0fbeb106ab812de6ab4e8..d719a3e488b9905ef7903e21d90dbaae0449735c 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Runs an Experiment.""" +"""Runs an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config as run_c from tensorflow.contrib.learn.python.learn.experiment import Experiment from tensorflow.contrib.training.python.training import hparam as hparam_lib from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(xiejw): Refactor the learn_runner to make code reusable. @@ -99,6 +105,7 @@ def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False): return wrapped_experiment_fn +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def run(experiment_fn, output_dir=None, schedule=None, run_config=None, hparams=None): """Make and run an experiment. @@ -218,6 +225,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, return _execute_schedule(experiment, schedule) +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def tune(experiment_fn, tuner): """Tune an experiment with hyper-parameters. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py index 7d9b1c7716f0ab1f2274ca53406175240b613027..ba2d067787c1dfd4e4820ecc916f1053e9f3cf60 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities to run and tune an Experiment. +"""Utilities to run and tune an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@run @@tune diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index 6440bc204b8e339ff51311dcc87b36f556b94092..97220365d5dddb82b602369f06bea021a86d584f 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The metric spec class to flexibly connect models and metrics.""" +"""The metric spec class to flexibly connect models and metrics (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ import six from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated def _assert_named_args(sentinel): @@ -223,6 +229,10 @@ def _adapt_metric_fn( class MetricSpec(object): """MetricSpec connects a model to metric functions. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + The MetricSpec class contains all information necessary to connect the output of a `model_fn` to the metrics (usually, streaming metrics) that are used in evaluation. @@ -284,6 +294,7 @@ class MetricSpec(object): """ + @deprecated(None, 'Use tf.estimator.EstimatorSpec.eval_metric_ops.') def __init__(self, metric_fn, prediction_key=None, diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py index 4283240d018c949bb35aeb12032d2ee8b75884a5..bd4bbf9f8c9ad7e8a0fc06d8c0dc24672536c158 100644 --- a/tensorflow/contrib/learn/python/learn/models.py +++ b/tensorflow/contrib/learn/python/learn/models.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Various high level TF models.""" +"""Various high level TF models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -28,8 +33,10 @@ from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using a tf.estimator.LinearRegressor') def linear_regression_zero_init(x, y): """Linear regression subgraph with zero-value initial weights and bias. @@ -43,6 +50,7 @@ def linear_regression_zero_init(x, y): return linear_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.LinearClassifier') def logistic_regression_zero_init(x, y): """Logistic regression subgraph with zero-value initial weights and bias. @@ -56,6 +64,7 @@ def logistic_regression_zero_init(x, y): return logistic_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.') def linear_regression(x, y, init_mean=None, init_stddev=1.0): """Creates linear regression TensorFlow subgraph. @@ -107,6 +116,7 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0): return losses_ops.mean_squared_error_regressor(x, y, weights, bias) +@deprecated(None, 'Consider using a class from tf.estimator.') def logistic_regression(x, y, class_weight=None, @@ -203,6 +213,7 @@ def _reverse_seq(input_seq, lengths): return result +@deprecated(None, 'Please consider `tf.nn.bidirectional_dynamic_rnn`.') def bidirectional_rnn(cell_fw, cell_bw, inputs, @@ -283,6 +294,7 @@ def bidirectional_rnn(cell_fw, # End of TensorFlow 0.7 +@deprecated(None, 'Please consider tensorflow/tensor2tensor.') def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn, bidirectional, target_predictor_fn, sequence_length, initial_state, attn_length, attn_size, attn_vec_size): diff --git a/tensorflow/contrib/learn/python/learn/monitored_session.py b/tensorflow/contrib/learn/python/learn/monitored_session.py index 22602e9f69d972505d83a66a6f9183b5e4d15c44..ac0433f1775feeed2ec3cf49291da01500bef01b 100644 --- a/tensorflow/contrib/learn/python/learn/monitored_session.py +++ b/tensorflow/contrib/learn/python/learn/monitored_session.py @@ -13,7 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A wrapper of Session API which runs hooks.""" +"""A wrapper of Session API which runs hooks (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 51381a7427c919592b8e818c4b46dba974992610..77f7c73d5412d40b338eaff4cf04d99fd0892723 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monitors instrument the training process. +"""Monitors instrument the training process (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@get_default_monitors @@BaseMonitor @@ -59,6 +63,10 @@ from tensorflow.python.util import tf_inspect class BaseMonitor(object): """Base class for Monitors. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Defines basic interfaces of Monitors. Monitors can either be run on all workers or, more commonly, restricted to run exclusively on the elected chief worker. @@ -229,6 +237,10 @@ def _extract_output(outputs, request): class EveryN(BaseMonitor): """Base class for monitors that execute callbacks every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This class adds three new callbacks: - every_n_step_begin - every_n_step_end @@ -418,6 +430,10 @@ class StopAtStep(BaseMonitor): class PrintTensor(EveryN): """Prints given tensors every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This is an `EveryN` monitor and has consistent semantic for `every_n` and `first_n`. @@ -455,9 +471,12 @@ class PrintTensor(EveryN): class LoggingTrainable(EveryN): """Writes trainable variable values into log every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Write the tensors in trainable variables `every_n` steps, starting with the `first_n`th step. - """ def __init__(self, scope=None, every_n=100, first_n=1): @@ -493,7 +512,12 @@ class LoggingTrainable(EveryN): class SummarySaver(EveryN): - """Saves summaries every N steps.""" + """Saves summaries every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, summary_op, @@ -554,6 +578,10 @@ class SummarySaver(EveryN): class ValidationMonitor(EveryN): """Runs evaluation of a given estimator, at most every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note that the evaluation is done based on the saved checkpoint, which will usually be older than the current step. @@ -573,7 +601,8 @@ class ValidationMonitor(EveryN): early_stopping_rounds=None, early_stopping_metric="loss", early_stopping_metric_minimize=True, - name=None): + name=None, + check_interval_secs=5): """Initializes a ValidationMonitor. Args: @@ -600,6 +629,9 @@ class ValidationMonitor(EveryN): loss metrics like mean squared error, and False for performance metrics like accuracy. name: See `BaseEstimator.evaluate`. + check_interval_secs: Only check for new checkpoint if at least + `check_interval_secs` have passed. Ignore if None. Default is 5 secs. + Raises: ValueError: If both x and input_fn are provided. @@ -626,6 +658,8 @@ class ValidationMonitor(EveryN): self._early_stopped = False self._latest_path = None self._latest_path_step = None + self._last_checkpoint_check_time = None + self._check_interval_secs = check_interval_secs @property def early_stopped(self): @@ -690,6 +724,16 @@ class ValidationMonitor(EveryN): # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") + current_time = time.time() + if (self._check_interval_secs is not None and + self._last_checkpoint_check_time is not None and + current_time - self._last_checkpoint_check_time <= + self._check_interval_secs): + logging.debug( + "Skipping evaluation since less than %d seconds have passed since " + "last check for a new checkpoint.", self._check_interval_secs) + return False + self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: @@ -740,6 +784,10 @@ class ValidationMonitor(EveryN): class CaptureVariable(EveryN): """Captures a variable's values into a collection. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This monitor is useful for unit testing. You should exercise caution when using this monitor in production, since it never discards values. @@ -778,6 +826,7 @@ class CaptureVariable(EveryN): self._var_values[step] = _extract_output(outputs, self._var_name) +@deprecation.deprecated(None, "Use tf.train.MonitoredTrainingSession.") def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, @@ -812,6 +861,10 @@ def get_default_monitors(loss_op=None, class GraphDump(BaseMonitor): """Dumps almost all tensors in the graph at every step. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note, this is very expensive, prefer `PrintTensor` in production. """ @@ -901,7 +954,12 @@ class GraphDump(BaseMonitor): class ExportMonitor(EveryN): - """Monitor that exports Estimator every N steps.""" + """Monitor that exports Estimator every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ @deprecation.deprecated("2017-03-25", "ExportMonitor is deprecated. Please pass an " @@ -1024,7 +1082,12 @@ class ExportMonitor(EveryN): class CheckpointSaver(BaseMonitor): - """Saves checkpoints every N steps or N seconds.""" + """Saves checkpoints every N steps or N seconds. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, checkpoint_dir, @@ -1109,7 +1172,12 @@ class CheckpointSaver(BaseMonitor): class StepCounter(EveryN): - """Steps per second monitor.""" + """Steps per second monitor. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): super(StepCounter, self).__init__(every_n_steps=every_n_steps) @@ -1149,6 +1217,10 @@ class NanLossDuringTrainingError(RuntimeError): class NanLoss(EveryN): """NaN Loss monitor. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Monitors loss and stops training if loss is NaN. Can either fail with exception or just stop training. """ diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index b2b24776c60183113a5f936dd276ff312d6d0079..5c34d0ddb01f3bcdc407e6926e7c5b73be1863b4 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -385,7 +385,11 @@ class MonitorsTest(test.TestCase): estimator.evaluate.return_value = validation_outputs monitor = learn.monitors.ValidationMonitor( - x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=2) + x=constant_op.constant(2.0), + every_n_steps=0, + early_stopping_rounds=2, + check_interval_secs=None) + self._assert_validation_monitor(monitor) monitor.set_estimator(estimator) with ops.Graph().as_default() as g, self.test_session(g): diff --git a/tensorflow/contrib/learn/python/learn/ops/__init__.py b/tensorflow/contrib/learn/python/learn/ops/__init__.py index 33962e34cc685ce2c830a7bbfd1b5c626bcd8b31..efb1f47cf5bb2dcd0fb37b7b85cd8f170d56e4d1 100644 --- a/tensorflow/contrib/learn/python/learn/ops/__init__.py +++ b/tensorflow/contrib/learn/python/learn/ops/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Various TensorFlow Ops.""" +"""Various TensorFlow Ops (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py index fa3b7323e343371e986b763d30a8a44620894549..8f9811cf251ae0af1e0055a56e1358c2771b1367 100644 --- a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops to work with embeddings. +"""TensorFlow Ops to work with embeddings (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Note: categorical variables are handled via embeddings in many cases. For example, in case of words. @@ -57,7 +61,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'): ids = ops.convert_to_tensor(ids) shape = array_ops_.shape(ids) ids_flat = array_ops_.reshape( - ids, math_ops.reduce_prod(shape, keep_dims=True)) + ids, math_ops.reduce_prod(shape, keepdims=True)) embeds_flat = nn.embedding_lookup(params, ids_flat, name) embed_shape = array_ops_.concat([shape, [-1]], 0) embeds = array_ops_.reshape(embeds_flat, embed_shape) diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py index b040ab3bb6c516158589a8e30d56fff1f7728951..9f2cadb01747c5a8e4ee75ac38f423f85e11bbba 100644 --- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for loss computation.""" +"""TensorFlow Ops for loss computation (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -35,7 +40,7 @@ def mean_squared_error_regressor(tensor_in, labels, weights, biases, name=None): [tensor_in, labels]): predictions = nn.xw_plus_b(tensor_in, weights, biases) if len(labels.get_shape()) == 1 and len(predictions.get_shape()) == 2: - predictions = array_ops_.squeeze(predictions, squeeze_dims=[1]) + predictions = array_ops_.squeeze(predictions, axis=[1]) return predictions, losses.mean_squared_error(labels, predictions) diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py index 45727faab4362abeab18f77861353eb53976023a..aa37cb4a76e2a6157bf077d327248353bd516472 100644 --- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for Sequence to Sequence models.""" +"""TensorFlow Ops for Sequence to Sequence models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,8 +31,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): """Returns predictions and loss for sequence of predictions. @@ -57,6 +64,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): return array_ops.stack(predictions, axis=1), loss +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): """Processes inputs for Sequence to Sequence models. @@ -87,6 +95,7 @@ def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): return in_x, in_y, out_y +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): """RNN Decoder that creates training and sampling sub-graphs. @@ -123,6 +132,7 @@ def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): return outputs, states, sampling_outputs, sampling_states +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_seq2seq(encoder_inputs, decoder_inputs, encoder_cell, diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py index 7bcc177d4ea0ab57f092d68888a72de2b2fd5edc..e8c6e1acf80f0791421bee59aff30e67bccb44b2 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Preprocessing tools useful for building models.""" +"""Preprocessing tools useful for building models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py index 154739d497ec1029026eaca1e93b37cd225f1050..faba3b2025e8abb51d1989c3fafbd5e711d6559b 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements preprocessing transformers for categorical variables.""" +"""Implements preprocessing transformers for categorical variables (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,8 @@ from __future__ import print_function import math import numpy as np +from tensorflow.python.util.deprecation import deprecated + # pylint: disable=g-bad-import-order from . import categorical_vocabulary from ..learn_io.data_feeder import setup_processor_data_feeder @@ -31,10 +38,16 @@ from ..learn_io.data_feeder import setup_processor_data_feeder class CategoricalProcessor(object): """Maps documents to sequences of word ids. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + As a common convention, Nan values are handled as unknown tokens. Both float('nan') and np.nan are accepted. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data for sequence ' + 'processing.') def __init__(self, min_frequency=0, share=False, vocabularies=None): """Initializes a CategoricalProcessor instance. diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py index 5709955c49fba50ca4a299a443a2902bbd9c6b23..3ac370a6ab4423846e810900514445ad5269b680 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""Categorical vocabulary classes to map categories to indexes. +"""Categorical vocabulary classes to map categories to indexes (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Can be used for categorical variables, sparse variables and words. """ @@ -25,14 +29,21 @@ from __future__ import print_function import collections import six +from tensorflow.python.util.deprecation import deprecated + class CategoricalVocabulary(object): """Categorical variables vocabulary class. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Accumulates and provides mapping from classes to indexes. Can be easily used for words. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, unknown_token="", support_reverse=True): self._unknown_token = unknown_token self._mapping = {unknown_token: 0} diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/text.py b/tensorflow/contrib/learn/python/learn/preprocessing/text.py index 3af2074c2a46f0258c04111fff0235ba8309625e..f2b6776be7789a9433bfe41eb9354b74347059ec 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/text.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/text.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements a number of text preprocessing utilities.""" +"""Implements a number of text preprocessing utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -24,6 +29,7 @@ import numpy as np import six from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated from .categorical_vocabulary import CategoricalVocabulary # pylint: disable=g-bad-import-order @@ -38,6 +44,7 @@ TOKENIZER_RE = re.compile(r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+", re.UNICODE) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def tokenizer(iterator): """Tokenizer generator. @@ -51,9 +58,16 @@ def tokenizer(iterator): yield TOKENIZER_RE.findall(value) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') class ByteProcessor(object): - """Maps documents into sequence of ids for bytes.""" + """Maps documents into sequence of ids for bytes. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length): self.max_document_length = max_document_length @@ -108,8 +122,14 @@ class ByteProcessor(object): class VocabularyProcessor(object): - """Maps documents to sequences of word ids.""" + """Maps documents to sequences of word ids. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length, min_frequency=0, diff --git a/tensorflow/contrib/learn/python/learn/session_run_hook.py b/tensorflow/contrib/learn/python/learn/session_run_hook.py index a8ba2be97206f2b974d256ad2c62c21a4e3e55d8..87edc9b720bdb3edcd5f2dcd1662d14da53c51cf 100644 --- a/tensorflow/contrib/learn/python/learn/session_run_hook.py +++ b/tensorflow/contrib/learn/python/learn/session_run_hook.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This file is deprecated. Use tensorflow.python.training.session_run_hook.""" +"""This file is deprecated. Use `tensorflow.python.training.session_run_hook`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py index 919d415c302b8ec17202aad34ff0bee69bfee2c7..d663cf5fb79c428b0e70d66b0f1305f0559a05c9 100644 --- a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py +++ b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wrapper for a Session-like object that handles threads and recovery. +"""Wrapper for a Session-like object that handles threads and recovery (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. Based on an original design of Illia Polosukhin. """ diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 429b6040be21d8cbe1f2bba58090366552fdfbe7..a1a3f20dcd8cb5ff7baa559ac41d5e5c40780511 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Trainable` interface.""" +"""`Trainable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,8 @@ import abc class Trainable(object): """Interface for objects that are trainable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py index 48978d0ac34cec2b18e6794dcf3b260bc3b683c4..66d8dc6fd43b383919a16515bc96be492a253bf6 100644 --- a/tensorflow/contrib/learn/python/learn/utils/__init__.py +++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Learn Utils.""" +"""TensorFlow Learn Utils (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index cb34cb1d26b6812c7f3f39e9f965615de5a8ef07..3eacac7a3d3dcff4d39025fdee88e16e385b1b84 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -13,14 +13,18 @@ # limitations under the License. # ============================================================================== -"""Export utilities.""" +"""Export utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -32,6 +36,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver +from tensorflow.python.training import training_util @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py index 226915987a4934626066b12810f579ae675107b2..916aecbea88b10bbef316ffb89d4c4d89667cb29 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -r"""System for specifying garbage collection (GC) of path based data. +r"""System for specifying garbage collection (GC) of path based data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This framework allows for GC of data specified by path names, for example files on disk. gc.Path objects each represent a single item stored at a path and may @@ -73,10 +77,12 @@ import os from tensorflow.python.platform import gfile from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated Path = collections.namedtuple('Path', 'path export_version') +@deprecated(None, 'Please implement your own file management or use Saver.') def largest_export_versions(n): """Creates a filter that keeps the largest n export versions. @@ -97,6 +103,7 @@ def largest_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def one_of_every_n_export_versions(n): """Creates a filter that keeps one of every n export versions. @@ -128,6 +135,7 @@ def one_of_every_n_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def mod_export_version(n): """Creates a filter that keeps every export that is a multiple of n. @@ -146,6 +154,7 @@ def mod_export_version(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def union(lf, rf): """Creates a filter that keeps the union of two filters. @@ -163,6 +172,7 @@ def union(lf, rf): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def negation(f): """Negate a filter. @@ -179,6 +189,7 @@ def negation(f): return keep +@deprecated(None, 'Please implement your own file name management.') def get_paths(base_dir, parser): """Gets a list of Paths in a given directory. diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py index b2521933e524e7ec24d73d4b5171f33e507dd88c..b92eb9fea8b7ccea56c781df74dcfa1cc5508e48 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for creating input_fns. +"""Utilities for creating input_fns (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Contents of this file are moved to tensorflow/python/estimator/export.py. InputFnOps is renamed to ServingInputReceiver. @@ -32,13 +36,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.util.deprecation import deprecated class InputFnOps(collections.namedtuple('InputFnOps', ['features', 'labels', 'default_inputs'])): - """A return type for an input_fn. + """A return type for an input_fn (deprecated). + + THIS CLASS IS DEPRECATED. Please use tf.estimator.export.ServingInputReceiver + instead. This return type is currently only supported for serving input_fn. Training and eval input_fn should return a `(features, labels)` tuple. @@ -56,6 +64,8 @@ class InputFnOps(collections.namedtuple('InputFnOps', """ +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_parsing_serving_input_receiver_fn.') def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): """Build an input_fn appropriate for serving, expecting fed tf.Examples. @@ -84,6 +94,8 @@ def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): return input_fn +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_raw_serving_input_receiver_fn.') def build_default_serving_input_fn(features, default_batch_size=None): """Build an input_fn appropriate for serving, expecting feature Tensors. diff --git a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py index 6a63fb545a56e6040b0b0c3bbb6a17cd96925895..6dbaa15f8391b0044be8e30ca191753beb88db93 100644 --- a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py +++ b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A simple script for inspect checkpoint files.""" +"""A simple script for inspect checkpoint files (deprecated).""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 1593380007b2799fb1d17e92408ab19a7b47fe1e..f8106d1e4a7e79f1cd651c40995be480721a8129 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities supporting export to SavedModel. +"""Utilities supporting export to SavedModel (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Some contents of this file are moved to tensorflow/python/estimator/export.py: @@ -52,8 +56,9 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator from tensorflow.python.training import saver - from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated + # A key for use in the input_alternatives dict indicating the default input. # This is the input that will be expected when a serving request does not @@ -77,6 +82,7 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_standardized_signature_def(input_tensors, output_tensors, problem_type): """Build a SignatureDef using problem type and input and output Tensors. @@ -156,6 +162,7 @@ def _is_regression_problem(problem_type, input_tensors, output_tensors): len(input_tensors) == 1 and len(output_tensors) == 1) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_input_alternatives(input_ops): """Obtain all input alternatives using the input_fn output and heuristics.""" input_alternatives = {} @@ -181,6 +188,7 @@ def get_input_alternatives(input_ops): return input_alternatives, features +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): """Obtain all output alternatives using the model_fn output and heuristics. @@ -246,6 +254,7 @@ def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): sorted(output_alternatives.keys()))) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_all_signature_defs(input_alternatives, output_alternatives, actual_default_output_alternative_key): """Build `SignatureDef`s from all pairs of input and output alternatives.""" @@ -279,6 +288,7 @@ def build_all_signature_defs(input_alternatives, output_alternatives, MAX_DIRECTORY_CREATION_ATTEMPTS = 10 +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. @@ -317,6 +327,7 @@ def get_timestamped_export_dir(export_dir_base): '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_temp_export_dir(timestamped_export_dir): """Builds a directory name based on the argument but starting with 'temp-'. @@ -332,7 +343,8 @@ def get_temp_export_dir(timestamped_export_dir): """ (dirname, basename) = os.path.split(timestamped_export_dir) temp_export_dir = os.path.join( - compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename))) + compat.as_bytes(dirname), + compat.as_bytes('temp-{}'.format(compat.as_text(basename)))) return temp_export_dir @@ -344,6 +356,7 @@ def _export_version_parser(path): return path._replace(export_version=int(filename)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_most_recent_export(export_dir_base): """Locate the most recent SavedModel export in a directory of many exports. @@ -363,6 +376,7 @@ def get_most_recent_export(export_dir_base): return next(iter(results or []), None) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def garbage_collect_exports(export_dir_base, exports_to_keep): """Deletes older exports, retaining only a given number of the most recent. @@ -387,6 +401,7 @@ def garbage_collect_exports(export_dir_base, exports_to_keep): logging.warn('Can not delete %s recursively: %s', p.path, e) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, @@ -400,7 +415,7 @@ def make_export_strategy(serving_input_fn, `InputFnOps`. default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from ${tf.estimator.Estimator} + Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -438,7 +453,7 @@ def make_export_strategy(serving_input_fn, The string path to the exported directory. Raises: - ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified. """ if isinstance(estimator, core_estimator.Estimator): @@ -469,6 +484,8 @@ def make_export_strategy(serving_input_fn, return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) +@deprecated(None, + 'Use tf.estimator.export.build_parsing_serving_input_receiver_fn') def make_parsing_export_strategy(feature_columns, default_output_alternative_key=None, assets_extra=None, @@ -487,7 +504,7 @@ def make_parsing_export_strategy(feature_columns, that must be provided at serving time (excluding labels!). default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from ${tf.estimator.Estimator} + Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -555,8 +572,14 @@ def _default_compare_fn(curr_best_eval_result, cand_eval_result): class BestModelSelector(object): - """A helper that keeps track of export selection candidates.""" + """A helper that keeps track of export selection candidates. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def __init__(self, event_file_pattern=None, compare_fn=None): """Constructor of this class. @@ -622,6 +645,7 @@ class BestModelSelector(object): return best_eval_result +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_best_model_export_strategy( serving_input_fn, exports_to_keep=1, @@ -707,6 +731,7 @@ def make_best_model_export_strategy( # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def extend_export_strategy(base_export_strategy, post_export_fn, post_export_name=None): @@ -741,7 +766,7 @@ def extend_export_strategy(base_export_strategy, The string path to the SavedModel indicated by post_export_fn. Raises: - ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. RuntimeError: If unable to create temporary or final export directory. diff --git a/tensorflow/contrib/legacy_seq2seq/BUILD b/tensorflow/contrib/legacy_seq2seq/BUILD index 1fa55132b1fc0cd3367ca2eb331b6870edc30c3b..4ce91a140f816ddc8bdc60287e4cbc807172ec6d 100644 --- a/tensorflow/contrib/legacy_seq2seq/BUILD +++ b/tensorflow/contrib/legacy_seq2seq/BUILD @@ -58,17 +58,8 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["noasan"], # times out b/63678675 -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], + tags = [ + "noasan", # times out b/63678675 + "optonly", # times out (flaky) + ], ) diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD index df96402a4ffd51840f77d58d8066487030362340..4dccb9be7cd2e603edcf10c020cc0ee1675f518a 100644 --- a/tensorflow/contrib/libsvm/BUILD +++ b/tensorflow/contrib/libsvm/BUILD @@ -88,15 +88,3 @@ tf_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 208e7bc69be76680868c766bc99429eea5870c80..78b7970069fec2d67f816b39d8fa4c58021cef85 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -42,15 +42,3 @@ cuda_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 4720692c3384ba1bede1f486c1b1e0e69d10a63a..a262a099cf8f843a4d228ce5d53664cb85fd046f 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -17,10 +17,15 @@ See the @{$python/contrib.linalg} guide. @@LinearOperator +@@LinearOperatorBlockDiag +@@LinearOperatorCirculant +@@LinearOperatorCirculant2D +@@LinearOperatorCirculant3D @@LinearOperatorDiag @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@LinearOperatorFullMatrix +@@LinearOperatorKronecker @@LinearOperatorLowerTriangular @@LinearOperatorLowRankUpdate @@LinearOperatorComposition @@ -35,14 +40,18 @@ from __future__ import print_function from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * from tensorflow.python.ops.linalg.linear_operator import * +from tensorflow.python.ops.linalg.linear_operator_block_diag import * +from tensorflow.python.ops.linalg.linear_operator_circulant import * from tensorflow.python.ops.linalg.linear_operator_composition import * from tensorflow.python.ops.linalg.linear_operator_diag import * from tensorflow.python.ops.linalg.linear_operator_full_matrix import * from tensorflow.python.ops.linalg.linear_operator_identity import * +from tensorflow.python.ops.linalg.linear_operator_kronecker import * from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.python.util.all_util import remove_undocumented + remove_undocumented(__name__) diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index cea3627ed565f0de86d8d9bb6b45c4b19c5b5558..5b89c6cef9fa9fdef7c26ddee1efa03f3056d881 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -138,14 +138,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 70f777f08bd5b8157e601f19019075d3e7543811..ef0e08a777779e04f70d11fe83280ccaf1c178fd 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random import threading from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel @@ -34,12 +35,14 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import googletest _MAX_ITERATIONS = 100 -_SHARD_NUMBERS = [None, 1, 3, 10] -_NUM_LOSS_PARTITIONS = [2, 4] +_SHARD_NUMBERS = [None, 1, 3] +_NUM_LOSS_PARTITIONS = [4] def make_example_proto(feature_dict, target, value=1.0): @@ -102,15 +105,51 @@ def make_example_dict(example_protos, example_weights): example_ids=['%d' % i for i in range(0, len(example_protos))]) -def make_variable_dict(max_age, max_gender): +def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): + random.seed(1) + + sparse_features = [ + SparseFeatureColumn( + [i for i in range(num_examples) for _ in range(num_non_zero)], [ + i for _ in range(num_examples) + for i in random.sample(range(dim), num_non_zero) + ], + [num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)]) + ] + examples_dict = dict( + sparse_features=sparse_features, + dense_features=[], + example_weights=[random.random() for _ in range(num_examples)], + example_labels=[ + 1. if random.random() > 0.5 else 0. for _ in range(num_examples) + ], + example_ids=[str(i) for i in range(num_examples)]) + + weights = variables_lib.Variable( + array_ops.zeros([dim], dtype=dtypes.float32)) + variables_dict = dict( + sparse_features_weights=[weights], + dense_features_weights=[]) + + return examples_dict, variables_dict + + +def make_variable_dict(max_age, max_gender, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + partitioner = None + if partitioned: + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, + axis=0) + with variable_scope.variable_scope( + name_or_scope='variables', + partitioner=partitioner): + age_weights = variables_lib.Variable( + array_ops.zeros( + [max_age + 1], dtype=dtypes.float32)) + gender_weights = variables_lib.Variable( + array_ops.zeros( + [max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -235,8 +274,113 @@ class SdcaWithLogisticLossTest(SdcaModelTest): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) - def testDistributedSimple(self): + def testPartitionedPrimals(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0], + 'gender': [0] + }, 0), + make_example_proto({ + 'age': [1], + 'gender': [1] + }, 1), + ] + example_weights = [1.0, 1.0] + for num_shards in _SHARD_NUMBERS: + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1, partitioned=True) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + num_table_shards=num_shards, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + unregularized_loss = lr.unregularized_loss(examples) + loss = lr.regularized_loss(examples) + predictions = lr.predictions(examples) + self.assertAllClose(0.693147, unregularized_loss.eval()) + self.assertAllClose(0.693147, loss.eval()) + train_op = lr.minimize() + for _ in range(_MAX_ITERATIONS): + train_op.run() + lr.update_weights(train_op).run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.411608 is the unregularized_loss at that optimum. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) + self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose( + 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + + def testSparseRandom(self): + dim = 20 + num_examples = 1000 + # Number of non-zero features per example. + non_zeros = 10 + # Setup test data. + with self._single_threaded_test_session(): + examples, variables = make_random_examples_and_variables_dicts( + num_examples, dim, non_zeros) + options = dict( + symmetric_l2_regularization=.1, + symmetric_l1_regularization=0, + num_table_shards=1, + adaptive=False, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + train_op = lr.minimize() + for _ in range(4): + train_op.run() + lr.update_weights(train_op).run() + # Duality gap is 1.4e-5. + # It would be 0.01 without shuffling and 0.02 with adaptive sampling. + self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-3) + + def testSparseDuplicate(self): # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0] * 5, + 'gender': [0] * 5 + }, 0), + make_example_proto({ + 'age': [1] * 5, + 'gender': [1] * 5 + }, 1), + ] + example_weights = [1.0, 1.0] + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + train_op = lr.minimize() + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'Duplicate'): + train_op.run() + + def testDistributedSimple(self): + # Distributed SDCA may not converge if the workers update concurrently the + # same example. In this test the examples are partitioned across workers. + # The examples are the same for all workers, just the example_ids are + # different. example_protos = [ make_example_proto({ 'age': [0], @@ -248,13 +392,19 @@ class SdcaWithLogisticLossTest(SdcaModelTest): }, 1), ] example_weights = [1.0, 1.0] + examples = make_example_dict(example_protos, example_weights) + example_ids = array_ops.placeholder( + dtypes.string, shape=(len(example_weights),)) + examples['example_ids'] = example_ids + variables = make_variable_dict(1, 1) for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): - examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) options = dict( - symmetric_l2_regularization=1, + # Keep the same solution as for TestSimple: since the number of + # examples is multplied by num_loss_partitions, multiply also + # L2 by the same value. + symmetric_l2_regularization=num_loss_partitions, symmetric_l1_regularization=0, loss_type='logistic_loss', num_table_shards=num_shards, @@ -270,32 +420,30 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def Minimize(): + def minimize(worker_id): with self._single_threaded_test_session(): + feed_dict = {example_ids: [ + str(i + worker_id*len(example_weights)) for i in range( + len(example_weights))]} for _ in range(_MAX_ITERATIONS): - train_op.run() + train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop threads = [] - for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=Minimize)) + for worker_id in range(num_loss_partitions): + threads.append(threading.Thread(target=minimize, args=(worker_id,))) threads[-1].start() for t in threads: t.join() - lr.update_weights(train_op).run() - - # The high tolerance in unregularized_loss comparisons is due to the - # fact that it's possible to trade off unregularized_loss vs. - # regularization and still have a sum that is quite close to the - # optimal regularized_loss value. SDCA's duality gap only ensures - # that the regularized_loss is within 0.01 of optimal. - # 0.525457 is the optimal regularized_loss. - # 0.411608 is the unregularized_loss at that optimum. - self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) - self.assertAllClose(0.525457, loss.eval(), atol=0.01) + lr.update_weights(train_op).run(feed_dict={ + example_ids: [str(i) for i in range(len(example_weights))]}) + + # Test only the unregularized loss because the optimal value of the + # regularized loss depends on num_loss_partitions. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) - self.assertTrue(lr.approximate_duality_gap().eval() < 0.02) + self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02) def testSimpleNoL2(self): # Same as test above (so comments from above apply) but without an L2. @@ -395,7 +543,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllClose([0, 1, 1, 1], predicted_labels.eval()) self.assertAllClose( - 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + 0.0, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) def testFractionalExampleLabel(self): # Setup test data with 1 positive, and 1 mostly-negative example. @@ -407,7 +555,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): make_example_proto({ 'age': [1], 'gender': [1] - }, 1), + }, 0.9), ] example_weights = [1.0, 1.0] for num_shards in _SHARD_NUMBERS: diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 3f5fdc18bb8f47cceee8f81dd5ded02059344b8b..0047d5753a773ce814d685f89da9ae6b04d21cb6 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,12 +22,14 @@ import collections from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -43,9 +45,6 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - This class currently only supports a single machine (multi-threaded) - implementation. We expect the weights and duals to fit in a single machine. - Loss functions supported: * Binary logistic loss @@ -168,6 +167,10 @@ class SdcaModel(object): # of workers return self._options.get('num_loss_partitions', 1) + def _adaptive(self): + # Perform adaptive sampling. + return self._options.get('adaptive', True) + def _num_table_shards(self): # Number of hash table shards. # Return 1 if not specified or if the value is 'None' @@ -178,18 +181,41 @@ class SdcaModel(object): # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic. def _create_slots(self): - # Make internal variables which have the updates before applying L1 - # regularization. + """Make unshrinked internal variables (slots).""" + # Unshrinked variables have the updates before applying L1 regularization. + # Each unshrinked slot variable is either a `Variable` or list of + # `Variable`, depending on the value of its corresponding primary variable. + # We avoid using `PartitionedVariable` for the unshrinked slots since we do + # not need any of the extra information. self._slots = collections.defaultdict(list) for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is - # fixed - self._slots['unshrinked_' + name].append( - var_ops.Variable( - array_ops.zeros_like(var.initialized_value(), dtypes.float32), - name=var.op.name + '_unshrinked/SDCAOptimizer')) + # Our primary variable may be either a PartitionedVariable, or a list + # of Variables (each representing a partition). + if (isinstance(var, var_ops.PartitionedVariable) or + isinstance(var, list)): + var_list = [] + # pylint: disable=protected-access + for v in var: + with ops.colocate_with(v): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 + # is fixed. + slot_var = var_ops.Variable( + initial_value=array_ops.zeros_like(v.initialized_value(), + dtypes.float32), + name=v.op.name + '_unshrinked/SDCAOptimizer') + var_list.append(slot_var) + self._slots['unshrinked_' + name].append(var_list) + # pylint: enable=protected-access + else: + with ops.device(var.device): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is + # fixed. + self._slots['unshrinked_' + name].append( + var_ops.Variable( + array_ops.zeros_like(var.initialized_value(), + dtypes.float32), + name=var.op.name + '_unshrinked/SDCAOptimizer')) def _assertSpecified(self, items, check_in): for x in items: @@ -201,16 +227,25 @@ class SdcaModel(object): if not isinstance(check_in[x], list): raise ValueError(x + ' must be a list.') + def _var_to_list(self, var): + """Wraps var in a list if it is not a list or PartitionedVariable.""" + if not (isinstance(var, list) or + isinstance(var, var_ops.PartitionedVariable)): + var = [var] + return var + def _l1_loss(self): """Computes the (un-normalized) l1 loss of the model.""" with name_scope('sdca/l1_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.abs(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append( + math_ops.reduce_sum( + math_ops.abs(math_ops.cast(weights, dtypes.float64)))) # SDCA L1 regularization cost is: l1 * sum(|weights|) return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) @@ -219,17 +254,37 @@ class SdcaModel(object): with name_scope('sdca/l2_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.square(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast( + weights, dtypes.float64)))) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" - return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list] + # input_list can be a list of Variables (that are implicitly partitioned), + # in which case the underlying logic in internal_convert_to_tensor will not + # concatenate the partitions together. This method takes care of the + # concatenating (we only allow partitioning on the first axis). + output_list = [] + for x in input_list: + tensor_to_convert = x + if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable): + # We only allow for partitioning on the first axis. + tensor_to_convert = array_ops.concat(x, axis=0) + output_list.append(internal_convert_to_tensor( + tensor_to_convert, as_ref=as_ref)) + return output_list + + def _get_first_dimension_size_statically(self, w, num_partitions): + """Compute the static size of the first dimension for a sharded variable.""" + dim_0_size = w[0].get_shape()[0] + for p in range(1, num_partitions): + dim_0_size += w[p].get_shape()[0] + return dim_0_size def _linear_predictions(self, examples): """Returns predictions of the form w*x.""" @@ -282,6 +337,28 @@ class SdcaModel(object): result = math_ops.sigmoid(result) return result + def _get_partitioned_update_ops(self, + v_num, + num_partitions_by_var, + p_assignments_by_var, + gather_ids_by_var, + weights, + full_update, + p_assignments, + num_partitions): + """Get updates for partitioned variables.""" + num_partitions = num_partitions_by_var[v_num] + p_assignments = p_assignments_by_var[v_num] + gather_ids = gather_ids_by_var[v_num] + updates = data_flow_ops.dynamic_partition( + full_update, p_assignments, num_partitions) + update_ops = [] + for p in range(num_partitions): + with ops.colocate_with(weights[p]): + result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) + update_ops.append(result) + return update_ops + def minimize(self, global_step=None, name=None): """Add operations to train a linear model by minimizing the loss function. @@ -314,18 +391,89 @@ class SdcaModel(object): # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. - weights_tensor = self._convert_n_to_tensor(self._slots[ - 'unshrinked_sparse_features_weights']) sparse_weights = [] sparse_indices = [] - for w, i in zip(weights_tensor, sparse_feature_indices): - # Find the feature ids to lookup in the variables. - with ops.device(w.device): - sparse_indices.append( - math_ops.cast( - array_ops.unique(math_ops.cast(i, dtypes.int32))[0], - dtypes.int64)) - sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) + # If we have partitioned variables, keep a few lists of Tensors around + # that we need for the assign_add after the op call to + # gen_sdca_ops.sdca_optimizer(). + num_partitions_by_var = [] + p_assignments_by_var = [] + gather_ids_by_var = [] + for w, i in zip(self._slots['unshrinked_sparse_features_weights'], + sparse_feature_indices): + # Append the sparse_indices (in full-variable space). + sparse_idx = math_ops.cast( + array_ops.unique(math_ops.cast(i, dtypes.int32))[0], + dtypes.int64) + sparse_indices.append(sparse_idx) + if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable): + num_partitions = len(w) + flat_ids = array_ops.reshape(sparse_idx, [-1]) + # We use div partitioning, which is easiest to support downstream. + # Compute num_total_ids as the sum of dim-0 of w, then assign + # to partitions based on a constant number of ids per partition. + # Optimize if we already know the full shape statically. + dim_0_size = self._get_first_dimension_size_statically( + w, num_partitions) + + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, + flat_ids.dtype) + else: + dim_0_sizes = [] + for p in range(num_partitions): + if w[p].get_shape()[0].value is not None: + dim_0_sizes.append(w[p].get_shape()[0].value) + else: + with ops.colocate_with(w[p]): + dim_0_sizes.append(array_ops.shape(w[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // num_partitions + extras = num_total_ids % num_partitions + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + new_ids = array_ops.where(p_assignments < extras, + flat_ids % (ids_per_partition + 1), + (flat_ids - extras) % ids_per_partition) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into num_partitions + # separate lists. + gather_ids = data_flow_ops.dynamic_partition(new_ids, + p_assignments, + num_partitions) + # Append these to the lists for use in the later update. + num_partitions_by_var.append(num_partitions) + p_assignments_by_var.append(p_assignments) + gather_ids_by_var.append(gather_ids) + + # Gather the weights from each partition. + partition_gathered_weights = [] + for p in range(num_partitions): + with ops.colocate_with(w[p]): + partition_gathered_weights.append( + array_ops.gather(w[p], gather_ids[p])) + + # Stitch the weights back together in the same order they were before + # we dynamic_partitioned them. + condition_indices = data_flow_ops.dynamic_partition( + math_ops.range(array_ops.shape(new_ids)[0]), + p_assignments, num_partitions) + batch_gathered_weights = data_flow_ops.dynamic_stitch( + condition_indices, partition_gathered_weights) + else: + w_as_tensor = internal_convert_to_tensor(w) + with ops.device(w_as_tensor.device): + batch_gathered_weights = array_ops.gather( + w_as_tensor, sparse_idx) + sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( @@ -344,18 +492,32 @@ class SdcaModel(object): l1=self._options['symmetric_l1_regularization'], l2=self._symmetric_l2_regularization(), num_loss_partitions=self._num_loss_partitions(), - num_inner_iterations=1) + num_inner_iterations=1, + adaptative=self._adaptive()) # pylint: enable=protected-access with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] # Update the weights before the proximal step. - for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'], - sparse_indices, sfw): - update_ops.append(state_ops.scatter_add(w, i, u)) + for v_num, (w, i, u) in enumerate( + zip(self._slots['unshrinked_sparse_features_weights'], + sparse_indices, sfw)): + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + update_ops += self._get_partitioned_update_ops( + v_num, num_partitions_by_var, p_assignments_by_var, + gather_ids_by_var, w, u, p_assignments, num_partitions) + else: + update_ops.append(state_ops.scatter_add(w, i, u)) for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): - update_ops.append(w.assign_add(u)) - + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + split_updates = array_ops.split( + u, num_or_size_splits=[v.shape.as_list()[0] for v in w]) + for v, split_update in zip(w, split_updates): + update_ops.append(state_ops.assign_add(v, split_update)) + else: + update_ops.append(state_ops.assign_add(w, u)) if not global_step: return control_flow_ops.group(*update_ops) with ops.control_dependencies(update_ops): @@ -380,21 +542,22 @@ class SdcaModel(object): for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): - update_ops.append(var.assign(slot_var)) + for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)): + update_ops.append(v.assign(sv)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # pylint: disable=protected-access - update_ops.append( - gen_sdca_ops.sdca_shrink_l1( - self._convert_n_to_tensor( - [var], as_ref=True), - l1=self._symmetric_l1_regularization(), - l2=self._symmetric_l2_regularization())) + for v in self._var_to_list(var): + with ops.device(v.device): + # pylint: disable=protected-access + update_ops.append( + gen_sdca_ops.sdca_shrink_l1( + self._convert_n_to_tensor([v], as_ref=True), + l1=self._symmetric_l1_regularization(), + l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops) def approximate_duality_gap(self): diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index ec726bbed41a86eb314e3591ecaedaa6bf0e5e9b..5015fb0848107950dd27eb81431dd308f22858bc 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -49,6 +49,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): default_value, empty_key, num_shards=1, + checkpoint=True, name='ShardedMutableHashTable'): with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: super(ShardedMutableDenseHashTable, self).__init__(key_dtype, @@ -61,6 +62,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype=value_dtype, default_value=default_value, empty_key=empty_key, + checkpoint=checkpoint, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards # TODO(andreasst): add a value_shape() method to LookupInterface diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 05794a42c5f2d0eece6adab36fb5610078cece31..200e7de6b95f17672c6ef51f887b15f9d185f775 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): num_loss_partitions = params["num_loss_partitions"] weight_column_name = params["weight_column_name"] update_weights_hook = params.get("update_weights_hook", None) + partitioner = params["partitioner"] loss_type = None if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access @@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None): example_id_column=example_id_column, num_loss_partitions=n_loss_partitions, symmetric_l1_regularization=l1_regularization, - symmetric_l2_regularization=l2_regularization) + symmetric_l2_regularization=l2_regularization, + partitioner=partitioner) parent_scope = "linear" - with variable_scope.variable_op_scope(features.values(), - parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope, + partitioner=partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -213,7 +216,8 @@ class _SDCAEstimator(estimator.Estimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `_SDCAEstimator` estimator object. Args: @@ -241,6 +245,8 @@ class _SDCAEstimator(estimator.Estimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `_SDCAEstimator` estimator. @@ -267,6 +273,7 @@ class _SDCAEstimator(estimator.Estimator): "l2_regularization": l2_regularization, "weight_column_name": weight_column_name, "update_weights_hook": _SdcaUpdateWeightsHook(), + "partitioner": partitioner, } super(_SDCAEstimator, self).__init__( @@ -336,7 +343,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALogisticClassifier` object. Args: @@ -361,6 +369,8 @@ class SDCALogisticClassifier(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALogisiticClassifier` estimator. @@ -376,7 +386,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_classes(self, input_fn=None): """Runs inference to determine the predicted class. @@ -463,7 +474,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALinearRegressor` estimator object. @@ -489,6 +501,8 @@ class SDCALinearRegressor(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALinearRegressor` estimator. @@ -503,7 +517,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_scores(self, input_fn): """Returns predicted scores for given features. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index 79a5928a21cb9a2633b2aac178f185ba333790d6..647667188238dc18b137eaad98356a79b3a549b4 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -25,11 +25,19 @@ from tensorflow.contrib.linear_optimizer.python import sdca_estimator from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test class SDCALogisticClassifierTest(test.TestCase): + def _single_threaded_test_session(self): + # TODO(andreasst): figure out why SDCALinearRegressor needs a single + # threaded session to pass in tsan mode but SDCALogisticClassifier does not. + config = config_pb2.ConfigProto( + inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) + return self.test_session(config=config) + def testRealValuedFeatures(self): """Tests SDCALogisticClassifier works with real valued features.""" @@ -41,7 +49,7 @@ class SDCALogisticClassifierTest(test.TestCase): 'weights': constant_op.constant([[1.0], [1.0]]) }, constant_op.constant([[0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): maintenance_cost = feature_column_lib.real_valued_column( 'maintenance_cost') sq_footage = feature_column_lib.real_valued_column('sq_footage') @@ -66,7 +74,7 @@ class SDCALogisticClassifierTest(test.TestCase): constant_op.constant([[500.0, 800.0], [200.0, 600.0]]) }, constant_op.constant([[0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): dense_feature = feature_column_lib.real_valued_column( 'dense_feature', dimension=2) classifier = sdca_estimator.SDCALogisticClassifier( @@ -86,7 +94,7 @@ class SDCALogisticClassifierTest(test.TestCase): 'weights': constant_op.constant([[1.0], [1.0], [1.0]]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): price_bucket = feature_column_lib.bucketized_column( feature_column_lib.real_valued_column('price'), boundaries=[500.0, 700.0]) @@ -120,7 +128,7 @@ class SDCALogisticClassifierTest(test.TestCase): constant_op.constant([[1.0], [1.0], [1.0]]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): price = feature_column_lib.real_valued_column('price') country = feature_column_lib.sparse_column_with_hash_bucket( 'country', hash_bucket_size=5) @@ -151,7 +159,7 @@ class SDCALogisticClassifierTest(test.TestCase): dense_shape=[3, 5]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): country = feature_column_lib.sparse_column_with_hash_bucket( 'country', hash_bucket_size=5) country_weighted_by_price = feature_column_lib.weighted_sparse_column( @@ -163,6 +171,38 @@ class SDCALogisticClassifierTest(test.TestCase): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testSparseFeaturesWithDuplicates(self): + """Tests SDCALogisticClassifier with duplicated sparse features.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2']), + 'age': + sparse_tensor.SparseTensor( + values=['20-29'] * 5 + ['31-40'] * 5, + indices=[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [1, 0], + [1, 0], [1, 0], [1, 0], [1, 0]], + dense_shape=[2, 1]), + 'gender': + sparse_tensor.SparseTensor( + values=['m'] * 5 + ['f'] * 5, + indices=[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [1, 0], + [1, 0], [1, 0], [1, 0], [1, 0]], + dense_shape=[2, 1]), + }, constant_op.constant([[1], [0]]) + + with self._single_threaded_test_session(): + age = feature_column_lib.sparse_column_with_hash_bucket( + 'age', hash_bucket_size=10) + gender = feature_column_lib.sparse_column_with_hash_bucket( + 'gender', hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', feature_columns=[age, gender]) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertLess(metrics['loss'], 0.060) + def testCrossedFeatures(self): """Tests SDCALogisticClassifier with crossed features.""" @@ -182,7 +222,7 @@ class SDCALogisticClassifierTest(test.TestCase): dense_shape=[3, 1]) }, constant_op.constant([[0], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): language = feature_column_lib.sparse_column_with_hash_bucket( 'language', hash_bucket_size=5) country = feature_column_lib.sparse_column_with_hash_bucket( @@ -215,7 +255,7 @@ class SDCALogisticClassifierTest(test.TestCase): constant_op.constant([[3.0], [1.0], [1.0]]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): price = feature_column_lib.real_valued_column('price') sq_footage_bucket = feature_column_lib.bucketized_column( feature_column_lib.real_valued_column('sq_footage'), @@ -234,6 +274,47 @@ class SDCALogisticClassifierTest(test.TestCase): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testPartitionedMixedFeatures(self): + """Tests SDCALogisticClassifier with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([900.0, 700.0, 600.0]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(metrics['accuracy'], 0.9) + class SDCALinearRegressorTest(test.TestCase): @@ -311,6 +392,48 @@ class SDCALinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testMixedFeaturesArbitraryWeightsPartitioned(self): + """Tests SDCALinearRegressor works with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + regressor = sdca_estimator.SDCALinearRegressor( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + l2_regularization=1.0, + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """SDCALinearRegressor works with sparse features and L1 regularization.""" diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 92d022f2a30ffeb77e81d3bd01365afcd14826b5..9872c6f97c879d8994b6c26e65df33e368a0603e 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib import layers from tensorflow.contrib.linear_optimizer.python.ops import sdca_ops from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -36,18 +37,18 @@ class SDCAOptimizer(object): Example usage: ```python - real_feature_column = real_valued_column(...) - sparse_feature_column = sparse_column_with_hash_bucket(...) - sdca_optimizer = linear.SDCAOptimizer(example_id_column='example_id', - num_loss_partitions=1, - num_table_shards=1, - symmetric_l2_regularization=2.0) - classifier = tf.contrib.learn.LinearClassifier( - feature_columns=[real_feature_column, sparse_feature_column], - weight_column_name=..., - optimizer=sdca_optimizer) - classifier.fit(input_fn_train, steps=50) - classifier.evaluate(input_fn=input_fn_eval) + real_feature_column = real_valued_column(...) + sparse_feature_column = sparse_column_with_hash_bucket(...) + sdca_optimizer = linear.SDCAOptimizer(example_id_column='example_id', + num_loss_partitions=1, + num_table_shards=1, + symmetric_l2_regularization=2.0) + classifier = tf.contrib.learn.LinearClassifier( + feature_columns=[real_feature_column, sparse_feature_column], + weight_column_name=..., + optimizer=sdca_optimizer) + classifier.fit(input_fn_train, steps=50) + classifier.evaluate(input_fn=input_fn_eval) ``` Here the expectation is that the `input_fn_*` functions passed to train and @@ -63,7 +64,8 @@ class SDCAOptimizer(object): of workers running the train steps. It defaults to 1 (single machine). `num_table_shards` defines the number of shards for the internal state table, typically set to match the number of parameter servers for large - data sets. + data sets. You can also specify a `partitioner` object to partition the primal + weights during training (`div` partitioning strategy will be used). """ def __init__(self, @@ -71,12 +73,16 @@ class SDCAOptimizer(object): num_loss_partitions=1, num_table_shards=None, symmetric_l1_regularization=0.0, - symmetric_l2_regularization=1.0): + symmetric_l2_regularization=1.0, + adaptive=True, + partitioner=None): self._example_id_column = example_id_column self._num_loss_partitions = num_loss_partitions self._num_table_shards = num_table_shards self._symmetric_l1_regularization = symmetric_l1_regularization self._symmetric_l2_regularization = symmetric_l2_regularization + self._adaptive = adaptive + self._partitioner = partitioner def get_name(self): return 'SDCAOptimizer' @@ -101,6 +107,14 @@ class SDCAOptimizer(object): def symmetric_l2_regularization(self): return self._symmetric_l2_regularization + @property + def adaptive(self): + return self._adaptive + + @property + def partitioner(self): + return self._partitioner + def get_train_step(self, columns_to_variables, weight_column_name, loss_type, features, targets, global_step): """Returns the training operation of an SdcaModel optimizer.""" @@ -168,37 +182,65 @@ class SDCAOptimizer(object): sparse_feature_column = _dense_tensor_to_sparse_feature_column( dense_bucket_tensor) sparse_feature_with_values.append(sparse_feature_column) - # For bucketized columns, the variables list contains exactly one - # element. - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) elif isinstance( column, ( + layers.feature_column._WeightedSparseColumn, # pylint: disable=protected-access layers.feature_column._CrossedColumn, # pylint: disable=protected-access layers.feature_column._SparseColumn)): # pylint: disable=protected-access - sparse_features.append( - SparseFeatureColumn( - array_ops.reshape( - array_ops.split( - value=transformed_tensor.indices, - num_or_size_splits=2, - axis=1)[0], [-1]), - array_ops.reshape(transformed_tensor.values, [-1]), None)) - sparse_feature_weights.append(columns_to_variables[column][0]) - elif isinstance(column, layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access - id_tensor = column.id_tensor(transformed_tensor) - weight_tensor = column.weight_tensor(transformed_tensor) + + if isinstance(column, layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access + id_tensor = column.id_tensor(transformed_tensor) + weight_tensor = array_ops.reshape( + column.weight_tensor(transformed_tensor).values, [-1]) + else: + id_tensor = transformed_tensor + weight_tensor = array_ops.ones( + [array_ops.shape(id_tensor.indices)[0]], dtypes.float32) + + example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1]) + + flat_ids = array_ops.reshape(id_tensor.values, [-1]) + # Prune invalid IDs (< 0) from the flat_ids, example_ids, and + # weight_tensor. These can come from looking up an OOV entry in the + # vocabulary (default value being -1). + is_id_valid = math_ops.greater_equal(flat_ids, 0) + flat_ids = array_ops.boolean_mask(flat_ids, is_id_valid) + example_ids = array_ops.boolean_mask(example_ids, is_id_valid) + weight_tensor = array_ops.boolean_mask(weight_tensor, is_id_valid) + + projection_length = math_ops.reduce_max(flat_ids) + 1 + # project ids based on example ids so that we can dedup ids that + # occur multiple times for a single example. + projected_ids = projection_length * example_ids + flat_ids + + # Remove any redudant ids. + ids, idx = array_ops.unique(projected_ids) + # Keep only one example id per duplicated ids. + example_ids_filtered = math_ops.unsorted_segment_min( + example_ids, idx, + array_ops.shape(ids)[0]) + + # reproject ids back feature id space. + reproject_ids = (ids - projection_length * example_ids_filtered) + + weights = array_ops.reshape( + math_ops.unsorted_segment_sum(weight_tensor, idx, + array_ops.shape(ids)[0]), [-1]) sparse_feature_with_values.append( - SparseFeatureColumn( - array_ops.reshape( - array_ops.split( - value=id_tensor.indices, num_or_size_splits=2, axis=1) - [0], [-1]), - array_ops.reshape(id_tensor.values, [-1]), - array_ops.reshape(weight_tensor.values, [-1]))) - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + SparseFeatureColumn(example_ids_filtered, reproject_ids, weights)) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) else: raise ValueError('SDCAOptimizer does not support column type %s.' % type(column).__name__) @@ -228,6 +270,7 @@ class SDCAOptimizer(object): options=dict( symmetric_l1_regularization=self._symmetric_l1_regularization, symmetric_l2_regularization=self._symmetric_l2_regularization, + adaptive=self._adaptive, num_loss_partitions=self._num_loss_partitions, num_table_shards=self._num_table_shards, loss_type=loss_type)) diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index bb1ea46b61bb7330c688dc42cd8a32f1dfeeb0b0..ff3d18b4b117863ab483b4ff9c5dac71fb5379ee 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -42,16 +42,20 @@ cc_library_static { "allocation.cc", "arena_planner.cc", "error_reporter.cc", - "graph_info.cc", + "graph_info.cc", "interpreter.cc", "model.cc", + "op_resolver.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", "simple_memory_arena.cc", "string_util.cc", + "util.cc", + "kernels/eigen_support.cc", "kernels/gemm_support.cc", ], header_libs: [ + "libeigen", "flatbuffer_headers", "gemmlowp_headers", ], diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 44c4a7e2ca8d019ca602c7f2b492cd1e70b17561..9c804d27854b8004d34c65691b48ca2b0d3bbf7c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", @@ -89,8 +87,21 @@ cc_library( hdrs = [ "builtin_op_data.h", ], + deps = [":context"], +) + +cc_library( + name = "kernel_api", + hdrs = [ + "builtin_op_data.h", + "builtin_ops.h", + "context.h", + "context_util.h", + ], ) +exports_files(["builtin_ops.h"]) + cc_library( name = "string", hdrs = [ @@ -111,6 +122,7 @@ cc_library( "interpreter.cc", "model.cc", "nnapi_delegate.cc", + "op_resolver.cc", "optional_debug_tools.cc", ], hdrs = [ @@ -121,6 +133,7 @@ cc_library( "interpreter.h", "model.h", "nnapi_delegate.h", + "op_resolver.h", "optional_debug_tools.h", ], copts = tflite_copts(), @@ -132,10 +145,12 @@ cc_library( ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", + ":util", + "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/profiling:profiler", "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/core:lib_platform", ], ) @@ -169,6 +184,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/testing:util", @@ -220,6 +236,18 @@ cc_test( ], ) +# Test OpResolver. +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the C extension API code. cc_test( name = "context_test", @@ -232,6 +260,27 @@ cc_test( ], ) +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":context", + ], +) + +cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + deps = [ + ":context", + ":util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the serialization of a model with optional tensors. # Model tests @@ -248,18 +297,3 @@ cc_test( # ], # }), #) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "downloads", - "examples", - "gen", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 7f316292724ea0baaf034d4e914773ad97a957d4..cc8a8035d1dadeec98886ba1dae4cdf403f26de4 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -1,4 +1,3 @@ - # Find where we're running from, so we can store generated files here. ifeq ($(origin MAKEFILE_DIR), undefined) MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) @@ -27,10 +26,10 @@ LIBDIR := $(MAKEFILE_DIR)/gen/lib/ GENDIR := $(MAKEFILE_DIR)/gen/obj/ # Settings for the host compiler. -CXX := $(CC_PREFIX) gcc +CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG -CC := $(CC_PREFIX) gcc -CFLAGS := +CC := $(CC_PREFIX)gcc +CCFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -57,10 +56,11 @@ LIBS := \ # If we're on Linux, also link in the dl library. ifeq ($(HOST_OS),LINUX) - LIBS += -ldl -lpthread + LIBS += -ldl endif include $(MAKEFILE_DIR)/ios_makefile.inc +include $(MAKEFILE_DIR)/rpi_makefile.inc # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. @@ -68,12 +68,12 @@ LIB_NAME := libtensorflow-lite.a LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. -BENCHMARK_PATH := $(BINDIR)benchmark_model +MINIMAL_PATH := $(BINDIR)minimal -BENCHMARK_SRCS := \ -tensorflow/contrib/lite/tools/benchmark_model.cc -BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) +MINIMAL_SRCS := \ +tensorflow/contrib/lite/examples/minimal/minimal.cc +MINIMAL_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. @@ -89,7 +89,8 @@ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ -$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) +$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ +$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ @@ -98,7 +99,7 @@ $(wildcard tensorflow/contrib/lite/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ -$(BENCHMARK_SRCS) +$(MINIMAL_SRCS) # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # File names of the intermediate files target compilation generates. @@ -117,17 +118,17 @@ $(OBJDIR)%.o: %.c $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(BENCHMARK_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) # Gathers together all the objects we've compiled into a single '.a' archive. $(LIB_PATH): $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) -$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH) +$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \ + -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) # Gets rid of all generated files. diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 00e93d2c4f3ab27057b855fba6fccf2ec8d7a1c1..a676b705f143b393c7e5bfa9e40d23f9adb68dcc 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -1,235 +1,8 @@ # TensorFlow Lite -TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. -TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. +TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded +devices. It enables low-latency inference of on-device machine learning models +with a small binary size and fast performance supporting hardware acceleration. -![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with an Android Demo App - -This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo. - -There are 3 ways to get the demo app to your device - - Download the prebuilt binary or - - Use Android Studio to build the application or - - Download the source code for TensorFlow Lite and the demo and build it using bazel - -## Description -In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. - -## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary -[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) - -Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. - -## Building in Android Studio using TensorFlow Lite AAR from JCenter -The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. - - - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). - - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). - - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. - - Click through installing all the Gradle extensions it requests. - - Either - - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/` - - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) - - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory - - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from - `classifier = new ImageClassifierQuantizedMobileNet(getActivity());` - to - `classifier = new ImageClassifierFloatInception(getActivity());` - - Build and run the demo app - -## Building TensorFlow Lite and the demo app from source - -### Clone the TensorFlow repo -- git clone - [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow) - -### Install Bazel -If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) - -NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. - -### Install Android NDK and SDK -Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. - - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) - - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). - - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). - - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` - -``` -android_sdk_repository ( - name = "androidsdk", - api_level = 23, - build_tools_version = "23.0.2", - path = "/home/xxxx/android-sdk-linux/", -) - -android_ndk_repository( - name = "androidndk", - path = "/home/xxxx/android-ndk-r10e/", - api_level = 19, -) -``` - -Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). - -### Build the source code -Run bazel with the following command to build the demo. - -Build the demo app: - -``` -bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo -``` - -### Note - -Currently, we only support building the Android demo app within a Python 2 -environment (due to a Bazel bug). - -### More about the demo -The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. - -# iOS Demo App - -Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). - -This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: - -1. Run `third_party/tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. -1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. -1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. -1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. - -# TensorFlow Lite Quick Start - -## Step 1. Decide which GraphDef to use - Depending on the use case, the developer may choose to use one of the popular - open-sourced models such as InceptionV3 or MobileNets, re-train these models - with their own custom data set or even build their own custom model. - -### Using a pre-trained model - -[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes. - -[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers. - -[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now. - -These pre-trained models can be downloaded from [here](g3doc/models.md). - -### Retrain Inception-V3 or MobileNet for a custom data set -The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes. - -The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. - - -### Train a custom model -A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. - -TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This -set will continue to expand in future releases of Tensorflow Lite. - - -## Step 2. Model format conversion - -The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below. - -A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph. - -Since we employ several formats, the following definitions may be useful: - - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions. - - - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted. - - - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint. - - - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. - - - TensorFlow lite model (.tflite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. - -### Freeze Graph -To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. - -The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). - -Graph freezing can be done using the command below (and modifying the arguments appropriately) - -``` -bazel build tensorflow/python/tools:freeze_graph - -bazel-bin/tensorflow/python/tools/freeze_graph\ - --input_graph=/tmp/mobilenet_v1_224.pb \ - --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ - --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ - --output_node_names=MobileNet/Predictions/Reshape_1 -``` - -The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with -graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). - -This frozen Graphdef is now ready to be converted to flatbuffer format (.tflite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. - -Here is a sample command line to convert the frozen Graphdef to '.tflite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. -(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). - -``` -bazel build tensorflow/contrib/lite/toco:toco - -bazel-bin/tensorflow/contrib/lite/toco/toco \ - --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ - --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ - --output_file=/tmp/mobilenet_v1_1.0_224.tflite --inference_type=FLOAT \ - --input_type=FLOAT --input_arrays=input \ - --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 -``` - -- The input_file argument should point to the frozen GraphDef file that holds the model architecture. -- The output_file argument should point to where the TensorFlow Lite model file should be generated. -- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. -- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. - -Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, - -```python -import tensorflow as tf - -img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) -out = tf.identity(val, name="out") -with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("converteds_model.tflite", "wb").write(tflite_model) - -``` -For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). - -You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). - -If you would like to see a visual description of your TensorFlow Lite model after conversion, you can use tensorflow/contrib/lite/tools/visualize.py by running -```sh -bazel run tensorflow/contrib/lite/tools:visualize -- model.tflite model_viz.html -``` -and then visualize the resulting HTML file in a browser. - -## Step 3. Use the TensorFlow Lite model for inference in a mobile app - -After completion of Step 2 the developer should have a .tflite model. - -### For Android -Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). - -The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). - -Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). - -### For iOS -Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. - -## Core ML support - -Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). +See the documentation: https://www.tensorflow.org/mobile/tflite/ +Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite) diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md new file mode 100644 index 0000000000000000000000000000000000000000..8fd63d5cee7db38fadf63ab8530bef7a3d99dd0d --- /dev/null +++ b/tensorflow/contrib/lite/RELEASE.md @@ -0,0 +1,8 @@ +# Release 0.1.7 + +* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit + fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). +* To reproduce the iOS library, it's required to cherry pick git commit + f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. +* The code is based on TensorFlow 1.8.0 release candidate and it's very close + to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index 4b322e027d48f4bf9f90d5b873c449d1ec31cc49..a4772731ecda92431c412672610a39c188dabf27 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 87b17c338e7afc33d32dd9688cc0825ac319dd19..4f836d367747e06de682b5764206d33f6e2fb983 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/arena_planner.h" +#include namespace tflite { @@ -128,6 +129,11 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { + // Grow the size of `allocs_` if necessary. This allows allocating temporary + // tensors in op's `prepare` function. + TF_LITE_ENSURE(context_, graph_info_->num_tensors() >= allocs_.size()); + allocs_.resize(graph_info_->num_tensors()); + TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node)); TF_LITE_ENSURE_STATUS(Commit()); diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index 58bc164619c2c053b9492e9a0e5de2da30e199af..e9d0fbc5a9b5aec06e28da8757466b25f40da2f5 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -25,7 +25,7 @@ limitations under the License. namespace tflite { -class AllocationInfo; +struct AllocationInfo; // A memory planner that makes all the allocations using arenas. // @@ -33,7 +33,7 @@ class AllocationInfo; // each tensor needs to be allocated and deallocated, and preallocates all the // necessary memory (the PlanAllocations phase). It then assigns portions of // this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may -// share some of the bufer if a tensor B is to be allocated after another tensor +// share some of the buffer if a tensor B is to be allocated after another tensor // A has been deallocated. // // If dynamic tensors are used the planning steps can be repeated during model diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441..16171df10a7b18b22919c6e54fe3d1e8e0120f69 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -209,11 +209,8 @@ TEST_F(ArenaPlannerTest, ZeroSizedTensors) { TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); (*graph.tensors())[1].bytes = 0; SetGraph(&graph); - // TODO(ahentz): this is currently broken because the arena finds two - // allocations with the same offset and returns an error. - ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); - // EXPECT_EQ(GetOffset(1), 0); - // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + ASSERT_EQ(planner_->ExecuteAllocations(0, 10), kTfLiteOk); + EXPECT_EQ((*graph_->tensors())[1].data.raw, nullptr); } TEST_F(ArenaPlannerTest, SimpleGraph) { diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 19829e4991651111e13fc1805f97daef8bc016a7..974e6c5d98e5691a3733495e915d919c4bf57d3a 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,4 +1,8 @@ """Generate Flatbuffer binary from json.""" +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) def tflite_copts(): """Defines compile time flags.""" @@ -104,7 +108,7 @@ def tflite_jni_binary(name, """Builds a jni binary for TFLite.""" linkopts = linkopts + [ "-Wl,--version-script", # Export only jni functions & classes. - linkscript, + "$(location {})".format(linkscript), ] native.cc_binary( name=name, @@ -124,19 +128,19 @@ def tf_to_tflite(name, src, options, out): out: name of the output flatbuffer file. """ - toco = "//tensorflow/contrib/lite/toco:toco" + toco_cmdline = " ".join([ + "//tensorflow/contrib/lite/toco:toco", + "--input_format=TENSORFLOW_GRAPHDEF", + "--output_format=TFLITE", + ("--input_file=$(location %s)" % src), + ("--output_file=$(location %s)" % out), + ] + options ) native.genrule( name = name, - srcs=[src, options], + srcs=[src], outs=[out], - cmd = ("$(location %s) " + - " --input_file=$(location %s) " + - " --output_file=$(location %s) " + - " --input_format=TENSORFLOW_GRAPHDEF" + - " --output_format=TFLITE" + - " `cat $(location %s)`") - % (toco, src, out, options), - tools= [toco], + cmd = toco_cmdline, + tools= ["//tensorflow/contrib/lite/toco:toco"], ) def tflite_to_json(name, src, out): @@ -185,33 +189,110 @@ def json_to_tflite(name, src, out): tools = [flatc], ) -def gen_zipped_test_files(name, files): +# This is the master list of generated examples that will be made into tests. A +# function called make_XXX_tests() must also appear in generate_examples.py. +# Disable a test by commenting it out. If you do, add a link to a bug or issue. +def generated_test_models(): + return [ + "add", + "arg_max", + "avg_pool", + "batch_to_space_nd", + "concat", + "constant", + "control_dep", + "conv", + "depthwiseconv", + "div", + "equal", + "exp", + "expand_dims", + "floor", + "fully_connected", + "fused_batch_norm", + "gather", + "global_batch_norm", + "greater", + "greater_equal", + "l2norm", + "l2_pool", + "less", + "less_equal", + "local_response_norm", + "log_softmax", + "log", + # TODO(b/110143200): Enable after resolving issues with LSTM conversion. + # "lstm", + "max_pool", + "maximum", + "mean", + "minimum", + "mul", + "neg", + "not_equal", + "pad", + "padv2", + # "prelu", + "relu", + "relu1", + "relu6", + "reshape", + "resize_bilinear", + "sigmoid", + "sin", + "slice", + "softmax", + "space_to_batch_nd", + "space_to_depth", + "sparse_to_dense", + "split", + "squeeze", + "strided_slice", + "strided_slice_1d_exhaustive", + "sub", + "tile", + "topk", + "transpose", + "transpose_conv", + "where", + ] + +def gen_zip_test(name, test_name, **kwargs): + """Generate a zipped-example test and its dependent zip files. + + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) + +def gen_zipped_test_file(name, file): """Generate a zip file of tests by using :generate_examples. Args: - name: Name of output. We will produce "`name`_files" as a target. - files: A list of zip file basenames. + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" """ toco = "//tensorflow/contrib/lite/toco:toco" - out_files = [] - for f in files: - out_file = name + "/" + f - out_files.append(out_file) - native.genrule( - name = name + "_" + f + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + f + - " $(@D) zipped"), - outs = [out_file], - tools = [ - ":generate_examples", - toco, - ], - ) + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) native.filegroup( name = name, - srcs = out_files, + srcs = [file], ) def gen_selected_ops(name, model): diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 4a9023ff33de15dd384531d51e39de4ffeecdb8b..9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,11 +19,16 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a lipo \ tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/contrib/lite/build_rpi_lib.sh old mode 100644 new mode 100755 similarity index 65% rename from tensorflow/python/keras/estimator/__init__.py rename to tensorflow/contrib/lite/build_rpi_lib.sh index 6f931f415810999a042b2c2060bf6494209677ed..3824b16412ed26a6cab79df3242da6017c3322b0 --- a/tensorflow/python/keras/estimator/__init__.py +++ b/tensorflow/contrib/lite/build_rpi_lib.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras estimator API.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +set -e -from tensorflow.python.keras._impl.keras.estimator import model_to_estimator +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." -del absolute_import -del division -del print_function +CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7 diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 5fc8954743e5b3b458e5c2004f4378cbad6056c0..c1cc4476fbd45fa6b3f5b3a1ed2cba39cc2ad54b 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/context.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -51,6 +53,8 @@ typedef struct { TfLitePadding padding; int stride_width; int stride_height; + int dilation_width_factor; + int dilation_height_factor; TfLiteFusedActivation activation; } TfLiteConvParams; @@ -144,10 +148,20 @@ typedef struct { float beta; } TfLiteLocalResponseNormParams; +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + typedef struct { + // Parameters for LSTM version 1. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; typedef struct { @@ -157,6 +171,9 @@ typedef struct { typedef struct { } TfLitePadParams; +typedef struct { +} TfLitePadV2Params; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -174,6 +191,11 @@ typedef struct { int block_size; } TfLiteSpaceToDepthParams; +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + typedef enum { kTfLiteCombinerTypeSum = 0, kTfLiteCombinerTypeMean = 1, @@ -214,6 +236,20 @@ typedef struct { int shrink_axis_mask; } TfLiteStridedSliceParams; +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 4ebd1586de791eecf0304637bde76232d9f0a11d..aef9a92883f18dabfc36058507d739856c3c2af7 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -17,19 +17,23 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { #endif // __cplusplus +// The enum for builtin operators. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. typedef enum { kTfLiteBuiltinAdd = 0, kTfLiteBuiltinAveragePool2d = 1, kTfLiteBuiltinConcatenation = 2, kTfLiteBuiltinConv2d = 3, kTfLiteBuiltinDepthwiseConv2d = 4, + kTfLiteBuiltinDequantize = 6, kTfLiteBuiltinEmbeddingLookup = 7, + kTfLiteBuiltinFloor = 8, kTfLiteBuiltinFullyConnected = 9, kTfLiteBuiltinHashtableLookup = 10, kTfLiteBuiltinL2Normalization = 11, @@ -71,10 +75,33 @@ typedef enum { kTfLiteBuiltinExp = 47, kTfLiteBuiltinTopkV2 = 48, kTfLiteBuiltinSplit = 49, + kTfLiteBuiltinLogSoftmax = 50, + kTfLiteBuiltinDelegate = 51, + kTfLiteBuiltinBidirectionalSequenceLstm = 52, + kTfLiteBuiltinCast = 53, + kTfLiteBuiltinPrelu = 54, + kTfLiteBuiltinMaximum = 55, + kTfLiteBuiltinArgMax = 56, + kTfLiteBuiltinMinimum = 57, + kTfLiteBuiltinLess = 58, + kTfLiteBuiltinNeg = 59, + kTfLiteBuiltinPadv2 = 60, + kTfLiteBuiltinGreater = 61, + kTfLiteBuiltinGreaterEqual = 62, + kTfLiteBuiltinLessEqual = 63, + kTfLiteBuiltinSelect = 64, + kTfLiteBuiltinSlice = 65, + kTfLiteBuiltinSin = 66, + kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, + kTfLiteBuiltinTile = 69, + kTfLiteBuiltinExpandDims = 70, + kTfLiteBuiltinEqual = 71, + kTfLiteBuiltinNotEqual = 72, + kTfLiteBuiltinLog = 73, } TfLiteBuiltinOperator; #ifdef __cplusplus } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index c09e838c5c2e50e0f4a38eaf66e55246fd9a6f7f..5c6f5e72a47180cd98be46f60cfa8eaf28197806 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -17,9 +17,14 @@ limitations under the License. #include #include +int TfLiteIntArrayGetSizeInBytes(int size) { + static TfLiteIntArray dummy; + return sizeof(dummy) + sizeof(dummy.data[0]) * size; +} + TfLiteIntArray* TfLiteIntArrayCreate(int size) { TfLiteIntArray* ret = - (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size)); ret->size = size; return ret; } @@ -55,12 +60,16 @@ TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } -void TfLiteTensorFree(TfLiteTensor* t) { +void TfLiteTensorDataFree(TfLiteTensor* t) { if (t->allocation_type == kTfLiteDynamic && t->data.raw) { free(t->data.raw); } - if (t->dims) TfLiteIntArrayFree(t->dims); t->data.raw = NULL; +} + +void TfLiteTensorFree(TfLiteTensor* t) { + TfLiteTensorDataFree(t); + if (t->dims) TfLiteIntArrayFree(t->dims); t->dims = NULL; } diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index b0c4d3431f9a67bc87d51ada91ed73f1661023a2..4eb66cc225eb04923be9aaa445a335ad822c8a6f 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -29,6 +29,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#include #include #include @@ -40,6 +41,7 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; +typedef struct _TfLiteDelegate TfLiteDelegate; #define kOptionalTensor (-1) @@ -57,6 +59,10 @@ typedef struct { #endif } TfLiteIntArray; +// Given the size (number of elements) in a TfLiteIntArray, calculate its size +// in bytes. +int TfLiteIntArrayGetSizeInBytes(int size); + // Create a array of a given `size` (uninitialized entries). // This returns a pointer, that you must free using TfLiteIntArrayFree(). TfLiteIntArray* TfLiteIntArrayCreate(int size); @@ -131,6 +137,7 @@ typedef enum { kTfLiteUInt8 = 3, kTfLiteInt64 = 4, kTfLiteString = 5, + kTfLiteBool = 6, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -149,6 +156,7 @@ typedef union { char* raw; const char* raw_const; uint8_t* uint8; + bool* b; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -162,6 +170,11 @@ typedef enum { kTfLiteDynamic, } TfLiteAllocationType; +// The delegates should use zero or positive integers to represent handles. +// -1 is reserved from unallocated status. +typedef int TfLiteBufferHandle; +const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; + // An tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). typedef struct { @@ -194,8 +207,27 @@ typedef struct { // Null-terminated name of this tensor. const char* name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; } TfLiteTensor; +// Free data memory of tensor `t`; +void TfLiteTensorDataFree(TfLiteTensor* t); + // Free memory of tensor `t`; void TfLiteTensorFree(TfLiteTensor* t); @@ -234,11 +266,16 @@ typedef struct { // WARNING: This is an experimental interface that is subject to change. const void* custom_initial_data; int custom_initial_data_size; + + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; } TfLiteNode; typedef struct TfLiteContext { // Number of tensors in the context. - int tensors_size; + size_t tensors_size; // The execution plan contains a list of the node indices in execution // order. execution_plan->size is the current number of nodes. And, @@ -258,7 +295,7 @@ typedef struct TfLiteContext { TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, TfLiteIntArray** execution_plan); - // An tensor of tensors in the interpreter context (of length `tensors_size`) + // An array of tensors in the interpreter context (of length `tensors_size`) TfLiteTensor* tensors; // opaque full context ptr (an opaque c++ data structure) @@ -283,14 +320,20 @@ typedef struct TfLiteContext { TfLiteNode** node, TfLiteRegistration** registration); - // Replace ops with delegate. + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( struct TfLiteContext*, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace); + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; // TODO(ahentz): we should create a more general mechanism for this sort of // library-global objects. void* gemm_context; + void* eigen_context; } TfLiteContext; typedef struct _TfLiteRegistration { @@ -327,29 +370,65 @@ typedef struct _TfLiteRegistration { // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. Note, it is the responsibility of the registration binder to - // set this properly. + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. int32_t builtin_code; // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. // WARNING: This is an experimental interface that is subject to change. const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. -typedef struct { +typedef struct _TfLiteDelegate { // Data that delegate needs to identify itself. This data is owned by the // delegate. The delegate is owned in the user code, so the delegate is // responsible for doing this when it is destroyed. void* data_; + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the // delegate a view of the current graph through TfLiteContext*. It typically // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() // to ask the TensorFlow lite runtime to create macro-nodes to represent // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(TfLiteContext* context, void* data); + TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + + // Copy the data from delegate buffer handle to raw memory. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Copy the data from raw memory to delegate buffer handle. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle* handle); } TfLiteDelegate; +// WARNING: This is an experimental interface that is subject to change. +typedef struct { + TfLiteDelegate* delegate; + TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; +} TfLiteDelegateParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h new file mode 100644 index 0000000000000000000000000000000000000000..abe802e34214caf4d5063da827b3aca4a82aa56d --- /dev/null +++ b/tensorflow/contrib/lite/context_util.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides a few C++ helpers that are useful for manipulating C structures +// in C++. +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + size_t size() const { return end() - begin(); } + + private: + const TfLiteIntArray* int_array_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..35a8f6ca4166e373ea1a0af5d4a013327b30d2b6 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = [ + "//visibility:public", +]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "nnapi_delegate", + srcs = ["nnapi_delegate.cc"], + hdrs = ["nnapi_delegate.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "nnapi_delegate_test", + size = "small", + srcs = ["nnapi_delegate_test.cc"], + deps = [ + ":nnapi_delegate", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..0731d14419d2dec2ea5efa48ef5d4b7728af635f --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -0,0 +1,464 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { +namespace { + +// TODO(b/80621585): Consider printing error string, but don't for now to +// minimize binary size. +#define CHECK_NN(context, code) \ + if (code != ANEURALNETWORKS_NO_ERROR) { \ + context->ReportError(context, "NN API returned error (%d).\n", code); \ + return kTfLiteError; \ + } + +// RAII NN API Model Destructor for use with std::unique_ptr +struct NNFreeModel { + void operator()(ANeuralNetworksModel* model) { + ANeuralNetworksModel_free(model); + } +}; +// RAII NN API Compilation Destructor for use with std::unique_ptr +struct NNFreeCompilation { + void operator()(ANeuralNetworksCompilation* model) { + ANeuralNetworksCompilation_free(model); + } +}; + +// Track tensor indices to NN API tensor indices mapping. +class OperandMapping { + public: + // Given a TFLite index return the ANN index. If it doesn't exist + // return -1. + int lite_index_to_ann(int index) const { + if (index < lite_tensor_to_ann_tensor_.size()) + return lite_tensor_to_ann_tensor_[index]; + else + return -1; + } + + // NN API uses non tensor operands instead of structs. This creates one + // and returns the index. It uses a std::vector and resizes it as needed + // keeping -1 to unmapped values. Intermediate tensors likely will not + // be mapped. + int add_new_non_tensor_operand() { return next_ann_tensor_index_++; } + + // Add a new mapping from `tflite_index` and return the NN API tensor index. + int add_new_ann_tensor_index(int tflite_index) { + if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { + lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + } + int new_tensor_index = next_ann_tensor_index_++; + lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; + return new_tensor_index; + } + + private: + // Next index of ann tensor + int next_ann_tensor_index_ = 0; + + // Mapping from lite index. Use a std::vector for speed and code size + // rather than a map. + std::vector lite_tensor_to_ann_tensor_; +}; + +// Abstract builder for building an op in the NN API graph. This handles +// the disparity between TFLite and NN API operand types. NN API has singular +// operands for both tensors and parameters, and TFLite separates the two. +class NNAPIOpBuilder { + public: + NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + ANeuralNetworksModel* nn_model) + : context_(context), + operand_mapping_(tensor_mapping), + nn_model_(nn_model) {} + + TfLiteStatus AddScalarInt32Operand(int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(int32_t))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddTensorInput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + TfLiteStatus AddTensorOutput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_outputs_.push_back(ann_index); + return kTfLiteOk; + } + + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. + // This returns the NN API tensor index corresponding to the created tensor. + // If another caller previously created a NN API tensor for `tensor_index` + // then the existing one is returned. + TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); + if (ann_tensor_index != -1) { + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + // Allocate a new tensor index + ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index); + + // Parameters needed for new type. + int32_t nn_type = 0; + float scale = 0.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + *ann_tensor_index_out = -1; + return kTfLiteOk; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + scale = 0.f; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = 0.f; + zeroPoint = 0; + break; + default: + context_->ReportError(context_, "Logic error in NN API Delegate.\n"); + return kTfLiteError; + } + + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + if (tensor->allocation_type == kTfLiteMmapRo) { + // TODO(b/80630405): Use NNAPIAllocation. + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_tensor_index, tensor->data.raw, + tensor->bytes)); + } + + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + CHECK_NN(context_, ANeuralNetworksModel_addOperation( + nn_model_, type, + static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_outputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + // TfLiteContext for error handling. Must be named context for macros to + // work. + TfLiteContext* context_; + + // Tracks relationship between indices + OperandMapping* operand_mapping_; + + // The model + ANeuralNetworksModel* nn_model_; + + // Inputs and outputs for the current op. These are augmented in the sense + // that NN API uses operands for all arguments, not just tensors, unlike + // TensorFlow lite. + std::vector augmented_inputs_; + std::vector augmented_outputs_; +}; + +// The kernel that represents the subgraph of TF Lite being run on NN API. +class NNAPIDelegateKernel { + public: + NNAPIDelegateKernel() = default; + + typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, + NNAPIOpBuilder* builder, + TfLiteNode* node); + + // Return a function that knows how to translate a node into its operands + // when called. You can use this function to see if a node is supported + // (i.e. that MappingFn is not nullptr). + MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + break; + case kTfLiteBuiltinAveragePool2d: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->filter_width); + builder->AddScalarInt32Operand(builtin->filter_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + break; + default: + return nullptr; + } + } + + // Initialize the kernel (a NN model). + TfLiteStatus Init(TfLiteContext* context, + const TfLiteDelegateParams* params) { + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + nodes_.push_back(node_index); + } + + if (!nn_model_) { + ANeuralNetworksModel* model; + CHECK_NN(context, ANeuralNetworksModel_create(&model)); + nn_model_.reset(model); + + TF_LITE_ENSURE_STATUS( + BuildGraph(context, params->input_tensors, params->output_tensors)); + } + + if (!nn_compilation_) { + ANeuralNetworksCompilation* compilation; + CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation)); + nn_compilation_.reset(compilation); + } + return kTfLiteOk; + } + + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + + // Set the input tensor buffers. Note: we access tflite tensors using + // absolute indices but NN api indices inputs by relative indices. + int relative_input_index = 0; + for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[absolute_input_index]; + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } + + // Set the output tensor buffers. + int relative_output_index = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; + } + // Invoke ANN in blocking fashion. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(context, ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + + return kTfLiteOk; + } + + private: + // ANN API state. + std::unique_ptr nn_model_; + std::unique_ptr + nn_compilation_; + // Node indices that this delegate is responsible for. Indices here + // indexes into the nodes array in the TfLiteContext. + std::vector nodes_; + // Track indices we use + OperandMapping operand_mapping_; + + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { + // The operand builder allows creating a single op. We create it at this + // reduced power position rather than in the for loop to avoid reallocating + // the vectors. + NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); + // Add Tensors + // allocate outside to avoid realloc + for (auto node_index : nodes_) { + // Obtain the op and registration. + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + // Map inputs to NN API tensor indices. + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } + // Get op type and operands + int nn_op_type = + Map(context, reg->builtin_code, node)(context, &builder, node); + // Map outputs to NN API tensor indices. + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + } + + builder.FinalizeAddOperation(nn_op_type); + } + return kTfLiteOk; + } + + TfLiteStatus BuildGraph(TfLiteContext* context, + const TfLiteIntArray* input_tensors, + const TfLiteIntArray* output_tensors) { + // Build the ops and tensors. + TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context)); + // Map input and output tensor indices to ANN + std::vector inputs; + inputs.reserve(input_tensors->size); + std::vector outputs; + outputs.reserve(output_tensors->size); + // Make the TensorFlow lite inputs and outputs to ann_indices. + for (int i : TfLiteIntArrayView(input_tensors)) + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(output_tensors)) + outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + // Tell ANN to declare inputs/outputs + CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_.get(), inputs.size(), inputs.data(), + outputs.size(), outputs.data())); + // Finalize the model + CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + + return kTfLiteOk; + } +}; + +} // namespace + +// Return a NN API Delegate struct that can check for support of ops. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = { + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Do not check nodes_ if NN API is unavailable. + if (!NNAPIExists()) return kTfLiteOk; + + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int total_supported_nodes = 0; + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + NNAPIDelegateKernel dummy_kernel; + if (dummy_kernel.Map(context, registration->builtin_code, node)) { + supported_nodes.push_back(node_index); + } + total_supported_nodes += 1; + } + // Put the size at the beginning of the array. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API subgraphs) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .builtin_code = kTfLiteBuiltinDelegate, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent subgraph a new nnapi_delegate_kernel. + context->ReplaceSubgraphsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), + delegate); + return kTfLiteOk; + }}; + + return &delegate; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..44cca2fd285370d700525f98ba33c861fb97be1e --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Return a delegate that can be used to use the NN API. +// e.g. +// NnApiDelegate* delegate = NnApiDelegate(); +// interpreter->ModifyGraphWithDelegate(&delegate); +// NnApiDelegate() returns a singleton, so you should not free this +// pointer or worry about its lifetime. +TfLiteDelegate* NnApiDelegate(); +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff2e721423f07889f36746a2889afcc3369f28fc --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FloatAddOpModel : public SingleOpModel { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +// Do a test with the NN API using no activation. +TEST(NNAPIDelegate, AddWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Do a test with the NN api with relu. +TEST(NNAPIDelegate, AddWithRelu) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index a93ed201d647ddf2359a57254a959871c13fb94f..840015a7fad173dbd2ea353786871dd4e89abb98 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -36,6 +36,7 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_ NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -91,6 +92,7 @@ download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index da193d2586e9123341b9a41be049ee2a4382017a..3c5f805f12f6a1fb7185c140604f692ac282a143 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -30,7 +30,7 @@ namespace tflite { // va_list args; // foo.Report("test %d", args); // where args is va_list // -// Sublclass ErrorReporter to provide another reporting destination. +// Subclass ErrorReporter to provide another reporting destination. // For example, if you have a GUI program, you might redirect to a buffer // that drives a GUI error log box. class ErrorReporter { diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..bc9574d646b7661de8ac9b745bd53cbba1eb9f31 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/AndroidManifest.xml @@ -0,0 +1,65 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..57000072561303e8457f61b1ebe95d382fc01f10 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -0,0 +1,85 @@ +# Description: +# TensorFlow camera demo app for Android. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +# Build the demo native demo lib from the original directory to reduce code +# reuse. Note that the Java counterparts (ObjectTracker.java and +# ImageUtils.java) are still duplicated. +cc_library( + name = "tensorflow_native_libs", + srcs = [ + "//tensorflow/examples/android:libtensorflow_demo.so", + ], + tags = [ + "manual", + "notap", + ], +) + +android_binary( + name = "tflite_demo", + srcs = glob([ + "src/**/*.java", + ]), + # Package assets from assets dir as well as all model targets. + # Remove undesired models (and corresponding Activities in source) + # to reduce APK size. + assets = [ + "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", + "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", + "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + ], + assets_dir = "", + custom_package = "org.tensorflow.lite.demo", + inline_constants = 1, + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = [ + "manual", + "notap", + ], + deps = [ + ":tensorflow_native_libs", + "//tensorflow/contrib/lite/java:tensorflowlite", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "bin/**", + "gen/**", + "gradleBuild/**", + "libs/**", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "java_files", + srcs = glob(["src/**/*.java"]), +) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), +) + +exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd6c98ff878e9c41875cab74c12191cadb173 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/assets/box_priors.txt new file mode 100644 index 0000000000000000000000000000000000000000..7246b073fe7fd8b1d1340536457c8aeac24cd5a3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/box_priors.txt @@ -0,0 +1,5 @@ + 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.5 0.5 0.49999997 0.5 0.5 0.5 0.5 0.5 0.49999997 0.5 0.5 0.5 0.5 0.5 0.49999997 0.5 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.25 0.25 0.25 0.24999999 0.25 0.25 0.25 0.25 0.25 0.24999999 0.25 0.25 0.75 0.75 0.75 0.75 0.74999994 0.75 0.75 0.75 0.75 0.75 0.74999994 0.75 0.5 0.5 0.5 0.5 0.5 0.5 + 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.25 0.25 0.25 0.25 0.25 0.25 0.75 0.75 0.75 0.75 0.75 0.75 0.25 0.25 0.25 0.25 0.25 0.25 0.75 0.75 0.75 0.75 0.75 0.75 0.5 0.5 0.5 0.5 0.5 0.5 + 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.8000001 0.5656855 1.131371 0.4618802 1.3857099 0.8717798 0.8000001 0.5656855 1.131371 0.4618802 1.3857099 0.8717798 0.80000013 0.5656855 1.131371 0.4618802 1.3857098 0.87177986 0.80000013 0.5656855 1.131371 0.4618802 1.3857098 0.87177986 0.95000005 0.6717515 1.343503 0.5484828 1.6455305 0.97467947 + 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.79999995 1.1313708 0.5656854 1.3856406 0.46185714 0.8717798 0.79999995 1.1313708 0.56568533 1.3856406 0.46185708 0.87177986 0.79999995 1.1313708 0.5656854 1.3856406 0.46185714 0.8717798 0.79999995 1.1313708 0.56568533 1.3856406 0.46185708 0.87177986 0.9499999 1.3435028 0.6717514 1.6454482 0.54845536 0.97467947 + diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a70ff82aa7b0fa7315ca591820e4cf7d2f5ad18 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt @@ -0,0 +1,91 @@ +??? +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +??? +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +??? +backpack +umbrella +??? +??? +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +??? +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +??? +dining table +??? +??? +toilet +??? +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +??? +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..ba416458b011a7f4b96739eb6fcb6275a6ab3bec --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt @@ -0,0 +1,12 @@ +_silence_ +_unknown_ +yes +no +up +down +left +right +on +off +stop +go \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..0d4de358156a5d139e35cc542b8d36ab24e763b9 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -0,0 +1,52 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml new file mode 100644 index 0000000000000000000000000000000000000000..891d8cc1d4f3e59d0371030fd763c5ad468e7887 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml @@ -0,0 +1,30 @@ + + + + + diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..32bd1aabcabb85ded957230533c00e735183a323 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..b3113cd15c3255405ee34c622a1e83674e6e5487 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000000000000000000000000000000000..135862883e26eddce2b19db021adf62e10357ad0 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..8efbbf8b3c44418551699db9388cd77a88362112 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..51f87ee6507cebec6bff32b1a03b36ffc711689d Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..ba143ea7a80f03b0e850775ad672ccb2d6195e4c Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..6361d792dacd8ce09a14258878b5ce6db5e0debb Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..394eb7e534905e36fd24c3defac92c09b403ee39 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..2e27bec9785d4d51fe597bced7f04508994aa10c Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/res/drawable/border.xml new file mode 100644 index 0000000000000000000000000000000000000000..dd1d64d1d61f359422c79533f726991c78e47d99 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/drawable/border.xml @@ -0,0 +1,19 @@ + + + + + diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..1a22d4b33ebbd755104272863c5cc6c93793b86b --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml new file mode 100644 index 0000000000000000000000000000000000000000..2fe1338da57122c7e26c64c653076b6746a25497 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml @@ -0,0 +1,55 @@ + + + + + + + +